Skip to content

Instantly share code, notes, and snippets.

View KeAWang's full-sized avatar

Alex Wang KeAWang

View GitHub Profile
@KeAWang
KeAWang / mfu_compute.py
Created April 11, 2024 17:17 — forked from Chillee/mfu_compute.py
Compute Flop Utilization in PyTorch
import torch
from torch.utils.flop_counter import FlopCounterMode
from triton.testing import do_bench
def get_flops_achieved(f):
flop_counter = FlopCounterMode(display=False)
with flop_counter:
f()
total_flops = flop_counter.get_total_flops()
ms_per_iter = do_bench(f)
@KeAWang
KeAWang / tcn_experiment.py
Last active November 20, 2023 17:37
TCN experiment with correct residual connection
# %%
import torch
import numpy as np
def make_adding_dataset(num_seqs, seq_len, num_terms=2, seed=43141):
assert 0 <= num_terms <= seq_len
rng = np.random.default_rng(seed=seed)
numbers = rng.uniform(0, 1, (num_seqs, seq_len)) # B x T
mask = np.zeros_like(numbers) # B x T
non_zero = np.stack([rng.choice(seq_len, num_terms, replace=True) for _ in range(num_seqs)]) # B x 2
mask[np.arange(num_seqs)[:, None], non_zero] = 1 # mask[i, non_zero[i, j]]
@KeAWang
KeAWang / nan_embedder.py
Created November 14, 2023 18:42
PyTorch NaN embedder
import torch
class NanWrapper(torch.nn.Module):
"""Wrapper module around a torch Module that handles incoming nans"""
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, x):
""" Masks the entire last dimension (usually the feature/channel dimension) if any element is NaN. """
@KeAWang
KeAWang / count_torch_params.py
Last active November 14, 2023 18:35
Count number of pytorch parameters
import torch
def count_params(model: torch.nn.Module):
"""count number trainable parameters in a pytorch model"""
total_params = sum(torch.numel(x) for x in model.parameters())
return total_params
@KeAWang
KeAWang / constrain.py
Created November 14, 2023 18:32
Constrain and unconstrain
import torch
def constrain(x, min, max, temperature:float=1.):
return (max - min) * torch.sigmoid(x / temperature) + min
def unconstrain(y, min, max, temperature:float=1, EPS:float=1e-8):
assert torch.all(y >= min) and torch.all(y <= max)
# ensure both numerator and denominator are positive
numerator = y - min
@KeAWang
KeAWang / dexcom_palette.py
Created October 23, 2023 03:46
Dexcom Clarity Color Palette
import numpy as np
tir_palette = {
"very_low": "#A61D2A",
"low": "#EE1D23",
"in_range": "#26B257",
"high": "#FAAB1A",
"very_high": "#F47D21"
}
def color_bg(bgs):
bins = [0, 54, 70, 180, 250, 1000]
@KeAWang
KeAWang / array_to_dataframe.py
Created May 16, 2023 22:57
Multidimensional array to pandas dataframe
import pandas as pd
from typing import Optional, List
def array_to_dataframe(array, axis_names: Optional[List[str]]=None):
"""Based on https://stackoverflow.com/questions/35525028/how-to-transform-a-3d-arrays-into-a-dataframe-in-python"""
if axis_names is None:
axis_names = list(range(array.ndim))
index = pd.MultiIndex.from_product([range(s) for s in array.shape], names=names)
df = pd.DataFrame({"array": array.flatten()}, index=index)["array"]
@KeAWang
KeAWang / parameterized_array.py
Last active April 10, 2023 18:48
Parameterized Arrays in Jax
# %%
import jax.numpy as jnp
import jax
import equinox as eqx
from typing import Union, Any
from abc import ABC, abstractmethod
MaybeParameterizedArray = Union[jax.Array, "ParameterizedArray"]
@KeAWang
KeAWang / data_namedtuple.py
Created November 11, 2022 20:41
Pytorch data inheriting namedtuples
from collections import namedtuple
class Data(namedtuple("Data", ("x", "y"))):
def to(self, device, non_blocking=False):
x = self.x.to(device, non_blocking=non_blocking)
y = self.y.to(device, non_blocking=non_blocking)
return Data(x, y)
def contiguous(self):
@KeAWang
KeAWang / wandb_utils.py
Created October 1, 2022 03:31
W&B utils
import os
import pickle
from pathlib import Path
from wandb.errors import CommError
import wandb
def get_history(user="", project="", query={}, **kwargs):
api = wandb.Api()