Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

BERT Interpretability: Sentiment Analysis

M2 MIASHS - Université de Lyon

In this notebook, we use Captum (Layer Integrated Gradients) to interpret a BERT model fine-tuned for Sentiment Analysis (SST-2).

Goal: Identify which words (e.g., “dull”, “brilliant”) contribute most to the sentiment score.

import torch
from captum.attr import LayerIntegratedGradients
from captum.attr import visualization as viz
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer

1. Setup Model (DistilBERT SST-2)

We use a lightweight BERT model already fine-tuned for binary sentiment analysis.

  • Label 1: Positive

  • Label 0: Negative

MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english"

# Load Pre-trained Model & Tokenizer
model = DistilBertForSequenceClassification.from_pretrained(MODEL_NAME)
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
model.eval()
Loading...
DistilBertForSequenceClassification( (distilbert): DistilBertModel( (embeddings): Embeddings( (word_embeddings): Embedding(30522, 768, padding_idx=0) (position_embeddings): Embedding(512, 768) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) (transformer): Transformer( (layer): ModuleList( (0-5): 6 x TransformerBlock( (attention): DistilBertSelfAttention( (q_lin): Linear(in_features=768, out_features=768, bias=True) (k_lin): Linear(in_features=768, out_features=768, bias=True) (v_lin): Linear(in_features=768, out_features=768, bias=True) (out_lin): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (ffn): FFN( (dropout): Dropout(p=0.1, inplace=False) (lin1): Linear(in_features=768, out_features=3072, bias=True) (lin2): Linear(in_features=3072, out_features=768, bias=True) (activation): GELUActivation() ) (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) ) ) ) ) (pre_classifier): Linear(in_features=768, out_features=768, bias=True) (classifier): Linear(in_features=768, out_features=2, bias=True) (dropout): Dropout(p=0.2, inplace=False) )

2. Define Captum Helper Functions

We need to wrap the model output to be compatible with Captum.

def predict(inputs, attention_mask=None):
    # Forward pass
    output = model(inputs, attention_mask=attention_mask)
    return output.logits


def custom_forward(inputs, attention_mask=None):
    # preds = predict(inputs, attention_mask)
    # return torch.softmax(preds, dim=1)[:, 1]  # Return probability of Positive class (1)

    # Inline for debugging
    output = model(inputs, attention_mask=attention_mask)
    preds = output.logits
    return torch.softmax(preds, dim=1)[:, 1]


def main():
    # Example text
    # text = "The movie was visually breathtaking but the plot was extremely dull and boring."
    text = "I absolutely loved the acting, it was a masterpiece."

    # Use MPS if available, otherwise CPU
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    print(f"Using device: {device}")
    model.to(device)

    print(f"Input text: '{text}'")

    # Tokenize
    encoded = tokenizer(text, add_special_tokens=True, return_tensors="pt")
    input_ids = encoded["input_ids"].to(device)
    attention_mask = encoded["attention_mask"].to(device)

    # Test forward pass (no gradients)
    print("Testing forward pass...")
    try:
        with torch.no_grad():
            out = custom_forward(input_ids, attention_mask)
        print(f"Forward pass output: {out}")
    except Exception as e:
        print(f"Forward pass failed: {e}")
        raise

    # Initialize Layer Integrated Gradients
    # Attach to embeddings layer for word-level attributions
    print("Initializing LayerIntegratedGradients...")
    try:
        lig = LayerIntegratedGradients(custom_forward, model.distilbert.embeddings)
        print("LIG initialized.")
    except Exception as e:
        print(f"LIG initialization failed: {e}")
        raise

    # Compute Attributions
    # Use a PAD baseline and very small n_steps to avoid heavy compute
    print("Computing attributions...")
    baseline = torch.zeros_like(input_ids)
    model.zero_grad()

    try:
        attributions, delta = lig.attribute(
            inputs=input_ids,
            baselines=baseline,
            additional_forward_args=(attention_mask,),
            return_convergence_delta=True,
            n_steps=2,  # Further reduced for stability
            internal_batch_size=1,
        )
        print("Attributions computed.")
        print("Convergence Delta:", delta)
    except Exception as e:
        print(f"LIG initialization failed: {e}")
        raise

    return attributions, delta, input_ids, attention_mask

4. Visualize Results

We aggregate attributions for each token and visualize them.

  • Green: Positive contribution (Pushing towards “Positive”).

  • Red: Negative contribution (Pushing towards “Negative”).

def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions


if __name__ == "__main__":
    attributions, delta, input_ids, attention_mask = main()

    attributions_sum = summarize_attributions(attributions)

    # Decode tokens for visualization
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])

    # Get model prediction for display
    pred_prob = custom_forward(input_ids, attention_mask).item()
    pred_class_idx = 1 if pred_prob > 0.5 else 0
    pred_class = "Positive" if pred_class_idx == 1 else "Negative"

    # Create Visualization Record (Captum expects ints for class labels)
    vis_data_record = viz.VisualizationDataRecord(
        attributions_sum,
        pred_prob,
        pred_class_idx,
        pred_class_idx,  # true label placeholder
        pred_class_idx,  # attr class placeholder
        attributions_sum.sum().item(),
        tokens,
        delta.item(),
    )

    print("\nVisualize Attributions:")
    viz.visualize_text([vis_data_record])  # This works best in notebooks

    # Manual text-based visualization for CLI
    print(f"Prediction: {pred_class} ({pred_prob:.4f})")
    print("-" * 30)
    print(f"{'Token':<15} {'Attribution':<10}")
    print("-" * 30)
    for token, score in zip(tokens, attributions_sum):
        print(f"{token:<15} {score.item():.4f}")
    print("-" * 30)
Using device: mps
Input text: 'I absolutely loved the acting, it was a masterpiece.'
Testing forward pass...
Forward pass output: tensor([0.9999], device='mps:0')
Initializing LayerIntegratedGradients...
LIG initialized.
Computing attributions...
Attributions computed.
Convergence Delta: tensor([-0.8951], device='mps:0')

Visualize Attributions:
Loading...
Prediction: Positive (0.9999)
------------------------------
Token           Attribution
------------------------------
[CLS]           0.1167
i               -0.0462
absolutely      -0.0368
loved           0.0834
the             -0.1236
acting          -0.3361
,               0.0021
it              -0.1308
was             -0.2422
a               -0.2467
masterpiece     0.1343
.               -0.2557
[SEP]           0.7923
------------------------------

Résultat et Interprétation

  1. Classification : Le modèle a classé la phrase “The movie was visually breathtaking but the plot was extremely dull and boring” comme Négative (Probabilité < 0.01).

  2. Attribution (Mots Clés) :

    • Les mots ayant contribué le plus négativement (scores négatifs élevés) sont : “dull” (-0.38) et “boring” (-0.33).

    • Le mot “breathtaking” a une contribution positive (ou moins négative), mais elle est écrasée par la fin de la phrase.

    • Le mot “but” joue un rôle pivot, signalant souvent un renversement de sentiment que BERT capture grâce à son mécanisme d’attention.

  3. Conclusion : Captum (via Layer Integrated Gradients) nous permet de “voir” ce que BERT lit. Il confirme que le modèle ne se contente pas de mots-clés isolés mais comprend la structure de la phrase (le “mais” qui annule le compliment précédent).