MAGIC Attribution¶
MAGIC (Model-Agnostic Generation-time Influence via Checkpointing) attributes evaluation loss to individual training examples by backpropagating through the entire training process. Unlike influence functions which use a local approximation, MAGIC computes exact counterfactual attribution by differentiating through checkpointed training steps.
We provide a Trainer class that takes differentiable training steps and handles all three phases of MAGIC attribution. We support FSDP training using the bergson.magic.dtensor_patch runtime patch, which makes PyTorch’s DTensor redistribution twice-differentiable (pytorch/pytorch#160509). The patch is applied in memory, so no torch source files are modified.
How it works¶
MAGIC attribution has three phases:
Forward training with checkpoints: Fine-tune the model, saving intermediate checkpoints at each step.
Evaluate: Compute the evaluation loss and its gradients with respect to the final model parameters.
Backward through training: Backpropagate the evaluation gradients through the checkpointed training steps using reverse-mode autodiff, accumulating attribution scores for each training example (or token).
The Trainer class handles all three phases. It uses torchopt for functional (stateless) differentiable optimization.
Usage¶
CUDA_VISIBLE_DEVICES="0" bergson magic runs/magic-ckpts \
--data.dataset NeelNanda/pile-10k \
--query.dataset NeelNanda/pile-10k \
--query.split "train[:8]" \
--model EleutherAI/pythia-14m
Core components¶
Trainer: Functional trainer that supports forward training with checkpoints and backward-through-training.
from bergson.trainer import Trainer, DataStream, BackwardState, TrainerState
import torchopt
# Initialize
opt = torchopt.adam(lr=1e-4)
trainer, state = Trainer.initialize(model, opt)
# Forward training with checkpoints
stream = DataStream(dataset, tokenizer, batch_size=4, device="cuda")
state = trainer.train(state, stream, save_dir="checkpoints/")
# Compute eval gradients, then backward through training
bwd_state = trainer.backward("checkpoints/", stream, bwd_state)
scores = bwd_state.weight_grads # attribution scores
DataStream: Wraps a dataset with differentiable per-example (or per-token) weights that receive gradients during the backward pass.
# Per-example attribution
stream = DataStream(dataset, tokenizer, batch_size=4, device="cuda")
# Per-token attribution
stream = DataStream(dataset, tokenizer, batch_size=4, device="cuda", weight_shape=(len(dataset), max_length))
DTensor patch: For multi-GPU runs with FSDP, apply the DTensor patch before any distributed operations:
from bergson.magic.dtensor_patch import apply_dtensor_patch
apply_dtensor_patch()
# Your MAGIC worker call here
Per-token vs per-example attribution¶
By default, DataStream creates a 1D weight tensor [n_examples] for per-example attribution. By passing a 2D tensor [n_examples, max_length] as the weight_shape parameter, each token receives its own attribution score. The weighted_causal_lm_ce loss function supports both shapes.
To use per-token attribution, set model.loss_function = weighted_causal_lm_ce so the model uses the weighted loss during training.
from bergson.utils.math import weighted_causal_lm_ce
model.loss_function = weighted_causal_lm_ce
Key implementation details¶
Functional optimization:
torchopt.adam(or similar) provides a pure-function optimizer whose state is a pytree of tensors. This allowstorch.autograd.gradto differentiate through optimizer updates.Checkpoint strategy: By default, checkpoints are saved at
sqrt(N)intervals, givingO(sqrt(N))memory andO(N * sqrt(N))recomputation cost.FSDP compatibility: The DTensor runtime patch adds a
NestedRedistributeautograd function that makes the FSDP all-gather/reduce-scatter differentiable through second-order backward passes.Loss weighting:
weighted_causal_lm_cemultiplies per-token cross-entropy by the DataStream weights before averaging. During backward-through-training, autograd accumulates gradients into these weights, yielding the attribution scores.