Skip to content

Instantly share code, notes, and snippets.

@bfpill
Created March 4, 2025 05:20
Show Gist options
  • Select an option

  • Save bfpill/16d8fe380879e6b2808c88be6c3993a1 to your computer and use it in GitHub Desktop.

Select an option

Save bfpill/16d8fe380879e6b2808c88be6c3993a1 to your computer and use it in GitHub Desktop.
This code is for use with the Devinterp repo. Research done using this code can be found here: https://max.v3rv.com/in-progress/circuits_and_memorization
import torch
from torch.utils.data import DataLoader, SequentialSampler
from devinterp.slt.callback import SamplerCallback
import warnings
class AtomicCallback(SamplerCallback):
"""
This callback allows efficient collection of per-input losses and LLC's at each step of SGLD.
It extends the SamplerCallback and is used to compute and store initial losses, learning loss coefficients (LLCs),
and optionally, probabilities for each input in the evaluation dataset.
Parameters:
-----------
evaluation_loader : DataLoader
A DataLoader object that provides batches of data for evaluation.
evaluate_fn : function
A function that takes a model and a batch of data, and returns a dictionary containing at least 'per_input_losses'.
The loss function should reutnr a loss for each input, for standard torch L2 loss this can be done with reduction='none'
Example:
def evaluate_fn(model, batch):
x, y = batch['data'][0], batch['data'][1]
x, y = x.to(device), y.to(device)
with torch.no_grad():
y_pred = model(x)
per_input_losses = criterion(y_pred, y)
return {"per_input_losses": per_input_losses}
nbeta : float
A scaling factor used in the computation of LLCs.
input_info_extractor : function, optional
A function that extracts additional information from each batch, such as indices or labels.
Example:
def input_info_extractor(batch):
mislabeled = [bool(x) for x in batch['mislabeled']]
label = [int(x) for x in batch['data'][1]]
index = [int(k) for k in batch['index']]
return [{"mislabeled": m, "label": l, "index": k} for m, l, k in zip(mislabeled, label, index)]
device : str, optional
The device on which computations will be performed (default is "cpu").
Usage:
------
callback = AtomicCallback(
evaluation_loader=aLLC_collect_loader,
evaluate_fn=lambda model, data: single_input_evaluate(torch.nn.CrossEntropyLoss(reduction="none"), model, data),
nbeta=NBETA,
device=device,
input_info_extractor=mislabeled_extractor
)
The callback is then passed to a sampling function, such as estimate_learning_coeff_with_summary, which uses it to record the learning process.
After the sampling process, the results are extracted using get_results() and can be saved to a CSV file for further analysis.
Example:
--------
get_llc(model, aLLC_sampling_loader, callback=callback)
llc_results = callback.get_results()
save_llc_results_to_csv(llc_results=llc_results, file_path="./test.csv")
"""
def __init__(
self,
evaluation_loader: DataLoader,
evaluate_fn,
nbeta,
input_info_extractor=None,
device="cpu",
):
super().__init__(device)
self.evaluation_loader = evaluation_loader
self.evaluate_fn = evaluate_fn
self.input_info_extractor = input_info_extractor
self.nbeta = nbeta
self.init_losses_tensor = None
self.input_info_list = []
self.per_input_llcs = []
self.per_input_probs = None
if hasattr(self.evaluation_loader, "sampler"):
if not isinstance(self.evaluation_loader.sampler, SequentialSampler):
warnings.warn(
"THE EVAL LOADER IS SHUFFLED, ANY TRACES WILL BE INVALID"
)
def _compute_init_losses(self, model):
model.eval()
all_losses = []
all_infos = []
with torch.no_grad():
for batch in self.evaluation_loader:
results = self.evaluate_fn(model, batch)
per_input_losses = results["per_input_losses"].cpu().numpy()
all_losses.extend(per_input_losses)
if self.input_info_extractor:
extracted_infos = self.input_info_extractor(batch)
all_infos.extend(extracted_infos)
num_samples = len(all_losses)
self.init_losses_tensor = torch.tensor(all_losses, device=self.device)
self.per_input_llcs = [[] for _ in range(num_samples)]
self.input_info_list = all_infos if self.input_info_extractor else []
sample_batch = next(iter(self.evaluation_loader))
if 'probs' in self.evaluate_fn(model, sample_batch):
self.per_input_probs = [[] for _ in range(num_samples)]
def __call__(self, chain, draw, model, **kwargs):
if chain == 0 and draw == 0 and self.init_losses_tensor is None:
self._compute_init_losses(model)
model.eval()
with torch.no_grad():
pos_counter = 0
for batch in self.evaluation_loader:
results = self.evaluate_fn(model, batch)
current_losses = results["per_input_losses"]
batch_size = current_losses.shape[0]
positions = torch.arange(pos_counter, pos_counter + batch_size, device=current_losses.device)
init_losses = self.init_losses_tensor[positions].to(current_losses.device)
llc_values = self.nbeta * (current_losses - init_losses)
for i, llc in enumerate(llc_values.cpu().numpy()):
self.per_input_llcs[pos_counter + i].append(llc)
if 'probs' in results and self.per_input_probs is not None:
probs = results['probs'].cpu().numpy()
for i, prob in enumerate(probs):
self.per_input_probs[pos_counter + i].append(prob)
pos_counter += batch_size
def get_results(self):
results = {
"init_losses": self.init_losses_tensor.cpu().numpy().tolist(),
"per_input_llcs": self.per_input_llcs,
"input_info": self.input_info_list
}
if self.per_input_probs is not None:
results["per_input_probs"] = self.per_input_probs
return results
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment