Skip to content

Instantly share code, notes, and snippets.

View norabelrose's full-sized avatar

Nora Belrose norabelrose

View GitHub Profile
@norabelrose
norabelrose / train.py
Last active December 8, 2023 22:19
Features across time training script
from argparse import ArgumentParser
from dataclasses import dataclass
import torch
import torchvision.transforms as T
from concept_erasure import QuadraticEditor, QuadraticFitter
from datasets import (
ClassLabel, Dataset, DatasetDict, Features, Image, load_dataset
)
from einops import rearrange
@norabelrose
norabelrose / extract.py
Created November 15, 2023 05:59
Hidden state extraction
from argparse import ArgumentParser
from pathlib import Path
from datasets import Dataset, load_dataset
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
@norabelrose
norabelrose / dpo.py
Created November 8, 2023 07:04
Training quirky models with DPO
from argparse import ArgumentParser
from datasets import load_dataset
from peft import LoraConfig
from trl import DPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
if __name__ == "__main__":
parser = ArgumentParser()
@norabelrose
norabelrose / training-code.py
Created October 24, 2023 00:36
training code
from itertools import pairwise
from typing import Literal
import pytorch_lightning as pl
import torch
import torchmetrics as tm
import torchvision as tv
from torch import nn
from torch.optim import RAdam
from torch.optim.lr_scheduler import CosineAnnealingLR
@norabelrose
norabelrose / moments.py
Last active October 22, 2023 05:18
Blocked moment generator
from itertools import (
combinations_with_replacement as pyramid
)
from typing import Iterable
import math
from opt_einsum import get_symbol
from torch import Tensor
import torch
@norabelrose
norabelrose / triton-covariance.py
Last active October 20, 2023 10:30
Compute covariance matrix in Triton
from itertools import product
import torch
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({'BLOCK_N': n, 'BLOCK_D': d, 'GROUP_SIZE_D': 8}, num_stages=4, num_warps=4)
@norabelrose
norabelrose / cumulants.py
Last active October 19, 2023 03:56
Ryan Greenblatt's cumulant estimation code
from typing import Optional
import torch
def get_all_the_cumulants(
x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, w: torch.Tensor, weights_in: Optional[torch.Tensor] = None
):
if weights_in is not None:
weights = weights_in
weights = weights / weights.sum()
@norabelrose
norabelrose / cdf-erasure.py
Created October 3, 2023 07:55
Erasing CIFAR-10 classes with componentwise probability integral transform
from argparse import ArgumentParser
from itertools import pairwise
from pathlib import Path
from typing import Callable, Sized
import random
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torchmetrics as tm
@norabelrose
norabelrose / classifier.py
Last active October 2, 2023 14:15
CUDA-enabled logistic regression with CV
from dataclasses import dataclass, field
import torch
from torch import Tensor
from torch.nn.functional import (
binary_cross_entropy_with_logits as bce_with_logits,
)
from torch.nn.functional import (
cross_entropy,
)
@norabelrose
norabelrose / cifar-leace.py
Last active September 30, 2023 07:36
messy cifar leace testing
from argparse import ArgumentParser
from typing import Any, Callable, Protocol, Sized, Type
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torchmetrics as tm
import torchvision as tv
from concept_erasure import LeaceFitter, OracleFitter, QuadraticFitter
from pytorch_lightning.loggers import WandbLogger