Skip to content

Instantly share code, notes, and snippets.

View KeAWang's full-sized avatar

Alex Wang KeAWang

View GitHub Profile
@KeAWang
KeAWang / UCI Dataset
Created December 9, 2021 01:04
Code for loading UCI datasets used in the Deep Kernel Learning paper
from typing import Union, List
from pathlib import Path
import torch
from torch.utils.data import Dataset
from scipy.io import loadmat
import itertools
import os
def znormalize(train_x, train_y, test_x, test_y):
@KeAWang
KeAWang / tcn.py
Last active October 23, 2023 23:52
Temporal Convolutional Network in PyTorch (https://arxiv.org/abs/1803.01271)
import torch
from typing import List
import torch.nn.functional as F
def receptive_field(kernel_size: int, dilation: int):
return 1 + (kernel_size - 1) * dilation
class Seq2SeqConv1d(torch.nn.Module):
@KeAWang
KeAWang / tree_stack.py
Last active June 5, 2023 16:05 — forked from willwhitney/tree_stack.py
utils for stacking and unstacking jax pytrees to deal with vmap
import jax
import jax.numpy as jnp
def tree_stack(trees):
"""Takes a list of trees and stacks every corresponding leaf.
For example, given two trees ((a, b), c) and ((a', b'), c'), returns
((stack(a, a'), stack(b, b')), stack(c, c')).
Useful for turning a list of objects into something you can feed to a
vmapped function.
@KeAWang
KeAWang / ts_dataset.py
Last active September 2, 2022 05:53
Pytorch timeseries Dataset and rolling windows Dataset
import torch
from torch.utils.data import Dataset
class TimeSeriesDataset(Dataset):
def __init__(
self,
ts: torch.Tensor,
x_ts: torch.Tensor,
normalize=True,
@KeAWang
KeAWang / mlp.py
Created September 8, 2022 20:47
PyTorch MLP
from collections import OrderedDict
import torch
from torch import Tensor, Size
from torch.nn import Linear
class MLP(torch.nn.Sequential):
"""Multi-layered perception, i.e. fully-connected neural network
Args:
@KeAWang
KeAWang / wandb_utils.py
Created October 1, 2022 03:31
W&B utils
import os
import pickle
from pathlib import Path
from wandb.errors import CommError
import wandb
def get_history(user="", project="", query={}, **kwargs):
api = wandb.Api()
@KeAWang
KeAWang / data_namedtuple.py
Created November 11, 2022 20:41
Pytorch data inheriting namedtuples
from collections import namedtuple
class Data(namedtuple("Data", ("x", "y"))):
def to(self, device, non_blocking=False):
x = self.x.to(device, non_blocking=non_blocking)
y = self.y.to(device, non_blocking=non_blocking)
return Data(x, y)
def contiguous(self):
@KeAWang
KeAWang / parameterized_array.py
Last active April 10, 2023 18:48
Parameterized Arrays in Jax
# %%
import jax.numpy as jnp
import jax
import equinox as eqx
from typing import Union, Any
from abc import ABC, abstractmethod
MaybeParameterizedArray = Union[jax.Array, "ParameterizedArray"]
@KeAWang
KeAWang / array_to_dataframe.py
Created May 16, 2023 22:57
Multidimensional array to pandas dataframe
import pandas as pd
from typing import Optional, List
def array_to_dataframe(array, axis_names: Optional[List[str]]=None):
"""Based on https://stackoverflow.com/questions/35525028/how-to-transform-a-3d-arrays-into-a-dataframe-in-python"""
if axis_names is None:
axis_names = list(range(array.ndim))
index = pd.MultiIndex.from_product([range(s) for s in array.shape], names=names)
df = pd.DataFrame({"array": array.flatten()}, index=index)["array"]
@KeAWang
KeAWang / dexcom_palette.py
Created October 23, 2023 03:46
Dexcom Clarity Color Palette
import numpy as np
tir_palette = {
"very_low": "#A61D2A",
"low": "#EE1D23",
"in_range": "#26B257",
"high": "#FAAB1A",
"very_high": "#F47D21"
}
def color_bg(bgs):
bins = [0, 54, 70, 180, 250, 1000]