Skip to content

Instantly share code, notes, and snippets.

@lucidrains
Last active July 8, 2021 15:16
Show Gist options
  • Save lucidrains/5c4151acf9fa73ab1ab93a8e0aa920f2 to your computer and use it in GitHub Desktop.
Save lucidrains/5c4151acf9fa73ab1ab93a8e0aa920f2 to your computer and use it in GitHub Desktop.
import torch
from torch import nn
from reformer_pytorch import ReformerLM
class ReformerClassifier(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.net = ReformerLM(
num_tokens= 20000,
dim = 1024,
depth = 1,
max_seq_len = 4096,
lsh_dropout = 0.1,
full_attn_thres = 1024,
return_embeddings = True
)
self.linear_out = nn.Linear(1024, num_classes)
def forward(self, x):
embs = self.net(x)
return self.linear_out(embs.sum(dim=1))
c = ReformerClassifier(2)
x = torch.randint(0, 20000, (1, 4096))
y = c(x) # (1, 2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment