Skip to content

Instantly share code, notes, and snippets.

@dvruette
dvruette / ada_numpy.py
Last active June 11, 2024 10:49
ada_numpy.py
from contextlib import contextmanager
@contextmanager
def set_numpy(name):
if name == "numpy":
import numpy as np
yield np
elif name == "jax":
import jax.numpy as jnp
yield jnp
@dvruette
dvruette / grid.diff
Last active January 18, 2024 22:12
Faster Grid
diff --git a/grid.py b/grid.py
index f9f1557..8eafb91 100755
--- a/grid.py
+++ b/grid.py
@@ -5,8 +5,10 @@
# Written by Francois Fleuret <francois@fleuret.org>
-import math
-import torch, torchvision
@dvruette
dvruette / decay_to_init.py
Created January 18, 2024 08:54
Weight decay to model initialization
import copy
import torch
import torch.nn as nn
class DecayToInit(nn.Module):
def __init__(self, param: torch.Tensor):
super().__init__()
self.register_buffer("param", param)
@dvruette
dvruette / adam_on_lion.py
Last active March 16, 2023 15:01
Adam#Lion
"""
Based on: https://github.com/lucidrains/lion-pytorch/blob/main/lion_pytorch/lion_pytorch.py
"""
from typing import Tuple, Optional, Callable
import torch
from torch.optim.optimizer import Optimizer