Created
March 4, 2025 05:20
-
-
Save bfpill/16d8fe380879e6b2808c88be6c3993a1 to your computer and use it in GitHub Desktop.
This code is for use with the Devinterp repo. Research done using this code can be found here: https://max.v3rv.com/in-progress/circuits_and_memorization
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torch.utils.data import DataLoader, SequentialSampler | |
| from devinterp.slt.callback import SamplerCallback | |
| import warnings | |
| class AtomicCallback(SamplerCallback): | |
| """ | |
| This callback allows efficient collection of per-input losses and LLC's at each step of SGLD. | |
| It extends the SamplerCallback and is used to compute and store initial losses, learning loss coefficients (LLCs), | |
| and optionally, probabilities for each input in the evaluation dataset. | |
| Parameters: | |
| ----------- | |
| evaluation_loader : DataLoader | |
| A DataLoader object that provides batches of data for evaluation. | |
| evaluate_fn : function | |
| A function that takes a model and a batch of data, and returns a dictionary containing at least 'per_input_losses'. | |
| The loss function should reutnr a loss for each input, for standard torch L2 loss this can be done with reduction='none' | |
| Example: | |
| def evaluate_fn(model, batch): | |
| x, y = batch['data'][0], batch['data'][1] | |
| x, y = x.to(device), y.to(device) | |
| with torch.no_grad(): | |
| y_pred = model(x) | |
| per_input_losses = criterion(y_pred, y) | |
| return {"per_input_losses": per_input_losses} | |
| nbeta : float | |
| A scaling factor used in the computation of LLCs. | |
| input_info_extractor : function, optional | |
| A function that extracts additional information from each batch, such as indices or labels. | |
| Example: | |
| def input_info_extractor(batch): | |
| mislabeled = [bool(x) for x in batch['mislabeled']] | |
| label = [int(x) for x in batch['data'][1]] | |
| index = [int(k) for k in batch['index']] | |
| return [{"mislabeled": m, "label": l, "index": k} for m, l, k in zip(mislabeled, label, index)] | |
| device : str, optional | |
| The device on which computations will be performed (default is "cpu"). | |
| Usage: | |
| ------ | |
| callback = AtomicCallback( | |
| evaluation_loader=aLLC_collect_loader, | |
| evaluate_fn=lambda model, data: single_input_evaluate(torch.nn.CrossEntropyLoss(reduction="none"), model, data), | |
| nbeta=NBETA, | |
| device=device, | |
| input_info_extractor=mislabeled_extractor | |
| ) | |
| The callback is then passed to a sampling function, such as estimate_learning_coeff_with_summary, which uses it to record the learning process. | |
| After the sampling process, the results are extracted using get_results() and can be saved to a CSV file for further analysis. | |
| Example: | |
| -------- | |
| get_llc(model, aLLC_sampling_loader, callback=callback) | |
| llc_results = callback.get_results() | |
| save_llc_results_to_csv(llc_results=llc_results, file_path="./test.csv") | |
| """ | |
| def __init__( | |
| self, | |
| evaluation_loader: DataLoader, | |
| evaluate_fn, | |
| nbeta, | |
| input_info_extractor=None, | |
| device="cpu", | |
| ): | |
| super().__init__(device) | |
| self.evaluation_loader = evaluation_loader | |
| self.evaluate_fn = evaluate_fn | |
| self.input_info_extractor = input_info_extractor | |
| self.nbeta = nbeta | |
| self.init_losses_tensor = None | |
| self.input_info_list = [] | |
| self.per_input_llcs = [] | |
| self.per_input_probs = None | |
| if hasattr(self.evaluation_loader, "sampler"): | |
| if not isinstance(self.evaluation_loader.sampler, SequentialSampler): | |
| warnings.warn( | |
| "THE EVAL LOADER IS SHUFFLED, ANY TRACES WILL BE INVALID" | |
| ) | |
| def _compute_init_losses(self, model): | |
| model.eval() | |
| all_losses = [] | |
| all_infos = [] | |
| with torch.no_grad(): | |
| for batch in self.evaluation_loader: | |
| results = self.evaluate_fn(model, batch) | |
| per_input_losses = results["per_input_losses"].cpu().numpy() | |
| all_losses.extend(per_input_losses) | |
| if self.input_info_extractor: | |
| extracted_infos = self.input_info_extractor(batch) | |
| all_infos.extend(extracted_infos) | |
| num_samples = len(all_losses) | |
| self.init_losses_tensor = torch.tensor(all_losses, device=self.device) | |
| self.per_input_llcs = [[] for _ in range(num_samples)] | |
| self.input_info_list = all_infos if self.input_info_extractor else [] | |
| sample_batch = next(iter(self.evaluation_loader)) | |
| if 'probs' in self.evaluate_fn(model, sample_batch): | |
| self.per_input_probs = [[] for _ in range(num_samples)] | |
| def __call__(self, chain, draw, model, **kwargs): | |
| if chain == 0 and draw == 0 and self.init_losses_tensor is None: | |
| self._compute_init_losses(model) | |
| model.eval() | |
| with torch.no_grad(): | |
| pos_counter = 0 | |
| for batch in self.evaluation_loader: | |
| results = self.evaluate_fn(model, batch) | |
| current_losses = results["per_input_losses"] | |
| batch_size = current_losses.shape[0] | |
| positions = torch.arange(pos_counter, pos_counter + batch_size, device=current_losses.device) | |
| init_losses = self.init_losses_tensor[positions].to(current_losses.device) | |
| llc_values = self.nbeta * (current_losses - init_losses) | |
| for i, llc in enumerate(llc_values.cpu().numpy()): | |
| self.per_input_llcs[pos_counter + i].append(llc) | |
| if 'probs' in results and self.per_input_probs is not None: | |
| probs = results['probs'].cpu().numpy() | |
| for i, prob in enumerate(probs): | |
| self.per_input_probs[pos_counter + i].append(prob) | |
| pos_counter += batch_size | |
| def get_results(self): | |
| results = { | |
| "init_losses": self.init_losses_tensor.cpu().numpy().tolist(), | |
| "per_input_llcs": self.per_input_llcs, | |
| "input_info": self.input_info_list | |
| } | |
| if self.per_input_probs is not None: | |
| results["per_input_probs"] = self.per_input_probs | |
| return results |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment