This is the second part of a series on enhancing Zero-Shot CLIP performance. In the first part, I provided a detailed explanation of how the CLIP model operates and described a straightforward method to improve its performance. This involved extending standard prompts like "A picture of {class}" with customized prompts generated by a large language model (LLM). If you haven’t already, you can find part 1 here. In this article we will present a relatively similar method to improve zero-shot CLIP performance which is additionally highly explainable.
Introduction
The CLIP model is an impressive zero-shot predictor, enabling predictions on tasks it hasn’t explicitly been trained for. Despite its inherent capabilities, there exist several strategies to notably improve its performance. In the first article we have seen one of these strategies, however, while achieving enhanced performance is valuable, there are instances where we might be willing to make trade-offs to prioritize better explainability. In this second article of our series we will explore a method that not only enhances the performance of the zero-shot CLIP model but also ensures that its predictions are easily understandable and interpretable.
Explainability in Deep Neural Networks
Various explainability techniques are available for Deep Learning models today. In a previous article, I delved into Integrated Gradients, a method that tells how each feature of an input influences the output of a machine learning model, especially deep neural networks. Another popular approach for model interpretation relies on Shap values, where we assign the contribution of each feature to the model’s output based on concepts from cooperative game theory. While these methods are versatile and can be applied to any deep learning model, they can be somewhat challenging to implement and interpret. CLIP, which has been trained to map image and text features into the same embedding space, provides an alternative explainability method based on text. This approach is more user-friendly and offers easy interpretability, providing a different perspective on model explanation.
Quick refresh of the problem
As a quick refresh from the first part of this series, the problem we are tackling here is to predict the class of the image displayed below:

A standard method of using a simple prompt "Picture of a {class}" gives the wrong answer predicting "tailed frog" with 0.68 probability score:
from transformers import CLIPProcessor, CLIPModel
import torch
import requests
from PIL import Image
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
url = "https://images.freeimages.com/images/large-previews/342/green-tree-frog2-1616738.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(text=["a photo of a tree frog", "a photo of a tailed frog"], images=image, return_tensors="pt", padding=True)
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image # this is the image-text similarity score
probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
print(probs)
"""
Output:
tensor([[0.3164, 0.6836]], grad_fn=<SoftmaxBackward0>)
"""
Let’s now see how we can improve it.
Visual classification via Description from LLMs
To improve the predictive accuracy of zero-shot CLIP, we are going to implement a similar idea to what we discussed in the first article. However, this time, rather than providing generic prompts for a class like "tree frog" such as "The identifying characteristics of a tree frog vary depending on the species, but some common features include large adhesive toes, protruding eyes, and bright colors" we will separate it into specific descriptive features. For example, considering the "tree frog" and "tailed frog" classes descriptive features are:
Tree frog:
- "Protruding eyes"
- "Large mouth"
- "Without a tail"
- "Bright green color"
Tailed frog:
- "Tiny eyes"
- "Small mouth"
- "Dark color"
- "Has long tail"
These features can be again generated using a LLM with a prompt like: Q: What are useful features for distinguishing a {class} in a photo? A: There are several useful visual features to tell there is a {class} in a photo: –
The "-" is important as it will force the model to generate a list of features.
Next, similarly to what we have done in the first article, to classify the image of a frog we take the average vector embedding of these textual features descriptions that represent each class in the multi-modal space and evaluate which average vector is the closest to the test image we want to classify. In code, we have:
# define features description for each class
features = {"tree frog": [
"protruding eyes", "large mouth", "without a tail", "bright green colour"
],
"tailed frog": [
"tiny eyes", "small mouth", "has long tail", "dark colour"
]}
# image embedding
image_features = model.visual_projection(model.vision_model(inputs['pixel_values']).pooler_output)
tree_frog_vector = model.text_model(processor(features['tree frog'], return_tensors="pt", padding=True)['input_ids']).pooler_output
# take the mean prompt embedding
tree_frog_vector = tree_frog_vector.mean(dim=0, keepdims=True)
# final projection
tree_frog_vector = model.text_projection(tree_frog_vector)
tailed_frog_vector = model.text_model(processor(features['tailed frog'], return_tensors="pt", padding=True)['input_ids']).pooler_output
# take the mean prompt embedding
tailed_frog_vector = tailed_frog_vector.mean(dim=0, keepdims=True)
# final projection
tailed_frog_vector = model.text_projection(tailed_frog_vector)
# concatenate
text_features = torch.cat([tree_frog_vector, tailed_frog_vector], dim=0)
# normalize features
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# cosine similarity as logits
logit_scale = model.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_image.softmax(dim=1)
"""
Output:
tensor([[0.8901, 0.1099]], grad_fn=<SoftmaxBackward0>)
"""
First of all, we observe that our prediction is now accurate and the model correctly identifies the class as "tree frog". __ Although we achieved the right classification result also with the method in part 1 of this series, there is a notable distinction in this method – it offers high explainability. Rather than simply taking the average of the features’ descriptions we can examine the non-standardized scores S(feature) for each feature description. This allows us to understand why the model predicted a particular class:
# here we don't average the textual features as we want to see
# the score for each feature separately
tree_frog_vector = model.text_model(processor(features['tree frog'], return_tensors="pt", padding=True)['input_ids']).pooler_output
tree_frog_vector = model.text_projection(tree_frog_vector)
text_features_tree_frog = tree_frog_vector
text_features_tree_frog = text_features_tree_frog / text_features_tree_frog.norm(dim=-1, keepdim=True)
logit_scale = model.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features_tree_frog.t()
logits_per_image
"""
Output:
tensor([[25.5400, 22.6840, 21.3895, 25.9017]], grad_fn=<MmBackward0>
"""
tailed_frog_vector = model.text_model(processor(features['tailed frog'], return_tensors="pt", padding=True)['input_ids']).pooler_output
tailed_frog_vector = model.text_projection(tailed_frog_vector)
text_features_tailed_frog = tailed_frog_vector
text_features_tailed_frog = text_features_tailed_frog / text_features_tailed_frog.norm(dim=-1, keepdim=True)
logit_scale = model.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features_tailed_frog.t()
logits_per_image
"""
Output:
tensor([[24.0911, 22.3996, 21.2813, 21.0066]], grad_fn=<MmBackward0>
"""
S("protruding eyes") = 25.5400 > S("tiny eyes") = 24.0911; S("large mouth") = 22.6840 > S("small mouth") = 22.3996; S("without a tail") ~ S("has no tail") are similar probably because the tail is not visible in the picture; S("bright green colour") = 25.9017 > S("dark colour")= 21.0066;
The scores for features belonging to the "tree frog" class are higher than those for the features describing the "tailed frog" class. Analysing these feature scores helps us understand why the model predicted a certain class. In this example, very high scores were given to features like "protruding eyes," "bright green colour," and "large mouth," providing a clear explanation for the predicted class. This level of explainability was not available in the method described in the first part because the generated prompts were quite generic and contained sentences that included different concepts. Changing prompts to simple feature descriptions gives us the best of both worlds – high accuracy and great explainability.
Conclusions
In the second part of our series we have seen how to improve the standard prompt "Picture of a {class}" boosting performance. This solution is not only scalable, as LLMs can generate descriptive features for any number of classes and datasets, but it is also highly explainable. In the upcoming articles, we will explore few-shot learning methods that leverage few-shot image examples for each class to achieve higher accuracy than zero-shot methods.
References
[1] CLIP (huggingface.co) [2] https://openreview.net/pdf?id=jlAjNL8z5cs