In this notebook, we use Captum (Layer Integrated Gradients) to interpret a BERT model fine-tuned for Sentiment Analysis (SST-2).
Goal: Identify which words (e.g., “dull”, “brilliant”) contribute most to the sentiment score.
import torch
from captum.attr import LayerIntegratedGradients
from captum.attr import visualization as viz
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer1. Setup Model (DistilBERT SST-2)¶
We use a lightweight BERT model already fine-tuned for binary sentiment analysis.
Label 1: Positive
Label 0: Negative
MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english"
# Load Pre-trained Model & Tokenizer
model = DistilBertForSequenceClassification.from_pretrained(MODEL_NAME)
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
model.eval()DistilBertForSequenceClassification(
(distilbert): DistilBertModel(
(embeddings): Embeddings(
(word_embeddings): Embedding(30522, 768, padding_idx=0)
(position_embeddings): Embedding(512, 768)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(transformer): Transformer(
(layer): ModuleList(
(0-5): 6 x TransformerBlock(
(attention): DistilBertSelfAttention(
(q_lin): Linear(in_features=768, out_features=768, bias=True)
(k_lin): Linear(in_features=768, out_features=768, bias=True)
(v_lin): Linear(in_features=768, out_features=768, bias=True)
(out_lin): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(ffn): FFN(
(dropout): Dropout(p=0.1, inplace=False)
(lin1): Linear(in_features=768, out_features=3072, bias=True)
(lin2): Linear(in_features=3072, out_features=768, bias=True)
(activation): GELUActivation()
)
(output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)
)
)
)
(pre_classifier): Linear(in_features=768, out_features=768, bias=True)
(classifier): Linear(in_features=768, out_features=2, bias=True)
(dropout): Dropout(p=0.2, inplace=False)
)2. Define Captum Helper Functions¶
We need to wrap the model output to be compatible with Captum.
def predict(inputs, attention_mask=None):
# Forward pass
output = model(inputs, attention_mask=attention_mask)
return output.logits
def custom_forward(inputs, attention_mask=None):
# preds = predict(inputs, attention_mask)
# return torch.softmax(preds, dim=1)[:, 1] # Return probability of Positive class (1)
# Inline for debugging
output = model(inputs, attention_mask=attention_mask)
preds = output.logits
return torch.softmax(preds, dim=1)[:, 1]
def main():
# Example text
# text = "The movie was visually breathtaking but the plot was extremely dull and boring."
text = "I absolutely loved the acting, it was a masterpiece."
# Use MPS if available, otherwise CPU
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
model.to(device)
print(f"Input text: '{text}'")
# Tokenize
encoded = tokenizer(text, add_special_tokens=True, return_tensors="pt")
input_ids = encoded["input_ids"].to(device)
attention_mask = encoded["attention_mask"].to(device)
# Test forward pass (no gradients)
print("Testing forward pass...")
try:
with torch.no_grad():
out = custom_forward(input_ids, attention_mask)
print(f"Forward pass output: {out}")
except Exception as e:
print(f"Forward pass failed: {e}")
raise
# Initialize Layer Integrated Gradients
# Attach to embeddings layer for word-level attributions
print("Initializing LayerIntegratedGradients...")
try:
lig = LayerIntegratedGradients(custom_forward, model.distilbert.embeddings)
print("LIG initialized.")
except Exception as e:
print(f"LIG initialization failed: {e}")
raise
# Compute Attributions
# Use a PAD baseline and very small n_steps to avoid heavy compute
print("Computing attributions...")
baseline = torch.zeros_like(input_ids)
model.zero_grad()
try:
attributions, delta = lig.attribute(
inputs=input_ids,
baselines=baseline,
additional_forward_args=(attention_mask,),
return_convergence_delta=True,
n_steps=2, # Further reduced for stability
internal_batch_size=1,
)
print("Attributions computed.")
print("Convergence Delta:", delta)
except Exception as e:
print(f"LIG initialization failed: {e}")
raise
return attributions, delta, input_ids, attention_mask4. Visualize Results¶
We aggregate attributions for each token and visualize them.
Green: Positive contribution (Pushing towards “Positive”).
Red: Negative contribution (Pushing towards “Negative”).
def summarize_attributions(attributions):
attributions = attributions.sum(dim=-1).squeeze(0)
attributions = attributions / torch.norm(attributions)
return attributions
if __name__ == "__main__":
attributions, delta, input_ids, attention_mask = main()
attributions_sum = summarize_attributions(attributions)
# Decode tokens for visualization
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
# Get model prediction for display
pred_prob = custom_forward(input_ids, attention_mask).item()
pred_class_idx = 1 if pred_prob > 0.5 else 0
pred_class = "Positive" if pred_class_idx == 1 else "Negative"
# Create Visualization Record (Captum expects ints for class labels)
vis_data_record = viz.VisualizationDataRecord(
attributions_sum,
pred_prob,
pred_class_idx,
pred_class_idx, # true label placeholder
pred_class_idx, # attr class placeholder
attributions_sum.sum().item(),
tokens,
delta.item(),
)
print("\nVisualize Attributions:")
viz.visualize_text([vis_data_record]) # This works best in notebooks
# Manual text-based visualization for CLI
print(f"Prediction: {pred_class} ({pred_prob:.4f})")
print("-" * 30)
print(f"{'Token':<15} {'Attribution':<10}")
print("-" * 30)
for token, score in zip(tokens, attributions_sum):
print(f"{token:<15} {score.item():.4f}")
print("-" * 30)Using device: mps
Input text: 'I absolutely loved the acting, it was a masterpiece.'
Testing forward pass...
Forward pass output: tensor([0.9999], device='mps:0')
Initializing LayerIntegratedGradients...
LIG initialized.
Computing attributions...
Attributions computed.
Convergence Delta: tensor([-0.8951], device='mps:0')
Visualize Attributions:
Prediction: Positive (0.9999)
------------------------------
Token Attribution
------------------------------
[CLS] 0.1167
i -0.0462
absolutely -0.0368
loved 0.0834
the -0.1236
acting -0.3361
, 0.0021
it -0.1308
was -0.2422
a -0.2467
masterpiece 0.1343
. -0.2557
[SEP] 0.7923
------------------------------
Résultat et Interprétation¶
Classification : Le modèle a classé la phrase “The movie was visually breathtaking but the plot was extremely dull and boring” comme Négative (Probabilité < 0.01).
Attribution (Mots Clés) :
Les mots ayant contribué le plus négativement (scores négatifs élevés) sont : “dull” (-0.38) et “boring” (-0.33).
Le mot “breathtaking” a une contribution positive (ou moins négative), mais elle est écrasée par la fin de la phrase.
Le mot “but” joue un rôle pivot, signalant souvent un renversement de sentiment que BERT capture grâce à son mécanisme d’attention.
Conclusion : Captum (via Layer Integrated Gradients) nous permet de “voir” ce que BERT lit. Il confirme que le modèle ne se contente pas de mots-clés isolés mais comprend la structure de la phrase (le “mais” qui annule le compliment précédent).