class bergson.AttentionConfig(num_heads=0, head_size=0, head_dim=0)

Bases: object

Config for splitting an attention module into head matrices.

head_dim: int = 0

Axis index for num_heads in the weight matrix.

head_size: int = 0

Size of each attention head.

num_heads: int = 0

Number of attention heads.

class bergson.Attributor(index_path, device='cpu', dtype=torch.float32, unit_norm=False, precondition=False, faiss_cfg=None)

Bases: object

precondition: Literal['one-sided', 'two-sided', 'none']
search(queries, k, modules=None, reverse=False)

Search for the k nearest examples in the index based on the query or queries.

Parameters:
  • queries – The query tensor of shape […, d].

  • k – The number of nearest examples to return for each query.

  • module – The name of the module to search for. If None, all modules will be searched.

  • reverse – If True, return the lowest influence examples instead of highest.

Returns:

A namedtuple containing the top k indices and inner products for each query. Both have shape […, k].

trace(module, k, *, modules=None, reverse=False)

Context manager to trace the gradients of a module and return the corresponding Attributor instance.

Parameters:
  • module – The module to trace.

  • k – The number of nearest examples to return.

  • modules – The modules to trace. If None, all modules will be traced.

  • reverse – If True, return the lowest influence examples instead of highest.

Return type:

Generator[TraceResult, None, None]

class bergson.Builder(data, grad_sizes, dtype, preprocess_cfg, *, attribute_tokens=False, path=None)

Bases: object

Gradient index writer.

Handles all combinations of storage (disk / in-memory) and granularity (per-sequence / per-token), with optional preconditioning and aggregation.

Parameters:
  • data (Dataset) – The dataset being indexed.

  • grad_sizes (dict[str, int]) – Per-module gradient dimensions.

  • dtype (torch.dtype) – Torch dtype for the gradients.

  • preprocess_cfg (PreprocessConfig) – Preconditioning, normalization, and aggregation settings.

  • attribute_tokens (bool) – Per-token gradients instead of per-example.

  • path (Path | None) – When given, write to a memory-mapped file on disk. When None, store in a plain numpy array.

flush()
Return type:

None

grad_buffer: ndarray
teardown()
Return type:

None

class bergson.CollectorComputer(model, data, *, collector, batches=None, cfg)

Bases: object

Orchestrates gradient collection by running forward/backward passes over a dataset.

Iterates through batches of data, computes losses, triggers backpropagation, and delegates gradient processing to the provided collector. Supports distributed training and optional profiling via PyTorch profiler via cfg.profile flag.

run_with_collector_hooks(desc=None)

Run the main computation loop over all batches.

For each batch: computes forward pass, calculates loss, triggers backward pass (which invokes collector hooks), then calls collector.process_batch(). After all batches are processed, calls collector.teardown().

Parameters:

desc – Optional description string for the tqdm progress bar.

class bergson.DataConfig(dataset='NeelNanda/pile-10k', split='train', subset=None, prompt_column='text', completion_column='', conversation_column='', reward_column='', skip_nan_rewards=False, truncation=False, format_template='', data_kwargs='', chunk_length=0)

Bases: Serializable

chunk_length: int = 0

When positive, concatenate and chunk the documents into fixed-length token sequences of this length. Incompatible with truncation and format_template.

completion_column: str = ''

Optional column in the dataset that contains the completions.

conversation_column: str = ''

Optional column in the dataset that contains the conversation.

data_kwargs: str = ''

Arguments to pass to the dataset constructor in the format arg1=val1,arg2=val2.

dataset: str = 'NeelNanda/pile-10k'

Dataset identifier to build the index from.

decode_into_subclasses: ClassVar[bool] = False
format_template: str = ''

Path to a YAML containing a Jinja2 template specifying how to format dataset rows into text. The YAML must contain doc_to_text and optionally doc_to_target and doc_to_choice. MCQA YAML available at bergson/templates/mcqa.yaml.

prompt_column: str = 'text'

Column in the dataset that contains the prompts.

reward_column: str = ''

Optional column in the dataset that contains the rewards. When specified, gradients are calculated using the policy gradient loss from Dr. GRPO. https://arxiv.org/abs/2503.20783

skip_nan_rewards: bool = False

Whether to skip examples with NaN rewards.

split: str = 'train'

Split of the dataset to use for building the index.

subset: str | None = None

Subset of the dataset to use for building the index.

truncation: bool = False

Whether to truncate long documents to fit the token budget.

class bergson.FaissConfig(index_factory='Flat', mmap_index=False, max_train_examples=None, batch_size=1024, num_shards=1, nprobe=10)

Bases: object

Configuration for FAISS index.

batch_size: int = 1024

The batch size for pre-processing gradients.

index_factory: str = 'Flat'

//github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index).

Common FAISS factory strings:
  • “IVF1,SQfp16”: exact nearest neighbors with brute force search and fp16.

    Valid for CPU or memmapped indices.

  • “IVF1024,SQfp16”: approximate nearest neighbors with 1024 cluster centers

    and fp16. Fast approximate queries are produced at the cost of a slower initial index build.

  • “PQ6720”: nearest neighbors with vector product quantization to 6720 elements.

    Reduces memory usage at the cost of accuracy.

Type:

The [FAISS index factory string](https

max_train_examples: int | None = None

The maximum number of examples to train the index on. If None, all examples will be used.

mmap_index: bool = False

Whether to query the gradients on-disk.

nprobe: int = 10

The number of FAISS vector clusters to search if using ANN.

num_shards: int = 1

The number of shards to build for an index. Using more shards reduces peak RAM usage.

class bergson.FiniteDiff(model, *, normalizers=None)

Bases: object

apply()

Context manager to apply finite differences to the model.

clear()

Clear the stored finite differences.

store(step_size)

Compute and store finite differences for the model parameters.

This method assumes that you’ve just called .backward() on some loss function, so the model’s parameters have gradients. We reuse these gradient buffers and set the .grad attributes to None to avoid unnecessary memory usage.

swap()

Swap the original and updated parameters.

class bergson.GradientCollector(model, filter_modules=None, target_modules=None, processor=<factory>, attention_cfgs=<factory>, attribute_tokens=False, lo=-inf, hi=inf, *, data, cfg, skip_hessians=True, mod_grads=<factory>, preprocess_cfg=<factory>, builder=None, scorer=None)

Bases: HookCollectorBase

Collects per-sample gradients from model layers and writes them to disk.

  • For each forward/backward hook, we compute the the gradient or a low-rank

approximation via random projections, if cfg.projection_dim is set. - Supports normalization via Adam or Adafactor normalizers.

backward_hook(module, g)

Compute per-sample gradient, accumulate autocorrelation matrix, and store.

builder: Builder | None = None

Handles writing gradients to disk. Created in setup() if save_index is True.

cfg: IndexConfig

Configuration for gradient index.

data: Dataset

The dataset being processed.

mod_grads: dict

Temporary storage for gradients during a batch, keyed by module name.

preprocess_cfg: PreprocessConfig

Configuration for gradient preprocessing.

process_batch(indices, **kwargs)

Process collected gradients for a batch and update losses.

scorer: Scorer | None = None

Optional scorer for computing scores instead of building an index.

setup()

Initialize collector state.

Sets up a Builder for gradient storage if not using a Scorer.

Return type:

None

skip_hessians: bool = True

Whether to skip estimating autocorrelation hessian statistics.

teardown()

Finalize gradient collection, save results and flush/reduce the Builder.

class bergson.GradientProcessor(normalizers=<factory>, hessians=<factory>, hessians_eigen=<factory>, projection_dim=None, reshape_to_square=False, projection_type='rademacher', projection_target='per_module', include_bias=False)

Bases: object

Configuration for processing and compressing gradients.

hessians: dict[str, Tensor]

Dictionary of hessians for each matrix-valued parameter in the model. These are applied after the normalization and random projection steps.

hessians_eigen: Mapping[str, tuple[Tensor, Tensor]]

Dictionary of eigen decompositions of hessians for each matrix-valued parameter in the model. Each value is a tuple of (eigenvalues, eigenvectors). These are used to efficiently apply inverse square-root of the hessians to the gradients.

include_bias: bool = False

Whether to include bias gradients when present on a module.

classmethod load(path, *, map_location=None, skip_hessians=False)

Load the normalizers and hessians from a file.

Return type:

GradientProcessor

normalizers: Mapping[str, Normalizer]

Dictionary of normalizers for each matrix-valued parameter in the model. The keys should match the names of the parameters in the model. If a parameter does not have a normalizer, it will be skipped.

projection_dim: int | None = None

Number of rows and columns to project the gradients to. If None, keep the original shape of the gradients.

projection_target: Literal['per_module', 'global'] = 'per_module'

Projection target. per_module does a double-sided random projection of each module’s gradient independently. global does an independent single-sided right projection of each module’s flattened gradient then sums the results, producing one [proj_dim] vector per example.

projection_type: Literal['normal', 'rademacher'] = 'rademacher'

Type of random projection to use for compressing gradients. Can be either “normal” for Gaussian projections or “rademacher” for Rademacher projections, which use a uniform distribution over {-1, 1}.

reshape_to_square: bool = False

Whether to reshape the gradients into a nearly square matrix before projection. This is useful when the matrix-valued parameters are far from square, like in the case of LoRA adapters.

save(path)

Save the normalizers and hessians to a file.

class bergson.InMemoryCollector(model, filter_modules=None, target_modules=None, processor=<factory>, attention_cfgs=<factory>, attribute_tokens=False, lo=-inf, hi=inf, *, data, cfg, skip_hessians=True, mod_grads=<factory>, preprocess_cfg=<factory>, builder=None, scorer=None, gradients=<factory>, scores=None)

Bases: HookCollectorBase

Collector that accumulates gradients in memory.

Supports both per-example and per-token gradient collection via cfg.attribute_tokens. Uses in-memory builder (Builder) for flat gradient storage and an optional Scorer for on-the-fly scoring.

After collection, self.gradients is populated from the builder’s buffer in teardown(), providing per-module gradient tensors for downstream use.

backward_hook(module, g)

Compute per-sample gradient, accumulate hessian, and store.

Return type:

None

builder: Builder | None = None

Handles writing gradients. Created in setup().

cfg: IndexConfig

Configuration for gradient collection.

data: Dataset

The dataset being processed.

gradients: dict[str, Tensor]

Per-module gradients, populated from builder in teardown.

mod_grads: dict

Temporary per-batch gradients keyed by module name.

preprocess_cfg: PreprocessConfig

Configuration for gradient preprocessing.

process_batch(indices, **kwargs)

Process collected data for a batch. This is called after each forward/backward pass. See also CollectorComputer.run_with_collector_hooks.

Parameters:
  • indices – List of data indices in the current batch

  • **kwargs – Additional batch-specific data (e.g., losses)

Return type:

None

scorer: Scorer | None = None

Optional scorer for on-the-fly scoring.

scores: list | Tensor | None = None

Scores populated from scorer’s writer in teardown.

setup()

Initialize collector state and create builder.

Return type:

None

skip_hessians: bool = True

Whether to skip estimating autocorrelation hessian statistics.

teardown()

Called at the end.

Override to perform custom cleanup such as: - Saving results to disk - Flushing buffers - Computing final statistics - Freeing resources

Return type:

None

class bergson.IndexConfig(run_path, overwrite=False, model='EleutherAI/pythia-160m', precision='fp32', revision=None, distributed=<factory>, fsdp=False, peft_init_kwargs='', model_kwargs='', data=<factory>, tokenizer='', drop_columns=True, max_tokens=None, use_tf32_matmuls=False, debug=False, projection_dim=16, include_bias=False, reshape_to_square=False, projection_type='rademacher', projection_target='per_module', token_batch_size=2048, max_batch_size=None, auto_batch_size=False, optimizer_state='', skip_index=False, stats_sample_size=10000, loss_fn='ce', loss_reduction='sum', label_smoothing=0.0, stream_shard_size=400000, split_attention_modules=<factory>, attention=<factory>, profile=False, filter_modules=None, force_math_sdp=False, attribute_tokens=False, modules=<factory>)

Bases: AttributionConfig, Serializable

Config for building the index and running the model/dataset pipeline.

attention: AttentionConfig

Configuration for each attention module to be split into head matrices. Used for attention modules specified in split_attention_modules.

attribute_tokens: bool = False

Whether to compute per-token gradients instead of per-example. Incompatible with reduce mode.

auto_batch_size: bool = False

Whether to automatically determine the optimal token batch size. Experimental feature only enabled for build.

decode_into_subclasses: ClassVar[bool] = False
filter_modules: str | None = None

If provided, a glob pattern to filter out modules from gradient collection. For example, “transformer.h.*.mlp.*” will exclude all MLP layers in a standard transformer architecture.

force_math_sdp: bool = False

Disable flash and memory-efficient SDPA backends, forcing the math-only kernel. Some models produce inconsistent gradients across different padding lengths when using optimized attention backends. Run bergson test_model_configuration to check whether your model needs this.

include_bias: bool = False

Whether to include linear layers’ bias gradients.

label_smoothing: float = 0.0

Label smoothing coefficient for cross-entropy loss. When > 0, prevents near-zero gradients for high-confidence predictions that can cause numerical instability.

loss_fn: Literal['ce', 'kl'] = 'ce'

Loss function to use.

loss_reduction: Literal['mean', 'sum'] = 'sum'

Reduction method for the loss function.

max_batch_size: int | None = None

Cap the number of documents per batch.

modules: list[str]

Modules to use for the query. If empty, all modules will be used.

optimizer_state: str = ''

Source for optimizer second moments used to normalize gradients. Either a local path (a checkpoint directory containing optimizer.pt, or a path to an optimizer state file directly) or a Hugging Face URI hf://<repo>[@<revision>][/<path>].

property partial_run_path: Path

Temporary path to use while writing build artifacts.

profile: bool = False

Whether to enable profiling during gradient collection. If true, by default the first 4 steps will be profiled.

projection_dim: int = 16

Dimension of the random projection for the index, or 0 to disable it.

projection_target: Literal['per_module', 'global'] = 'per_module'

Projection target. per_module does a double-sided random projection of each module gradient. global projects each module’s flattened gradient with an independent right-side matrix and sums into one vector per example.

projection_type: Literal['normal', 'rademacher'] = 'rademacher'

Type of random projections to use for the gradients.

reshape_to_square: bool = False

Whether to reshape the gradients to a square matrix.

skip_index: bool = False

Whether to skip building the gradient index.

split_attention_modules: list[str]

Modules to split into head matrices.

stats_sample_size: int | None = 10000

Number of examples to use for estimating the autocorrelation Hessian. This feature is experimental and may be removed.

stream_shard_size: int = 400000

Shard size for streaming the dataset into Dataset objects.

token_batch_size: int = 2048

Batch size in tokens for building the index.

class bergson.PreprocessConfig(unit_normalize=False, hessian_path=None, aggregation='none', normalize_aggregated_grad=False)

Bases: Serializable

Config for gradient preprocessing, shared across build, reduce, and score.

aggregation: Literal['mean', 'sum', 'none'] = 'none'

Method for aggregating the gradients. In score, only query gradients will be aggregated.

decode_into_subclasses: ClassVar[bool] = False
hessian_path: str | None = None

Path to a precomputed gradient processor. Set to apply Hessian approx.

normalize_aggregated_grad: bool = False

Whether to unit normalize the aggregated gradient. This has no effect on future relative score rankings but does affect score magnitudes.

unit_normalize: bool = False

Whether to unit normalize the gradients.

class bergson.QueryConfig(index='', model='', text_field='text', unit_norm=True, device_map_auto=False, faiss=False, top_k=5, record='')

Bases: Serializable

Config for querying an existing gradient index.

decode_into_subclasses: ClassVar[bool] = False
device_map_auto: bool = False

Load the model onto multiple devices if necessary.

faiss: bool = False

Whether to use FAISS for the query.

index: str = ''

Path to the existing index.

model: str = ''

Model to use for the query. When not provided the model used to build the index is used.

record: str = ''

Path to a CSV file for recording query results. Each query appends its top and bottom results as rows with columns: query, direction, result, result_index, score.

text_field: str = 'text'

Field to use for the query.

top_k: int = 5

Number of top (and bottom) results to return per query.

unit_norm: bool = True

Whether to unit normalize the query.

class bergson.ScoreConfig(query_path='', score='individual', batch_size=1024, precision='fp32', modules=<factory>, higher_is_better=True)

Bases: Serializable

Config for querying an index on the fly.

batch_size: int = 1024

Batch size for processing the query dataset.

decode_into_subclasses: ClassVar[bool] = False
higher_is_better: bool = True

True when a positive scoring item is a proponent of the query capability (e.g. in influence functions). False for unrolled differentiation.

modules: list[str]

Modules to use for the query. If empty, all modules will be used.

precision: Literal['auto', 'bf16', 'fp16', 'fp32'] = 'fp32'

Precision (dtype) to convert the query and index gradients to before computing the scores. If “auto”, the model’s gradient dtype is used.

query_path: str = ''

Path to the existing query index.

score: Literal['nearest', 'individual'] = 'individual'

Method for scoring the gradients with the query. nearest: compute each gradient’s similarity to the most

similar query gradient (the maximum score).

individual: compute a separate score for each query gradient.

class bergson.Scorer(query_grads, modules, writer, device, dtype, *, unit_normalize=False, score_mode='individual', attribute_tokens=False, index_transform=<function Scorer.<lambda>>)

Bases: object

Scores training gradients against query gradients.

Accepts an optional index_transform callable that is applied to each batch of index gradients before scoring. This can be used for preconditioning, projection, or any other per-batch transformation. When no transform is needed, pass None (identity is used).

Accepts a ScoreWriter for saving the scores (disk or in-memory).

score(index_grads)

Compute scores for a batch of gradients.

Return type:

Tensor

class bergson.TokenGradients(root_dir)

Bases: object

Convenience wrapper around the flat per-token gradient memmap.

Provides __getitem__ to retrieve a single example’s gradients as a contiguous array of shape (num_token_grads[i], grad_dim).

Parameters:

root_dir (Path | str) – Directory produced by create_token_index().

property num_token_grads: ndarray
bergson.collect_gradients(model, data, processor, cfg, *, skip_hessians=True, batches=None, target_modules=None, attention_cfgs=None, scorer=None, preprocess_cfg=None)

Compute gradients using the hooks specified in the GradientCollector.

bergson.load_from_optimizer(model, optimizer_state, include_bias=False, target_modules=None)

Load optimizer second moments from a checkpoint and create normalizer instances for each target linear layer.

Auto-detects the optimizer format:

  • Adam/AdamW: exp_avg_sq -> AdamNormalizer

  • Adafactor: exp_avg_sq_row/exp_avg_sq_col -> AdafactorNormalizer

  • 8-bit Adam (BitsAndBytes): state2 -> AdamNormalizer

Parameters:
  • model – The model whose parameter names are used to map optimizer state indices to layer names.

  • optimizer_state – Local path to an optimizer state file or a checkpoint directory containing optimizer.pt, or a Hugging Face URI hf://<repo>[@<revision>][/<path>] (see load_optimizer()).

  • include_bias – Whether to include bias second moments.

  • target_modules – Optional set of module names to include. If None, all linear layers are included.

Return type:

dict[str, Normalizer]

Returns:

Dictionary mapping layer names to normalizer instances.

bergson.load_gradient_dataset(root_dir, structured=True)

Load a dataset of gradients from root_dir.

Return type:

Dataset

bergson.load_gradients(root_dir, structured=True)

Map the structured gradients stored in root_dir into memory.

Return type:

memmap

bergson.load_token_gradients(root_dir)

Load per-token gradients stored by create_token_index().

Returns:

mmap has shape (total_tokens, total_grad_dim). Example i’s gradients are mmap[offsets[i]:offsets[i+1]] with shape (num_token_grads[i], total_grad_dim).

Return type:

tuple[memmap, ndarray, ndarray]

bergson.mix_autocorrelation_matrices(query_path, index_path, output_path, target_downweight_components=1000)

Mix query and index autocorrelation matrices and save the result to disk.

Computes H_mixed = coeff * H_query + (1 - coeff) * H_index for every module’s raw H matrix, then persists a new GradientProcessor at output_path.

A mix_config.yaml file is also written alongside for provenance.

Parameters:
  • query_path (str | Path) – Directory containing the query GradientProcessor.

  • index_path (str | Path) – Directory containing the index GradientProcessor.

  • output_path (str | Path) – Directory where the mixed GradientProcessor will be saved.

  • target_downweight_components (int) – Number of gradient components to downweight via automatic lambda selection

Returns:

The output_path as a pathlib.Path.

Return type:

Path