MAGIC Attribution

MAGIC (Model-Agnostic Generation-time Influence via Checkpointing) attributes evaluation loss to individual training examples by backpropagating through the entire training process. Unlike influence functions which use a local approximation, MAGIC computes exact counterfactual attribution by differentiating through checkpointed training steps.

We provide a Trainer class that takes differentiable training steps and handles all three phases of MAGIC attribution. We support FSDP training using the bergson.magic.dtensor_patch runtime patch, which makes PyTorch’s DTensor redistribution twice-differentiable (pytorch/pytorch#160509). The patch is applied in memory, so no torch source files are modified.

How it works

MAGIC attribution has three phases:

  1. Forward training with checkpoints: Fine-tune the model, saving intermediate checkpoints at each step.

  2. Evaluate: Compute the evaluation loss and its gradients with respect to the final model parameters.

  3. Backward through training: Backpropagate the evaluation gradients through the checkpointed training steps using reverse-mode autodiff, accumulating attribution scores for each training example (or token).

The Trainer class handles all three phases. It uses torchopt for functional (stateless) differentiable optimization.

Usage

CUDA_VISIBLE_DEVICES="0" bergson magic runs/magic-ckpts \
    --data.dataset NeelNanda/pile-10k \
    --query.dataset NeelNanda/pile-10k \
    --query.split "train[:8]" \
    --model EleutherAI/pythia-14m

Output files

After a run completes, run_cfg.run_path contains:

  • scores.pt — attribution scores tensor. Shape depends on the weight parameterization:

    • Per-example (1D weights): (num_train_docs,), indexed directly by doc_id.

    • Per-token (2D weights): (num_chunks, seq_len), indexed by (chunk_idx, token_idx) in the post-shuffle order used during training. Pad rows appended to make the dataset divisible by batch_size are trimmed before saving.

  • doc_ids.pt — written alongside scores.pt for every per-token run, shape (num_chunks, seq_len) matching scores.pt row-for-row. Each entry is the original (pre-shuffle) document id for that token position. Downstream aggregation is one line:

    scores = torch.load("scores.pt")      # (num_chunks, seq_len)
    doc_ids = torch.load("doc_ids.pt")    # (num_chunks, seq_len)
    num_docs = int(doc_ids.max()) + 1
    per_doc = torch.zeros(num_docs, dtype=scores.dtype)
    per_doc.scatter_add_(0, doc_ids.flatten(), scores.flatten())
    

    When data.chunk_length > 0 the doc_ids column comes from tokenize_and_chunk and chunks may pack multiple docs or split one across chunks. When chunk_length is 0, each row is one document and doc_ids is broadcast from the row’s pre-shuffle index; tokens past the row’s actual length carry zero MAGIC score and contribute nothing to the scatter-add.

  • run_config.yaml — serialized MagicConfig used for the run.

  • validation.csv — leave-subset-out validation results (if validation was run).

Meta Smoothness

MAGIC is valid when the function you are differentiating through is meta smooth. There a few heuristics known to encourage meta smoothness:

  • Increase batch size

  • Scale model outputs down

  • Clip gradients

  • Pre-activation batch norm

  • QK norm

  • Tune weight decay

Many of these methods boil down to “Identify and manage spikes in your training loss.”

Core components

Trainer: Functional trainer that supports forward training with checkpoints and backward-through-training.

from bergson.trainer import Trainer, DataStream, BackwardState, TrainerState
import torchopt

# Initialize
opt = torchopt.adam(lr=1e-4)
trainer, state = Trainer.initialize(model, opt)

# Forward training with checkpoints
stream = DataStream(dataset, tokenizer, batch_size=4, device="cuda")
state = trainer.train(state, stream, save_dir="checkpoints/")

# Compute eval gradients, then backward through training
bwd_state = trainer.backward("checkpoints/", stream, bwd_state)
scores = bwd_state.weight_grads  # attribution scores

DataStream: Wraps a dataset with differentiable per-example (or per-token) weights that receive gradients during the backward pass.

# Per-example attribution
stream = DataStream(dataset, tokenizer, batch_size=4, device="cuda")

# Per-token attribution
stream = DataStream(dataset, tokenizer, batch_size=4, device="cuda", weight_shape=(len(dataset), max_length))

DTensor patch: For multi-GPU runs with FSDP, apply the DTensor patch before any distributed operations:

from bergson.magic.dtensor_patch import apply_dtensor_patch
apply_dtensor_patch()

# Your MAGIC worker call here

Per-token vs per-example attribution

By default, DataStream creates a 1D weight tensor [n_examples] for per-example attribution. By passing a 2D tensor [n_examples, max_length] as the weight_shape parameter, each token receives its own attribution score. The weighted_causal_lm_ce loss function supports both shapes.

To use per-token attribution, set model.loss_function = weighted_causal_lm_ce so the model uses the weighted loss during training.

from bergson.utils.math import weighted_causal_lm_ce
model.loss_function = weighted_causal_lm_ce

Key implementation details

  • Functional optimization: torchopt.adam (or similar) provides a pure-function optimizer whose state is a pytree of tensors. This allows torch.autograd.grad to differentiate through optimizer updates.

  • Checkpoint strategy: By default, checkpoints are saved at sqrt(N) intervals, giving O(sqrt(N)) memory and O(N * sqrt(N)) recomputation cost.

  • FSDP compatibility: The DTensor runtime patch adds a NestedRedistribute autograd function that makes the FSDP all-gather/reduce-scatter differentiable through second-order backward passes.

  • Loss weighting: weighted_causal_lm_ce multiplies per-token cross-entropy by the DataStream weights before averaging. During backward-through-training, autograd accumulates gradients into these weights, yielding the attribution scores.