Skip to content

Instantly share code, notes, and snippets.

View Birch-san's full-sized avatar

Birch-san

View GitHub Profile
@Birch-san
Birch-san / bench_repro.py
Created November 17, 2024 18:09
Enabling --count-flops-early (run a model under FlopCounterMode before benchmarkign it) regresses the performance of the compiled model
import argparse
import math
from dataclasses import dataclass
from enum import Enum
from typing import Callable, Optional
import torch
from einops import rearrange
from torch import (
BoolTensor,
@Birch-san
Birch-san / t5_enc_attn_bench.py
Last active October 23, 2024 20:05
Benchmark various ways of doing T5 Encoder flex_attention against SDPA
from enum import Enum
from typing import Callable, Optional, Any
from einops import rearrange
from dataclasses import dataclass
import math
import torch
from torch import FloatTensor, LongTensor, IntTensor, BoolTensor, ByteTensor, no_grad, inference_mode
from torch.nn import Embedding, Linear, Module
from torch.nn.attention.flex_attention import BlockMask, flex_attention, create_block_mask, _score_mod_signature, _mask_mod_signature
from torch.nn.functional import scaled_dot_product_attention
@Birch-san
Birch-san / segfault.txt
Created August 21, 2024 22:13
stable-fast torch.jit.trace segfault
Caught signal 11 (Segmentation fault: address not mapped to object at address 0x20)
==== backtrace (tid: 63632) ====
0 0x0000000000042520 __sigaction() ???:0
1 0x0000000006e9fe76 torch::jit::InterpreterStateImpl::callstack() interpreter.cpp:0
2 0x0000000006ea0172 torch::jit::InterpreterStateImpl::handleError() interpreter.cpp:0
3 0x0000000006eac9fb torch::jit::InterpreterStateImpl::runTemplate<false>() interpreter.cpp:0
4 0x0000000006eb0585 torch::jit::InterpreterStateImpl::run() interpreter.cpp:0
5 0x0000000006e897b3 torch::jit::GraphExecutorImplBase::run() graph_executor.cpp:0
6 0x0000000000d3d859 torch::jit::runAndInsertCall() python_custom_class.cpp:0
7 0x0000000000e4208b torch::jit::invokeScriptMethodFromPython() script_init.cpp:0
@Birch-san
Birch-san / model_watch.py
Last active May 21, 2024 14:36
Watch your activation norms fly into the sunset
# Contains MIT-licensed code from wandb
# https://github.com/wandb/wandb/blob/main/LICENSE
# This gist is MIT-licensed (Copyright Alex Birch)
from torch import Tensor, FloatTensor
from torch.nn import Module
from torch.utils.hooks import RemovableHandle
import torch
from typing import List, Callable, Dict, Sequence, Optional, Tuple, Any
from wandb.wandb_torch import log_track_init, log_track_update
@Birch-san
Birch-san / .gitconfig
Created April 2, 2024 15:10
Using fine-grained access token to access your organisation's private GitHub repositories
[url "https://oauth2:github_pat_REDACTED@github.com/"]
insteadOf = https://github.com/
[url "https://oauth2:github_pat_REDACTED@github.com/MYCOOLORG/"]
insteadOf = git@github.com:MYCOOLORG/
@Birch-san
Birch-san / img-folder-chunking.md
Last active April 24, 2024 12:07
Chunking a folder of pngs into .tar files

Uploading a folder of many files to HF, by chunking it into .tars

So you generated 50000 images for computing FID or whatever, and now you want to upload those samples to HF.
You try, but one of the filetransfers fails, and you lose all your progress.
I mean it'd be nice if HF could just… fix this… like, put retries into huggingface-cli upload instead of just discarding tens of gigabytes of progress… but we live in the world in which we live.

So let's make it easier. instead of 50k small files, let's upload 50 big files. Collate 'em into .tars.

I'm not sure this makes a valid WDS, but it's close; I think you would need to rename the files to 000000.img.png if you wanted that.

@Birch-san
Birch-san / installing-python-proxy.md
Last active January 29, 2024 10:36
Installing Python when behind a corporate proxy

Behind a corporate proxy? Can't add PPAs to your apt listings?

A typical HTTP proxy URL may look like:
http://proxy.mycoolproxy.com:8080

Let's configure all our tools to use this proxy.

apt

sudo nano /etc/apt/apt.conf.d/00proxy.conf
@Birch-san
Birch-san / install-bpftrace-on-wsl2.md
Last active January 29, 2024 10:02
Installing bpftrace on WSL2
wsl --update --web-download
wsl --install -d Ubuntu-22.04 --web-download
wsl --setdefault Ubuntu-22.04
sudo apt-get install -y bpftrace bpftrace-dbgsym linux-headers-generic libc6-dev
@Birch-san
Birch-san / 8bit_adam_memory_usage.md
Last active October 3, 2023 18:20
Unexplained memory usage of 8-bit AdamW (paged vs unpaged)

Some weird memory usage (VRAM) is reported (by torch and by NVML) when using 8-bit AdamW, paged or unpaged.

Here we train llama 2 on 4096-token sequences, using either --optim adamw_8bit or --optim paged_adamw_8bit.
We do a full finetune using qlora.py --full-finetune, with our qlora.py fork, stepwise branch, commit 9a1045d.
We print the memory usage using HF transformers trainer's on_step_end callback. This is after optimizer.step(); model.zero_grad().

One would expect the memory usage at the end of step 1 to be the same as the end of step 2.
Yet for unpaged optimizer: memory usage leaps by 13.2GiB. End of step 1=70.4GiB, end of step 2=81.6GiB.
This appears to be a leap in PyTorch reserved memory only (32.6GiB -> 43.9GiB).

@Birch-san
Birch-san / t5-small-weight-inits.py
Created October 1, 2023 15:04
google/t5-v1_1-small t5-small weight initializations
import torch
from transformers import T5ForConditionalGeneration
model: T5ForConditionalGeneration = T5ForConditionalGeneration.from_pretrained('google/t5-v1_1-small')
_inference_mode_context = torch._C._InferenceMode(True)
_inference_mode_context.__enter__()
model.shared.weight.std()
tensor(11.6375)