MAGIC Attribution

MAGIC (Model-Agnostic Generation-time Influence via Checkpointing) attributes evaluation loss to individual training examples by backpropagating through the entire training process. Unlike influence functions which use a local approximation, MAGIC computes exact counterfactual attribution by differentiating through checkpointed training steps.

We provide a Trainer class that takes differentiable training steps and handles all three phases of MAGIC attribution. We support FSDP training using the bergson.magic.dtensor_patch runtime patch, which makes PyTorch’s DTensor redistribution twice-differentiable (pytorch/pytorch#160509). The patch is applied in memory, so no torch source files are modified.

How it works

MAGIC attribution has three phases:

  1. Forward training with checkpoints: Fine-tune the model, saving intermediate checkpoints at each step.

  2. Evaluate: Compute the evaluation loss and its gradients with respect to the final model parameters.

  3. Backward through training: Backpropagate the evaluation gradients through the checkpointed training steps using reverse-mode autodiff, accumulating attribution scores for each training example (or token).

The Trainer class handles all three phases. It uses torchopt for functional (stateless) differentiable optimization.

Usage

CUDA_VISIBLE_DEVICES="0" bergson magic runs/magic-ckpts \
    --data.dataset NeelNanda/pile-10k \
    --query.dataset NeelNanda/pile-10k \
    --query.split "train[:8]" \
    --model EleutherAI/pythia-14m

Core components

Trainer: Functional trainer that supports forward training with checkpoints and backward-through-training.

from bergson.trainer import Trainer, DataStream, BackwardState, TrainerState
import torchopt

# Initialize
opt = torchopt.adam(lr=1e-4)
trainer, state = Trainer.initialize(model, opt)

# Forward training with checkpoints
stream = DataStream(dataset, tokenizer, batch_size=4, device="cuda")
state = trainer.train(state, stream, save_dir="checkpoints/")

# Compute eval gradients, then backward through training
bwd_state = trainer.backward("checkpoints/", stream, bwd_state)
scores = bwd_state.weight_grads  # attribution scores

DataStream: Wraps a dataset with differentiable per-example (or per-token) weights that receive gradients during the backward pass.

# Per-example attribution
stream = DataStream(dataset, tokenizer, batch_size=4, device="cuda")

# Per-token attribution
stream = DataStream(dataset, tokenizer, batch_size=4, device="cuda", weight_shape=(len(dataset), max_length))

DTensor patch: For multi-GPU runs with FSDP, apply the DTensor patch before any distributed operations:

from bergson.magic.dtensor_patch import apply_dtensor_patch
apply_dtensor_patch()

# Your MAGIC worker call here

Per-token vs per-example attribution

By default, DataStream creates a 1D weight tensor [n_examples] for per-example attribution. By passing a 2D tensor [n_examples, max_length] as the weight_shape parameter, each token receives its own attribution score. The weighted_causal_lm_ce loss function supports both shapes.

To use per-token attribution, set model.loss_function = weighted_causal_lm_ce so the model uses the weighted loss during training.

from bergson.utils.math import weighted_causal_lm_ce
model.loss_function = weighted_causal_lm_ce

Key implementation details

  • Functional optimization: torchopt.adam (or similar) provides a pure-function optimizer whose state is a pytree of tensors. This allows torch.autograd.grad to differentiate through optimizer updates.

  • Checkpoint strategy: By default, checkpoints are saved at sqrt(N) intervals, giving O(sqrt(N)) memory and O(N * sqrt(N)) recomputation cost.

  • FSDP compatibility: The DTensor runtime patch adds a NestedRedistribute autograd function that makes the FSDP all-gather/reduce-scatter differentiable through second-order backward passes.

  • Loss weighting: weighted_causal_lm_ce multiplies per-token cross-entropy by the DataStream weights before averaging. During backward-through-training, autograd accumulates gradients into these weights, yielding the attribution scores.