参考资料:
[CLIP huggingface源码:CLIPModel]
这篇文章首先展示CLIP损失函数的两种底层实现代码,然后聊一聊自己的理解。
说实话念硕士的时候没有接触过CLIP这个东西,来实习之后发现这个多模态的模型使用非常广泛,设计理念也是看后惊为天人。加上最近有探究任务研究CLIP,BLIP这些,遂决心把这个模型弄懂。参考资料1已经把CLIP的设计思想,原理,甚至是底层实现给讲清楚了,但是当我读到训练的损失函数那一段的时候还是产生了很大的疑问:作者说有两种方式来计算损失函数,一种较为简单,一种较为复杂。较为复杂的损失函数实现如下:
def forward(self, batch):
# Getting Image and Text Features
image_features = self.image_encoder(batch["image"])
text_features = self.text_encoder(
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
)
# Getting Image and Text Embeddings (with same dimension)
image_embeddings = self.image_projection(image_features)
text_embeddings = self.text_projection(text_features)
# Calculating the Loss
logits = (text_embeddings @ image_embeddings.T) / self.temperature
images_similarity = image_embeddings @ image_embeddings.T
texts_similarity = text_embeddings @ text_embeddings.T
targets = F.softmax(
(images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
)
texts_loss = cross_entropy(logits, targets, reduction='none')
images_loss = cross_entropy(logits.T, targets.T, reduction='none')
loss = (images_loss + texts_loss) / 2.0 # shape: (batch_size)
return loss.mean()
其中Cross_entropy也是作者自己实现的,看上去就是logsoftmax加上NLLloss:
def cross_entropy(preds, targets, reduction='none'):
log_softmax = nn.LogSoftmax(dim=-1)
loss = (-targets * log_softmax(preds)).sum(1)
if reduction == "none":
return loss
elif reduction == "mean":
return loss.mean()
较为简单的损失函数的实现则是这样:nn.CrossEntropyLoss()(logits, torch.arange(batch_size))
作者在下面进行了分析,我看完分析之后觉得... ... 作者的语气好像是在说这种较为简单的损失函数是有误的,在数据集中有同一张图片的多个相似caption的时候会明显犯错。那么,较为复杂的损失函数就是正确的了。以上是Tutorial里作者的实现,较为权威的另一种实现是huggingface团队Transformer库里的源码。由于CLIP模型的高度可定制性,huggingface团队实现了一个基类,也就是CLIPModel部分。并在需要训练的时候把loss设置为forward函数的第一个返回值,我们来看一下他们的实现:
image_embeds = vision_outputs[1]
image_embeds = self.visual_projection(image_embeds)
text_embeds = text_outputs[1]
text_embeds = self.text_projection(text_embeds)
# normalized features
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
logits_per_image = logits_per_text.t()
loss = None
if return_loss:
loss = clip_loss(logits_per_text)
其中,clip_loss的实现如下:
# contrastive loss function, adapted from
# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
caption_loss = contrastive_loss(similarity)
image_loss = contrastive_loss(similarity.t())
return (caption_loss + image_loss) / 2.0
一开始的归一化比较好理解,logit_scale是一个超参数也好理解。最难理解的就是logits_per_text和logits_per_image这两个互为转置的矩阵。写这篇文章的时候我只能说自己弄懂了7分,原论文中有这么一段话:While standard image models jointly train an image feature extractor and a linear classifier to predict some label, CLIP jointly trains an image encoder and a text encoder to predict the correct pairings of a batch of (image, text) training examples. 即CLIP是学习(image, text)图文对之间的正确匹配的。这个正确匹配有两个对称的方面:1)对于每一个caption,和它吻合的图片得到label 1,和它不吻合的图片得到label 0。(这个对应于caption_loss)2)对于每一个image,和它吻合的caption得到label 1,和它不吻合的caption得到label 0。(这个对应于image_loss)而将两个loss相加除以2,得到的损失函数就同时考虑了两个方面了。如果一个模型在这两个方面都做得好,那么大概率是能够成功学习到correct pairings of a batch of (image, text) 的。
所有评论(0)