【大作业-65】从零开始掌握CLIP模型:理论到实战完全指南
·
从零开始掌握CLIP模型:理论到实战完全指南
作者:肆十二
整理日期:2026年4月
📚 目录
一、理论篇:CLIP模型详解
1.1 什么是CLIP?
CLIP(Contrastive Language-Image Pre-training)是OpenAI在2021年发布的一个多模态模型,它的核心能力是理解图像和文本之间的关系。
🎯 CLIP能做什么?
┌─────────────────────────────────────────────────────────────┐
│ │
│ 🖼️ 图像输入 │
│ ┌─────────┐ │
│ │ 🐱 │ │
│ └────┬────┘ │
│ │ │
│ ▼ │
│ CLIP模型 │
│ │ │
│ ▼ │
│ ┌─────────────┐ │
│ │ 理解内容 │ ──→ "一只可爱的橘猫" │
│ │ 匹配文本 │ ──→ "a photo of a cat" │
│ │ 零样本分类 │ ──→ 识别新类别,无需训练 │
│ └─────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
💡 一句话理解CLIP
CLIP = 图像编码器 + 文本编码器 + 对比学习
让"猫的图片"和"a photo of a cat"在特征空间中靠得越近越好!
1.2 CLIP论文核心内容
论文信息
| 项目 | 内容 |
|---|---|
| 标题 | Learning Transferable Visual Models From Natural Language Supervision |
| 作者 | Alec Radford et al. (OpenAI) |
| 年份 | 2021 |
| 会议 | ICML |
核心创新点
🎯 创新一:用自然语言监督信号训练视觉模型
传统方法:
图像 → CNN → 分类(需要人工标注的类别)
CLIP的方法:
图像 + 描述文本 → 对比学习 → 理解任意图像-文本关系
优势:不再受限于固定的类别,可以识别互联网上的任意概念!
🎯 创新二:超大规模预训练
| 数据集 | 规模 |
|---|---|
| MS-COCO | 10万张 |
| Visual Genome | 10万张 |
| ImageNet | 120万张 |
| WIT (CLIP) | 4亿对 |
🎯 创新三:零样本迁移能力
训练时的类别: 测试时的类别:
┌──────────────────┐ ┌──────────────────┐
│ cat, dog, bird │ │ shark, whale │
│ car, bus, truck │ → │ piano, guitar │
│ apple, banana... │ │ 任意的文本描述! │
└──────────────────┘ └──────────────────┘
CLIP的工作原理
对比学习目标
┌─────────────────────────────────────┐
│ 最大化对角线相似度,最小化其他 │
│ │
│ 文本特征 │
│ T1 T2 T3 T4 │
│ ┌───┬───┬───┬───┐ │
│ I1│ ✅│ │ │ │ │
图像特征 │ ├───┼───┼───┼───┤ │
┌───┐ │ I2│ │ ✅│ │ │ │
I1 │ 🐱│ ───────→ │ ├───┼───┼───┼───┤ │
├───┤ │ I3│ │ │ ✅│ │ │
I2 │ 🐕│ ───────→ │ ├───┼───┼───┼───┤ │
├───┤ │ I4│ │ │ │ ✅│ │
I3 │ 🐦│ ───────→ │ └───┴───┴───┴───┘ │
└───┘ │
└─────────────────────────────────────┘
✅ I1-T1配对正确 → 高相似度
❌ I1-T2配对错误 → 低相似度
损失函数
CLIP使用对称对比损失:
# 图像→文本损失
loss_i2t = cross_entropy(logits_image_to_text, labels)
# 文本→图像损失
loss_t2i = cross_entropy(logits_text_to_image, labels)
# 总损失
loss = (loss_i2t + loss_t2i) / 2
1.3 CLIP的应用场景
| 应用 | 说明 | 示例 |
|---|---|---|
| 零样本分类 | 不需要训练数据就能分类 | 识别新物种 |
| 图像检索 | 用文字搜索图像 | “找一张猫的图片” |
| 图文匹配 | 判断图像和文本是否匹配 | 图像描述验证 |
| 图像生成引导 | 作为生成模型的判别器 | DALL-E的评价器 |
| 迁移学习 | 预训练模型微调 | 下游任务适配 |
二、实战篇:代码逐行解析
2.1 CLIP核心模块解析 (clip/clip.py)
📁 文件路径:
c:\quick\kaggle\CLIP-main\clip\clip.py
2.1.1 整体架构
clip/
├── __init__.py # 包入口
├── clip.py # ⭐ 主模块(本文重点)
├── model.py # 模型定义
└── simple_tokenizer.py # 分词器
2.1.2 可用模型列表
_MODELS = {
"RN50": "https://.../RN50.pt", # ResNet-50骨干
"RN101": "https://.../RN101.pt", # ResNet-101骨干
"ViT-B/32": "https://.../ViT-B-32.pt", # ViT-B/32骨干 ⭐推荐
"ViT-B/16": "https://.../ViT-B-16.pt", # ViT-B/16骨干
"ViT-L/14": "https://.../ViT-L-14.pt", # ViT-L/14大模型
"ViT-L/14@336px": "https://.../ViT-L-14-336px.pt", # 高分辨率版
}
2.1.3 图像预处理流程
def _transform(n_px):
"""
图像预处理管道
输入:任意尺寸的PIL图像
输出:标准化后的张量 [3, n_px, n_px]
"""
return Compose([
Resize(n_px, interpolation=BICUBIC), # ① 调整大小到n_px
CenterCrop(n_px), # ② 中心裁剪为正方形
_convert_image_to_rgb, # ③ 转RGB格式
ToTensor(), # ④ 转张量 [0,1]
Normalize( # ⑤ ImageNet标准化
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711]
),
])
图示:
原始图像 调整大小 中心裁剪
┌─────────────┐ ┌───────────┐ ┌───────┐
│ │ → │ │ → │ │
│ 任意尺寸 │ │ 短边=n_px│ │ n_px │
│ │ │ │ │ │
└─────────────┘ └───────────┘ └───────┘
↓
标准化 → [0.48, 0.45, ...]
2.1.4 模型加载函数
def load(name, device="cuda", jit=False, download_root=None):
"""
加载CLIP模型
参数:
name: 模型名称(ViT-B/32等)或本地文件路径
device: 设备(cuda/cpu)
jit: 是否使用JIT优化
download_root: 模型下载目录
返回:
(model, preprocess): 模型和预处理函数
"""
# 1️⃣ 获取模型路径(下载或本地)
if name in _MODELS:
model_path = _download(_MODELS[name], download_root)
else:
model_path = name # 本地路径
# 2️⃣ 加载模型文件
with open(model_path, 'rb') as f:
state_dict = torch.load(f, map_location="cpu")
# 3️⃣ 构建模型
model = build_model(state_dict).to(device)
# 4️⃣ 返回模型和预处理
return model, _transform(model.visual.input_resolution)
2.1.5 文本分词函数
def tokenize(texts, context_length=77, truncate=False):
"""
文本分词
参数:
texts: 单个字符串或字符串列表
context_length: 上下文长度(固定77)
truncate: 超长是否截断
返回:
torch.Tensor: 形状 [batch, 77] 的token ID
"""
# 输入处理
if isinstance(texts, str):
texts = [texts]
# 添加特殊token
sot_token = tokenizer.encoder["<|startoftext|>"] # 序列开始
eot_token = tokenizer.encoder["<|endoftext|>"] # 序列结束
# 编码:起始 + 文本 + 结束
all_tokens = [
[sot_token] + encode(text) + [eot_token]
for text in texts
]
# 填充/截断到固定长度77
result = torch.zeros(len(all_tokens), 77)
for i, tokens in enumerate(all_tokens):
if len(tokens) > 77:
if truncate:
tokens = tokens[:77]
tokens[-1] = eot_token # 确保以结束token结尾
else:
raise RuntimeError("文本过长")
result[i, :len(tokens)] = torch.tensor(tokens)
return result
分词示例:
输入: "a cat"
↓
编码: [4] + [3206, 994] + [49407]
↓
填充: [4, 3206, 994, 49407, 0, 0, ..., 0]
(长度为77)
2.2 数据集加载器 (flickr8k_dataset.py)
📁 文件路径:
c:\quick\kaggle\CLIP-main\flickr8k_dataset.py
2.2.1 Flickr8k数据集简介
Flickr8k是一个经典的图像描述数据集,包含:
- 8,000张训练图像
- 1,000张验证图像
- 1,000张测试图像
- 每张图像有5个不同的英文描述
2.2.2 数据格式
{
"image": "flickr8k-images/2513260012_03d33305cf.jpg",
"caption": [
"A black dog is running after a white dog in the snow .",
"Black dog chasing brown dog through snow",
"Two dogs chase each other across the snowy ground .",
"Two dogs play together in the snow .",
"Two dogs running through a low lying body of water ."
]
}
2.2.3 数据集类详解
class Flickr8kDataset(Dataset):
def __init__(self, json_path, image_dir, preprocess,
max_samples=None, use_all_captions=True):
"""
初始化数据集
参数:
json_path: JSON标注文件路径
image_dir: 图像目录
preprocess: CLIP的图像预处理函数
max_samples: 最大样本数(调试用)
use_all_captions: 是否使用所有5个caption
"""
# 1️⃣ 加载JSON
with open(json_path, 'r') as f:
self.data = json.load(f)
# 2️⃣ 路径处理(兼容不同格式)
for item in self.data:
image_rel_path = item['image']
if image_rel_path.startswith('flickr8k-images'):
# 提取文件名
image_filename = os.path.basename(image_rel_path)
image_path = image_filename
# 3️⃣ 构建样本列表
self.samples = []
for item in self.data:
if use_all_captions:
# 训练模式:每个caption都是独立样本
for caption in item['caption']:
self.samples.append({
'image_path': image_path,
'caption': caption
})
else:
# 验证模式:每张图只用第一个caption
self.samples.append({
'image_path': image_path,
'caption': item['caption'][0]
})
def __getitem__(self, idx):
"""获取单个样本"""
sample = self.samples[idx]
# 加载图像
image = Image.open(self.image_dir / sample['image_path'])
image_tensor = self.preprocess(image)
return image_tensor, sample['caption']
2.2.4 数据加载流程图
Dataset类
│
▼
┌──────────────────────────┐
│ __init__ │
│ - 加载JSON │
│ - 构建样本列表 │
│ - 路径处理 │
└──────────┬───────────────┘
│
▼
┌──────────────────────────┐
│ __getitem__ │
│ - 打开图像 │
│ - CLIP预处理 │
│ - 返回(图像, 文本) │
└──────────┬───────────────┘
│
▼
DataLoader
│
▼
┌──────────────────────────┐
│ 批处理 │
│ - 批量读取 │
│ - 封装为Batch │
│ - 送入模型训练 │
└──────────────────────────┘
2.3 模型训练流程 (clip_finetune.py)
📁 文件路径:
c:\quick\kaggle\CLIP-main\clip_finetune.py
2.3.1 训练流程总览
┌─────────────────────────────────────────────────────────────┐
│ 训练主流程 │
├─────────────────────────────────────────────────────────────┤
│ │
│ for epoch in range(epochs): │
│ │ │
│ ├── ① train_epoch() 训练一个epoch │
│ │ │ │
│ │ ├── 前向传播(图像+文本→特征) │
│ │ ├── 计算对比损失 │
│ │ ├── 反向传播 + 梯度更新 │
│ │ └── 记录训练指标 │
│ │ │
│ ├── ② validate() 验证(每N个epoch) │
│ │ │ │
│ │ └── 计算验证损失 │
│ │ │
│ └── ③ save_checkpoint() 保存模型(每N个epoch) │
│ │
└─────────────────────────────────────────────────────────────┘
2.3.2 模型加载
def load_model_and_data(self):
"""加载模型和数据"""
# 1️⃣ 加载CLIP模型
model, preprocess = clip.load(
self.args.model_name,
device=self.device,
jit=False
)
model.train() # 训练模式
# 2️⃣ 可选:冻结视觉编码器
if self.args.freeze_vision:
for param in model.visual.parameters():
param.requires_grad = False
# 3️⃣ 限制logit_scale范围(防止NaN)
self.logit_scale_max = 100.0
model.logit_scale.clamp_(max=self.logit_scale_max)
# 4️⃣ 创建数据集和DataLoader
train_dataset = Flickr8kDataset(...)
self.train_loader = DataLoader(train_dataset, batch_size=32, ...)
2.3.3 对比损失计算
def compute_loss(self, image_features, text_features,
logits_per_image, logits_per_text):
"""计算对称对比损失"""
batch_size = image_features.shape[0]
# 创建标签:[0, 1, 2, 3, ...] 对角线为正样本
labels = torch.arange(batch_size, device=self.device)
# logits裁剪(防止数值爆炸)
logits_per_image = torch.clamp(logits_per_image, min=-100, max=100)
logits_per_text = torch.clamp(logits_per_text, min=-100, max=100)
# 图像→文本损失
loss_i2t = cross_entropy(logits_per_image, labels)
# 文本→图像损失
loss_t2i = cross_entropy(logits_per_text, labels)
# 对称损失
loss = (loss_i2t + loss_t2i) / 2
return loss
2.3.4 训练循环
def train_epoch(self, epoch):
"""训练一个epoch"""
self.model.train()
total_loss = 0.0
for batch_idx, (images, captions) in enumerate(self.train_loader):
# 1️⃣ 数据准备
images = images.to(self.device)
text_tokens = clip.tokenize(captions).to(self.device)
# 2️⃣ 前向传播
image_features = self.model.encode_image(images)
text_features = self.model.encode_text(text_tokens)
# 3️⃣ 特征归一化
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 4️⃣ 计算相似度矩阵
logit_scale = self.safe_logit_scale() # 安全获取
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
# 5️⃣ 计算损失
loss = self.compute_loss(image_features, text_features,
logits_per_image, logits_per_text)
# 跳过NaN批次
if loss is None:
continue
# 6️⃣ 反向传播
self.optimizer.zero_grad()
loss.backward()
# 7️⃣ 梯度裁剪 + 更新
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
total_loss += loss.item()
return total_loss / len(self.train_loader)
2.3.5 训练参数配置
| 参数 | 默认值 | 说明 |
|---|---|---|
--model_name |
ViT-B/32 | CLIP模型选择 |
--learning_rate |
1e-3 | 学习率 |
--batch_size |
32 | 批次大小 |
--epochs |
12 | 训练轮数 |
--val_interval |
5 | 验证间隔 |
--save_interval |
5 | 保存间隔 |
--freeze_vision |
False | 是否冻结视觉编码器 |
2.3.6 运行训练
# 基本训练
python clip_finetune.py
# 自定义参数
python clip_finetune.py --epochs 20 --batch_size 64 --learning_rate 1e-4
# 冻结视觉编码器
python clip_finetune.py --freeze_vision
2.4 图形界面工具 (clip_gui_final.py)
📁 文件路径:
c:\quick\kaggle\CLIP-main\clip_gui_final.py
2.4.1 界面功能概览
┌─────────────────────────────────────────────────────────────┐
│ CLIP模型工具箱 │
├─────────────────────────────────────────────────────────────┤
│ [📦 模型加载] [🏷️ 零样本分类] [🔗 相似度计算] │
└─────────────────────────────────────────────────────────────┘
2.4.2 界面一:模型加载
┌─────────────────────────────────────────────────────────────┐
│ 🔧 CLIP模型加载 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 模型选择 │
│ ├─ 加载方式: [官方模型 ▼] │
│ ├─ 选择模型: [ViT-B/32 ▼] │
│ └─ 设备: [cuda ▼] │
│ │
│ [🚀 加载模型] │
│ │
│ 状态: ✅ 已加载 ViT-B/32 │
│ │
│ 📖 CLIP模型介绍 │
│ CLIP是由OpenAI开发的多模态模型... │
│ │
└─────────────────────────────────────────────────────────────┘
2.4.3 界面二:零样本分类
┌─────────────────────────────────────────────────────────────┐
│ 🎯 零样本分类 │
├───────────────────────────┬─────────────────────────────────┤
│ 图像输入 │ 类别文本输入 │
│ ┌─────────────────┐ │ │
│ │ │ │ 输入候选类别(每行一个) │
│ │ [图像预览] │ │ ┌─────────────────────────────┐│
│ │ │ │ │a photo of a cat ││
│ └─────────────────┘ │ │a photo of a dog ││
│ [📁 上传图像] │ │a photo of a bird ││
│ │ └─────────────────────────────┘│
├─────────────────────────┴─────────────────────────────────┤
│ [🔍 开始零样本分类] │
├─────────────────────────────────────────────────────────────┤
│ 分类结果 │
│ ┌────┬──────────────┬──────────┐ │
│ │排名│ 类别 │ 置信度 │ │
│ ├────┼──────────────┼──────────┤ │
│ │ 1 │ a photo of.. │ 85.23% │ │
│ │ 2 │ a photo of.. │ 12.45% │ │
│ └────┴──────────────┴──────────┘ │
└─────────────────────────────────────────────────────────────┘
2.4.4 界面三:相似度计算
┌─────────────────────────────────────────────────────────────┐
│ 🔗 相似度计算 │
├───────────────────────────┬─────────────────────────────────┤
│ 左侧内容 │ 右侧内容 │
│ ┌─────────────────┐ │ ┌─────────────────┐ │
│ │ 输入类型: │ │ │ 输入类型: │ │
│ │ [文本输入 ▼] │ │ │ [图像上传 ▼] │ │
│ ├─────────────────┤ │ ├─────────────────┤ │
│ │ │ │ │ │ │
│ │ A cute cat... │ │ │ │ │
│ │ │ │ │ [🐱图像] │ │
│ │ │ │ │ │ │
│ └─────────────────┘ │ └─────────────────┘ │
│ [📁 上传左侧图像] │ [📁 上传右侧图像] │
├─────────────────────────┴─────────────────────────────────┤
│ [📊 计算相似度] │
├─────────────────────────────────────────────────────────────┤
│ │
│ 相似度结果: 72.45 │
│ 🎉 高度相似 | 文本 ↔ 图像 │
│ │
└─────────────────────────────────────────────────────────────┘
2.4.5 核心代码实现
class CLIPModelLoader:
"""CLIP模型加载器"""
def __init__(self):
self.model = None
self.preprocess = None
self.is_loaded = False
def load_official_model(self, model_name, device):
"""加载官方模型"""
self.model, self.preprocess = clip.load(model_name, device=device)
self.is_loaded = True
def zero_shot_classify(self, image, class_names):
"""
零样本分类
参数:
image: PIL图像
class_names: 类别列表
返回:
[(类别, 置信度), ...] 按置信度降序
"""
# 编码图像
image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad():
image_features = self.model.encode_image(image_tensor)
# 编码文本
text_tokens = clip.tokenize(class_names).to(self.device)
with torch.no_grad():
text_features = self.model.encode_text(text_tokens)
# 计算相似度
similarity = 100.0 * image_features @ text_features.t()
# softmax得到概率
probs = similarity.softmax(dim=-1).cpu().numpy()[0]
# 排序返回
results = sorted(zip(class_names, probs), key=lambda x: x[1], reverse=True)
return results
def compute_similarity(self, features1, features2):
"""计算两个特征的相似度"""
return (100.0 * features1 @ features2.t()).item()
2.4.6 运行图形界面
# 1. 安装依赖
pip install PyQt5 torch torchvision pillow
# 2. 运行界面
python clip_gui_final.py
📚 总结
学习路径
┌─────────────────────────────────────────────────────────────┐
│ │
│ 理论篇 → 实战篇 │
│ ├── CLIP是什么 ├── clip.py 网络结构 │
│ ├── 论文核心 ├── flickr8k_dataset.py 数据加载 │
│ ├── 对比学习 ├── clip_finetune.py 模型训练 │
│ └── 零样本能力 └── clip_gui_final.py 图形界面 │
│ │
│ ↓ │
│ │
│ 掌握CLIP! │
│ │
└─────────────────────────────────────────────────────────────┘
下一步学习建议
- 深入模型细节:阅读
clip/model.py理解ViT和ResNet的实现 - 尝试自己的数据集:将Flickr8k换成自己的数据
- 模型部署:学习ONNX、TensorRT部署
- 下游任务:图像检索、图文匹配、目标检测等
如果觉得有用,欢迎转发分享!
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)