MAGIC Attribution¶
MAGIC (Model-Agnostic Generation-time Influence via Checkpointing) attributes evaluation loss to individual training examples by backpropagating through the entire training process. Unlike influence functions which use a local approximation, MAGIC computes exact counterfactual attribution by differentiating through checkpointed training steps.
We provide a Trainer class that takes differentiable training steps and handles all three phases of MAGIC attribution. We support FSDP training using the bergson.magic.dtensor_patch runtime patch, which makes PyTorch’s DTensor redistribution twice-differentiable (pytorch/pytorch#160509). The patch is applied in memory, so no torch source files are modified.
How it works¶
MAGIC attribution has three phases:
Forward training with checkpoints: Fine-tune the model, saving intermediate checkpoints at each step.
Evaluate: Compute the evaluation loss and its gradients with respect to the final model parameters.
Backward through training: Backpropagate the evaluation gradients through the checkpointed training steps using reverse-mode autodiff, accumulating attribution scores for each training example (or token).
The Trainer class handles all three phases. It uses torchopt for functional (stateless) differentiable optimization.
Usage¶
CUDA_VISIBLE_DEVICES="0" bergson magic runs/magic-ckpts \
--data.dataset NeelNanda/pile-10k \
--query.dataset NeelNanda/pile-10k \
--query.split "train[:8]" \
--model EleutherAI/pythia-14m
Output files¶
After a run completes, run_cfg.run_path contains:
scores.pt— attribution scores tensor. Shape depends on the weight parameterization:Per-example (1D weights):
(num_train_docs,), indexed directly bydoc_id.Per-token (2D weights):
(num_chunks, seq_len), indexed by(chunk_idx, token_idx)in the post-shuffle order used during training. Pad rows appended to make the dataset divisible bybatch_sizeare trimmed before saving.
doc_ids.pt— written alongsidescores.ptfor every per-token run, shape(num_chunks, seq_len)matchingscores.ptrow-for-row. Each entry is the original (pre-shuffle) document id for that token position. Downstream aggregation is one line:scores = torch.load("scores.pt") # (num_chunks, seq_len) doc_ids = torch.load("doc_ids.pt") # (num_chunks, seq_len) num_docs = int(doc_ids.max()) + 1 per_doc = torch.zeros(num_docs, dtype=scores.dtype) per_doc.scatter_add_(0, doc_ids.flatten(), scores.flatten())
When
data.chunk_length > 0thedoc_idscolumn comes fromtokenize_and_chunkand chunks may pack multiple docs or split one across chunks. Whenchunk_lengthis 0, each row is one document anddoc_idsis broadcast from the row’s pre-shuffle index; tokens past the row’s actual length carry zero MAGIC score and contribute nothing to the scatter-add.run_config.yaml— serializedMagicConfigused for the run.validation.csv— leave-subset-out validation results (if validation was run).
Meta Smoothness¶
MAGIC is valid when the function you are differentiating through is meta smooth. There a few heuristics known to encourage meta smoothness:
Increase batch size
Scale model outputs down
Clip gradients
Pre-activation batch norm
QK norm
Tune weight decay
Many of these methods boil down to “Identify and manage spikes in your training loss.”
Core components¶
Trainer: Functional trainer that supports forward training with checkpoints and backward-through-training.
from bergson.trainer import Trainer, DataStream, BackwardState, TrainerState
import torchopt
# Initialize
opt = torchopt.adam(lr=1e-4)
trainer, state = Trainer.initialize(model, opt)
# Forward training with checkpoints
stream = DataStream(dataset, tokenizer, batch_size=4, device="cuda")
state = trainer.train(state, stream, save_dir="checkpoints/")
# Compute eval gradients, then backward through training
bwd_state = trainer.backward("checkpoints/", stream, bwd_state)
scores = bwd_state.weight_grads # attribution scores
DataStream: Wraps a dataset with differentiable per-example (or per-token) weights that receive gradients during the backward pass.
# Per-example attribution
stream = DataStream(dataset, tokenizer, batch_size=4, device="cuda")
# Per-token attribution
stream = DataStream(dataset, tokenizer, batch_size=4, device="cuda", weight_shape=(len(dataset), max_length))
DTensor patch: For multi-GPU runs with FSDP, apply the DTensor patch before any distributed operations:
from bergson.magic.dtensor_patch import apply_dtensor_patch
apply_dtensor_patch()
# Your MAGIC worker call here
Per-token vs per-example attribution¶
By default, DataStream creates a 1D weight tensor [n_examples] for per-example attribution. By passing a 2D tensor [n_examples, max_length] as the weight_shape parameter, each token receives its own attribution score. The weighted_causal_lm_ce loss function supports both shapes.
To use per-token attribution, set model.loss_function = weighted_causal_lm_ce so the model uses the weighted loss during training.
from bergson.utils.math import weighted_causal_lm_ce
model.loss_function = weighted_causal_lm_ce
Key implementation details¶
Functional optimization:
torchopt.adam(or similar) provides a pure-function optimizer whose state is a pytree of tensors. This allowstorch.autograd.gradto differentiate through optimizer updates.Checkpoint strategy: By default, checkpoints are saved at
sqrt(N)intervals, givingO(sqrt(N))memory andO(N * sqrt(N))recomputation cost.FSDP compatibility: The DTensor runtime patch adds a
NestedRedistributeautograd function that makes the FSDP all-gather/reduce-scatter differentiable through second-order backward passes.Loss weighting:
weighted_causal_lm_cemultiplies per-token cross-entropy by the DataStream weights before averaging. During backward-through-training, autograd accumulates gradients into these weights, yielding the attribution scores.