Skip to content

Instantly share code, notes, and snippets.

View lucidrains's full-sized avatar

Phil Wang lucidrains

  • San Francisco
View GitHub Profile
@lucidrains
lucidrains / tree_attn_decode.py
Created August 12, 2024 17:48
Tree Attention Decoding
import torch
from torch import einsum
import torch.distributed as dist
def tree_attn_decode(q, k, v):
"""
Algorithm 3 proposed in Tree Attention
https://arxiv.org/abs/2408.04093
"""
@lucidrains
lucidrains / vit_with_mask.py
Created December 8, 2022 20:14
ViT, but you can pass in images with patches masked out
import torch
import torch.nn.functional as F
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def pair(t):
@lucidrains
lucidrains / uniprot_mapping.py
Created January 7, 2022 05:28
uniprot mapping for python3
import urllib
import urllib.parse
from urllib.request import urlopen
def uniprot_mapping(fromtype, totype, identifier):
base = 'http://www.uniprot.org'
tool = 'mapping'
params = {
'from': fromtype,
# Schedules with t from 0-1, eg use as lr_sch(t/steps)
def lr_sch(t):
left_br = 20 * t - 5
right_br = - (1.45 * t + 2.08)
def denom(sign):
return (1 + jnp.exp(- sign * (19 * (t - 0.015))))
return 10 ** ((left_br / denom(-1)) + (right_br / denom(+1)))
def wd_sch(t):
return 10 ** (-np.log(np.exp( 10.7 * t - 2.7) + 1) - 2 )
@lucidrains
lucidrains / faster_rng.py
Created June 2, 2021 17:08
faster rng for jax
def hardware_uniform(rng_key: PRNGKey,
shape: Shape,
dtype: Dtype = np.float32,
minval: Array = np.float32(0),
maxval: Array = np.float32(1)) -> Array:
del rng_key # non-deterministic prng.
minval = lax.convert_element_type(minval, dtype)
maxval = lax.convert_element_type(maxval, dtype)
return lax.rng_uniform(minval, maxval, shape)
def split(arr: torch.Tensor, splits, dim=0):
axis_len = arr.shape[dim]
splits = min(axis_len, max(splits, 1))
chunk_size = axis_len // splits
remainder = axis_len - chunk_size * splits
s = 0
for i in range(splits):
adjust, remainder = 1 if remainder > 0 else 0, remainder - 1
yield torch.narrow(arr, dim, s, chunk_size + adjust)
s += chunk_size + adjust
import torch
import torch.nn.functional as F
from torch.optim import Adam
from einops import rearrange, repeat
import sidechainnet as scn
from se3_transformer_pytorch.se3_transformer_pytorch import SE3Transformer
torch.set_default_dtype(torch.float64)
import torch
from torch import nn, einsum
from einops import rearrange, repeat
class FixedPositionalEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
import torch
import torch.nn.functional as F
from torch import nn
from torch.optim import Adam
from einops import rearrange, repeat
import sidechainnet as scn
from en_transformer.en_transformer import EnTransformer
torch.set_default_dtype(torch.float64)
import torch
import torch.nn.functional as F
from torch.optim import Adam
from einops import rearrange, repeat
import sidechainnet as scn
from se3_transformer_pytorch.se3_transformer_pytorch import SE3Transformer
torch.set_default_dtype(torch.float64)