Codebase for the paper "Interpreting CLIP with Sparse Linear Concept Embeddings (SpLiCE)" by U. Bhalla*, A. Oesterling*, S. Srinivas, F. P. Calmon^, H. Lakkaraju^.
SpLiCE decomposes dense CLIP embeddings into sparse, nonnegative combinations of human-interpretable, semantic concepts that can be used for post-hoc concept based explanations, concept bottleneck models, dataset bias and spurious correlation detection, and distribution shift monitoring.
To install SpLiCE as a package, run the following:
git clone [email protected]:AI4LIFE-GROUP/SpLiCE.git
cd SpLiCE
pip install .
To get the concept decomposition of a single image (we include an example image 308175 from MSCOCO), run the decompose_image.py
script. The -l1_penalty
argument can be used to control the sparsity of the decompositions.
python decompose_image.py -path 000000308175.jpg -out_path image308175_decomp.txt -l1_penalty 0.25 --verbose
The following will be printed to the specified out path:
Concept Decomposition of 000000308175.jpg:
motorbike 0.093
police 0.0824
britain 0.0814
cops 0.0471
enforcement 0.027
officers 0.0174
lime 0.0111
passed 0.0068
policeman 0.0036
Decomposition L0 Norm: 9.0
CLIP, SpLiCE Cosine Sim: 0.5363
To generate a concept distribution for a given class, run the script concept_distribution.py
, which generates a bar plot of the concepts given by decompose_data.py
. The number of concepts to be plotted can be controlled with the -plot_topk
argument. An example distribution of the top 10 concepts from the 'jack-o-lantern' class from ImageNetVal is given below.
python concept_distribution.py -dataset ImageNetVal -out_folder imagenet_jackolantern_decomp --verbose -l1_penalty 0.3 -class_label 607 -device "cuda"
To decompose an entire dataset, run the decompose_data.py
script. An example decomposition for CIFAR-10 is shown below.
python decompose_data.py -dataset CIFAR10 -out_path cifar10_decomp.txt -l1_penalty 0.35 --verbose
Concept Decomposition:
frog 0.0042
deer 0.004
cat 0.0036
aircraft 0.0034
banner 0.0032
horse 0.003
truck 0.003
lorry 0.0022
icon 0.0022
hatchback 0.0019
insignia 0.0018
dog 0.0016
...
Average Decomposition L0 Norm: 3.3622
Average CLIP, SpLiCE Cosine Sim: 0.4345
Note that the poor image quality of CIFAR results in unintuitive tokens that are related to image size/quality: {icon, insignia, banner}. This is observed for other datasets with highly downsampled images but not for images of ImageNet like quality (200x200).
This can also be done for a single class with the -class_label
tag. Here, we show the "deer" class of CIFAR-10:
python decompose_data.py -dataset CIFAR10 -out_path cifar10_decomp.txt -l1_penalty 0.35 -class_label 4 --verbose
Concept Decomposition:
deer 0.0397
elk 0.0111
fawn 0.008
banner 0.0059
buck 0.0056
moose 0.0055
foal 0.0041
reindeer 0.0033
...
Average Decomposition L0 Norm: 3.246
Average CLIP, SpLiCE Cosine Sim: 0.4484
To integrate SpLiCE in your own codebase, the splice
module has the following functions:
-
splice.load(name, vocabulary, vocabulary_size, device, download_root=None, **kwargs) -> splicemodel nn.Module
:Returns the SpLiCE model for a supported CLIP backbone (e.g.
"open_clip:ViT-B-32"
) and concept vocabulary (e.g."laion"
). Note the syntax for CLIP backbone is"[library]:[model]"
. The download root can be specified if desired, and arguments such asreturn_weights
andl1_penalty
are passed to the SpLiCE model configuration. Specifyingvocabulary_size = k
will result in decompositions over only the last k concepts in the concept set (assuming the set is sorted from least to most important).-
Use
splice.available_models()
to view supported CLIP backbones that can be directly loaded and used with SpLiCE. Note that additional models can easily be added as explained below. -
splice.get_tokenizer(name)
andsplice.get_preprocess(name)
return the associated tokenizer and image preprocessing transform for supported clip backbones. -
splice.get_vocabulary(vocabulary_name)
returns a list of strings associated with a supported vocabulary.
-
-
splice.decompose_image(image, splicemodel, device) -> weights Tensor, l0_norm float, cosine_sim float
:Returns the splicemodel's sparse concept decomposition of the input image. Also returns the additional metrics of the l0 norm of the decomposition and the cosine similarity between the splice embedding and the original dense clip embedding.
-
splice.decompose_classes(dataloader, target_label, splicemodel, device) -> weights Tensor, l0_norm float, cosine_sim float
:Returns the splicemodel's average sparse concept decomposition of all images in the dataloader with label
target_label
. Also returns the additional metrics of the average l0 norm of the decompositions and the aveerage cosine similarity between the splice embeddings and the original dense clip embeddings. -
splice.decompose_dataset(dataloader, splicemodel, device) -> weights Tensor, l0_norm float, cosine_sim float
:Returns the splicemodel's average sparse concept decomposition of all images in the dataloader. Also returns the additional metrics of the average l0 norm of the decompositions and the aveerage cosine similarity between the splice embeddings and the original dense clip embeddings.
Currently, SpLiCE supports the following models,
Library | Model |
---|---|
clip | ViT-B/32 |
clip | ViT-B/16 |
clip | RN50 |
open_clip | ViT-B-32 |
and the following vocabularies:
Vocabulary | Description |
---|---|
laion | ~33000 of the most frequent tokens present in the LAION-400m dataset. We recommend you use the top 10000 (vocab_size=10000) for the best results. |
mscoco | ~12000 of the most frequent tokens present in the MSCOCO dataset. |
laion_bigrams | 10000 of the most frequent tokens present in the LAION-400m dataset and 5000 of the most frequent bigrams present in the LAION-400m dataset. |
A SpLiCE model has the following functions:
-
encode_image(image)
:Returns the concept decomposition of the preprocessed input image batch if
return_weights
is set toTrue
, otherwise it returns the interpretable dense SpLiCE embedding (z_hat) that can be used as a replacement for the original dense CLIP embeddings (z). -
decompose(dense_embedding) -> sparse_embedding
:Takes in batch of dense CLIP embeddings and outputs the sparse weight vector representations of the semantic decompositions.
-
recompose_image(weights) -> dense_embedding
:Takes in a batch of sparse decompositions of embeddings and outputs dense reconstructions.
-
intervene_image(image, intervention_indices) -> dense embedding
Decomposes the preprocessed input image batch and suppresses any weights on concepts specified by the intervention_indices. Returns the recomposed dense embeddings after intervention.
-
forward(image, text)
Runs the forward pass through the splicemodels image and text encoders and returns the corresponding dense or sparse representations depending on model initialization.
Here is a basic example of loading a SpLiCE model from our supported CLIP backbones and vocabularies. We set our l1_penalty
here to control the sparsity of the decompositions in downstream use cases.
import splice
splicemodel = splice.load("open_clip:ViT-B-32", vocabulary="laion", vocabulary_size=10000, l1_penalty=0.25, device="cuda")
Here is a toy example of how to use a model initialized with return_weights=True
, to decompose and recompose an image with encode_image
and recompose_image
.
import splice
splicemodel = splice.load("open_clip:ViT-B-32", vocabulary="laion", vocabulary_size=10000, l1_penalty=0.25, return_weights=True, device="cuda")
preprocess = splice.get_preprocess("open_clip:ViT-B-32")
image = preprocess(Image.open("00000308175.jpg")).unsqueeze(0).to("cuda")
sparse_weights = splicemodel.encode_image(image) # shape = [1, 10000], l0 norm = 9
reconstruction = splicemodel.recompose_image(sparse_weights) # shape = [1, 512]
The sparse weights can be mapped back to text using the splice.get_vocabulary
function. For example, the following lines will print the semantic decomposition of the image shown above in the example "Sample Concept Decomposition".
import splice
splicemodel = splice.load("open_clip:ViT-B-32", vocabulary="laion", vocabulary_size=10000, l1_penalty=0.15, return_weights=True, device="cuda")
preprocess = splice.get_preprocess("open_clip:ViT-B-32")
vocabulary = splice.get_vocabulary("laion", 10000)
image = preprocess(Image.open("00000308175.jpg")).unsqueeze(0).to("cuda")
sparse_weights = splicemodel.encode_image(image) # shape = [1, 10000], l0 norm = 9
reconstruction = splicemodel.recompose_image(sparse_weights) # shape = [1, 512]
sparse_weights = sparse_weights.squeeze()
for weight_idx in torch.sort(sparse_weights, descending=True)[1]:
print(f"{vocabulary[weight_idx]}: {sparse_weights[weight_idx]}")
If you want to construct your own splice model built on your own VLM backbone or vocabulary, you can directly construct your own splice model. All you need to provide is a module that implements encode_image
, and encode_text
, an estimated image_mean
over that module, and a dictionary
of text embeddings.
import splice
vlm_backbone = VLMBackbone()
image_mean = torch.load('path/to/image_mean.pt')
vocab_path = 'path/to/vocab.txt'
## Compute embedded dictionary, mean-center and normalize
concepts = []
with open(vocab_path, "r") as f:
lines = f.readlines()
for line in lines:
concepts.append(vlm_backbone.encode_text(line))
concepts = torch.nn.functional.normalize(torch.stack(concepts), dim=1)
concepts = torch.nn.functional.normalize(concepts-torch.mean(concepts, dim=0), dim=1)
splicemodel = splice.SPLICE(image_mean, concepts, clip=vlm_backbone, device="cuda")