
Senior Machine Learning Engineer, Canva
Training OpenAI’s CLIP on Google Collab

This is a walkthrough of training CLIP by OpenAI. CLIP was designed to put both images and text into a new projected space such that they can map to each other by simply looking at dot products.
Traditionally training sets like Imagenet only allowed you to map images to a single class (and hence one word). This method allows you to map text to images, but can also be used to map images to text if the need arises.
This particular blog however is specifically how we managed to train this on colab GPUs using huggingface transformers and pytorch lightning.
A Working version of this code can be found on Kaggle.
Acknowledgment
Kudos to the following CLIP tutorial in the Keras documentation.
The important thing to notice about the constants is the embedding dim. We will project the output of a Resnet and transformers into 512-dimensional space.
EMBED_DIM = 512
TRANSFORMER_EMBED_DIM = 768
MAX_LEN = 32 # Maximum length of text
TEXT_MODEL = "distilbert-base-multilingual-cased"
EPOCHS = 5
BATCH_SIZE = 64
Data
We download the coco dataset which contains 5 captions per image and has roughly 82k images. We take 20% of it to be our validation set.
Considering that the image backbone is trained using Imagenet, we normalize it using the Imagenet stats as shown in the transforms normalize step. We also resize the image to 128×128 to make sure it trains in a reasonable time.
Warning: Downloading the files will take a while (~5-10 minutes).
class Tokenizer:
def __init__(self, tokenizer: BertTokenizer) -> None:
self.tokenizer = tokenizer
def __call__(self, x: str) -> AutoTokenizer:
return self.tokenizer(
x, max_length=MAX_LEN, truncation=True, padding=True, return_tensors="pt"
)
def decode(self, x: Dict[str, torch.LongTensor]):
return [self.tokenizer.decode(sentence[:sentence_len]) for sentence, sentence_len in
zip(x["input_ids"], x["attention_mask"].sum(axis=-1))]
tokenizer = Tokenizer(AutoTokenizer.from_pretrained(TEXT_MODEL))
img = inv_tfm(img)
plt.imshow(np.rot90(img.transpose(0, 2), 3))
plt.title(tokenizer.decode(target)[0])
plt.show()

train_len = int(0.8*len(cap))
train_data, valid_data = random_split(cap, [train_len, len(cap) - train_len])
train_dl = DataLoader(train_data, BATCH_SIZE, pin_memory=True, shuffle=True, num_workers=4, drop_last=True)
valid_dl = DataLoader(valid_data, BATCH_SIZE, pin_memory=True, shuffle=False, num_workers=4, drop_last=False)
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:478: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
cpuset_checked))
Model
There are two main models, the VisionEncoder
and the TextEncoder
which have Resnet18 and Distilbert as backbones. In order to make it multi-lingual, we simply choose the distilbert-multilingual
model and that’s it! No need to specifically train on non-English words as you will soon see.
The Projection
module takes the embeddings from vision and text encoders and projects them into 512-dimensional space.
Two things to note: 1. We have frozen both the text and vision encoder backbones and do not retrain their weights at all. 2. For both encoders the final output is normalized to be of unit length.
class Projection(nn.Module):
def __init__(self, d_in: int, d_out: int, p: float=0.5) -> None:
super().__init__()
self.linear1 = nn.Linear(d_in, d_out, bias=False)
self.linear2 = nn.Linear(d_out, d_out, bias=False)
self.layer_norm = nn.LayerNorm(d_out)
self.drop = nn.Dropout(p)
def forward(self, x: torch.Tensor) -> torch.Tensor:
embed1 = self.linear1(x)
embed2 = self.drop(self.linear2(F.gelu(embed1)))
embeds = self.layer_norm(embed1 + embed2)
return embeds
class VisionEncoder(nn.Module):
def __init__(self, d_out: int) -> None:
super().__init__()
base = models.resnet34(pretrained=True)
d_in = base.fc.in_features
base.fc = nn.Identity()
self.base = base
self.projection = Projection(d_in, d_out)
for p in self.base.parameters():
p.requires_grad = False
def forward(self, x):
projected_vec = self.projection(self.base(x))
projection_len = torch.norm(projected_vec, dim=-1, keepdim=True)
return projected_vec / projection_len
class TextEncoder(nn.Module):
def __init__(self, d_out: int) -> None:
super().__init__()
self.base = AutoModel.from_pretrained(TEXT_MODEL)
self.projection = Projection(TRANSFORMER_EMBED_DIM, d_out)
for p in self.base.parameters():
p.requires_grad = False
def forward(self, x):
out = self.base(**x)[0]
out = out[:, 0, :] # get CLS token output
projected_vec = self.projection(out)
projection_len = torch.norm(projected_vec, dim=-1, keepdim=True)
return projected_vec / projection_len
CLIP loss function
For someone like me who hasn’t played around with contrastive loss, this was the most interesting part.
We know that we want the vectors of the corresponding image and the text to line up. This means that the dot product has to be as close to one as possible. For everything else, we need to push it towards 0.
Therefore for a given caption, we take the softmax of the dot products across all images and then take a cross-entropy loss. Similarly, for a given image, we repeat the process across all captions. We average these two losses.
In terms of which element is the true positive within a batch, remember that we are sending image, and caption pairs already lined up. Therefore we want all the diagonal elements to line up while all off-diagonal elements we want to push toward zero.
def contrastive_loss(logits, dim):
neg_ce = torch.diag(F.log_softmax(logits, dim=dim))
return -neg_ce.mean()
def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
caption_loss = contrastive_loss(similarity, dim=0)
image_loss = contrastive_loss(similarity, dim=1)
return (caption_loss + image_loss) / 2.0
def metrics(similarity: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
y = torch.arange(len(similarity)).to(similarity.device)
img2cap_match_idx = similarity.argmax(dim=1)
cap2img_match_idx = similarity.argmax(dim=0)
img_acc = (img2cap_match_idx == y).float().mean()
cap_acc = (cap2img_match_idx == y).float().mean()
return img_acc, cap_acc
Model
If you haven’t used PyTorch lightning before, the benefit is that you do not need to stress about which device to put it in, remembering to zero the optimizer, etc. All of that is taken care of. Just simply specify the training and validation steps, along with the optimizer and you are good to go.
The other benefit that I really like is logging. You just need to write self.log("name", metric_to_track)
and it will log to Tensorboard by default, or any other kind of logger for that matter.
class Model(pl.LightningModule):
def __init__(self,
lr: float = 1e-3
) -> None:
super().__init__()
self.vision_encoder = VisionEncoder(EMBED_DIM)
self.caption_encoder = TextEncoder(EMBED_DIM)
self.tokenizer = Tokenizer(AutoTokenizer.from_pretrained(TEXT_MODEL))
self.lr = lr
def common_step(self, batch: Tuple[torch.Tensor, List[str]]) -> torch.Tensor:
images, text = batch
text_dev = {k: v.to(self.device) for k, v in self.tokenizer(text).items()}
image_embed = self.vision_encoder(images)
caption_embed = self.caption_encoder(text)
similarity = caption_embed @ image_embed.T
loss = clip_loss(similarity)
img_acc, cap_acc = metrics(similarity)
return loss, img_acc, cap_acc
def training_step(
self, batch: Tuple[torch.Tensor, List[str]], *args: list
) -> torch.Tensor:
loss, img_acc, cap_acc = self.common_step(batch)
self.log("training_loss", loss, on_step=True)
self.log("training_img_acc", img_acc, on_step=True, prog_bar=True)
self.log("training_cap_acc", cap_acc, on_step=True, prog_bar=True)
return loss
def validation_step(
self, batch: Tuple[torch.Tensor, List[str]], *args: list
) -> torch.Tensor:
loss, img_acc, cap_acc = self.common_step(batch)
self.log("validation_loss", loss, on_step=True)
self.log("validation_img_acc", img_acc, on_step=True, prog_bar=True)
self.log("validation_cap_acc", cap_acc, on_step=True, prog_bar=True)
return loss
def configure_optimizers(self) -> torch.optim.Optimizer:
vision_params = {"params": self.vision_encoder.projection.parameters(), "lr": self.lr}
caption_params = {"params": self.caption_encoder.projection.parameters() , "lr": self.lr}
return torch.optim.Adam([vision_params, caption_params])
Train
Training is straightforward as shown in the five lines below. Using 16-bit precision almost halved the training time from 16 minutes to 9 minutes per epoch. Notice how easy it was to add half-precision training and gradient clipping.
Also, one thing to note is that I could not get this working on TPUs so if anyone knows what I need to adjust, please let me know. Setting tpu_cores=8
just did not work.
model = Model(1e-3)
trainer = pl.Trainer(
max_epochs= 1,
gpus=torch.cuda.device_count(),
# tpu_cores=1,
gradient_clip_val=1.0,
precision=16
)
trainer.fit(model, train_dl, valid_dl) #
Run the following cell if you wish to see the logs in Tensorboard. But here’s a screenshot I took:
# Load the TensorBoard notebook extension
%reload_ext tensorboard
%tensorboard --logdir ./lightning_logs/
Results
I will compare the text embeddings of the first batch (in the validation set) to all the images of the validation set by taking the dot product between them.
similarity = caption_embed @ image_embed.T
val, closest = similarity.topk(5, dim=-1)
similarity.shape
torch.Size([64, 16557])
draw_result(i, similarity_matrix)
is a convenience function that takes the i-th caption and the similarity matrix, and plots the five closest images, along with the true image. The similarity between the caption and the image is shown in the title. The caption is printed first.
The histogram shows the similarity of the caption to all images as a histogram.
draw_result(2, similarity)
A baseball player in the outfield with his hands up, standing next to a team mascot.


draw_result(1, similarity)
A watch and clock repair shop window with clocks on display.


draw_result(10, similarity)
A person on a skateboard on the ground.


Below is the Google translated version of one of the captions.
English caption: “A zebra standing up with its head down and eating grass on the dirt ground.”, translated into Spanish:
text = "Una cebra de pie con la cabeza gacha y comiendo hierba en el suelo de tierra."
text_dev = {k: v.to(device) for k, v in tokenizer(text).items()}
with torch.no_grad():
caption_embed_text = caption_encoder(text_dev)
similarity_text = caption_embed_text @ image_embed.T
draw_result_single_query(10, similarity_text)

Skateboarder conducting a trick with bicycles in the background.

Again a translated version, this time in French. English caption: “A laptop is displayed on a small wooden platform.”
# Guy and woman in glasses shake hands while exchanging gifts.
text = "Un ordinateur portable est affiché sur une petite plate-forme en bois."
text_dev = {k: v.to(device) for k, v in tokenizer(text).items()}
with torch.no_grad():
caption_embed_text = caption_encoder(text_dev)
similarity_text = caption_embed_text @ image_embed.T
draw_result_single_query(3, similarity_text)

Laptop computer on a small table on the side of a bed

The Russian translation below is doing terrible though, so it’s clearly not bulletproof. Or perhaps I need to train for a bit longer. English caption: “A shop filled with different kinds of clocks.
text = "Магазин с разными часами"
text_dev = {k: v.to(device) for k, v in tokenizer(text).items()}
with torch.no_grad():
caption_embed_text = caption_encoder(text_dev)
similarity_text = caption_embed_text @ image_embed.T
draw_result_single_query(1, similarity_text)

A room filled with clocks through a window.

And lastly, I check a single-word version. Notice how the dog does kind of look like a bear. Maybe its name is the Bear?
text = "bear"
text_dev = {k: v.to(device) for k, v in tokenizer(text).items()}
with torch.no_grad():
caption_embed_text = caption_encoder(text_dev)
similarity_text = caption_embed_text @ image_embed.T
draw_result_single_query(1, similarity_text)

Large collection of digital and analog clocks on display.

Would love to hear any thoughts and comments on the above.
Shameless Self Promotion
See here for my course on Machine Learning and Deep Learning (Use code DEEPSCHOOL-MARCH for 85% off).
Tags: Huggingface, Machine learning, MLops, Multilingual CLIP, PyTorch Lightning