Last active
January 26, 2024 08:18
-
-
Save thistleknot/0d2bbced6264cd2ac508145797989638 to your computer and use it in GitHub Desktop.
Mamba GPT
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
"""SimplerMambaSSM.ipynb | |
Automatically generated by Colaboratory. | |
Original file is located at | |
https://colab.research.google.com/drive/1g9qpeVcFa0ca0cnhmqusO4RZtQdh9umY | |
""" | |
#!pip install mamba-ssm causal-conv1d | |
#!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt | |
#!mkdir differentattention | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
from torch.nn.parameter import Parameter | |
from tqdm import tqdm | |
from mamba_ssm import Mamba | |
import nltk | |
import pandas as pd | |
nltk.download("brown") | |
from nltk.corpus import brown | |
import random | |
from sklearn.model_selection import train_test_split | |
import numpy as np | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# hyperparams | |
epochs = 1 | |
lr = 1e-3 | |
batch_size = 48 | |
# 2048@48 / 7 chars per word (avg) = 292 words (30 min) 15GB VRAM | |
block_size = 2048 | |
stride = block_size // 2 # Example stride | |
# max_iters = 740 | |
# max_iters = 10 | |
print_iters = 100 | |
eval_iters = 10 | |
# eval_interval = 300 | |
n_embed = 384 | |
n_heads = 6 | |
n_layers = 6 | |
dropout = 0.2 | |
# --------- | |
# train and test splits | |
# Unique characters - Update to include BOS and EOS tokens | |
bos_token = "<BOS>" | |
eos_token = "<EOS>" | |
chars = sorted( | |
list( | |
set( | |
"".join([" ".join(brown.words(fileid)) for fileid in brown.fileids()]) | |
+ bos_token | |
+ eos_token | |
) | |
) | |
) | |
print("".join(chars)) | |
vocab_size = len(chars) | |
print(vocab_size) | |
# Update the tokenizers to include BOS and EOS | |
stoi = {ch: i for i, ch in enumerate(chars)} | |
itos = {i: ch for i, ch in enumerate(chars)} | |
stoi[bos_token] = len(chars) - 2 # Assign unique index for BOS | |
stoi[eos_token] = len(chars) - 1 # Assign unique index for EOS | |
itos[len(chars) - 2] = bos_token | |
itos[len(chars) - 1] = eos_token | |
# Update the encode and decode functions | |
encode = lambda xx: [stoi[x] for x in xx] | |
decode = lambda xx: "".join([itos[x] for x in xx]) | |
# Concatenate documents from the Brown Corpus with BOS and EOS tokens | |
brown_text = "".join( | |
[ | |
bos_token + " ".join(brown.words(fileid)) + eos_token | |
for fileid in brown.fileids() | |
] | |
) | |
# Encode the Brown Corpus text | |
data = torch.tensor(encode(brown_text), dtype=torch.long) | |
# Split into train and validation data | |
def get_batch(split): | |
# generate targets and context | |
if split == "train": | |
data = train_data | |
else: | |
data = val_data | |
index = torch.randint(0, len(data) - block_size, (batch_size,)) | |
x = torch.stack([data[ind : ind + block_size] for ind in index]) | |
y = torch.stack([data[ind + 1 : ind + block_size + 1] for ind in index]) | |
return x.to(device), y.to(device) | |
@torch.no_grad() | |
def estimate_loss(): | |
out = {} | |
model.eval() | |
for split in ["train", "test"]: | |
losses = torch.zeros(eval_iters) | |
for k in range(eval_iters): | |
X, Y = get_batch(split) | |
logits, loss = model(X, Y) | |
losses[k] = loss.item() | |
out[split] = losses.mean() | |
model.train() | |
return out | |
class SelfAttentionHead(nn.Module): | |
def __init__(self, head_size): | |
super().__init__() | |
self.keys = nn.Linear(n_embed, head_size) | |
self.queries = nn.Linear(n_embed, head_size) | |
self.values = nn.Linear(n_embed, head_size) | |
self.head_size = head_size | |
self.n_embed = n_embed | |
self.register_buffer( | |
"tril", torch.tril(torch.ones((block_size, block_size))).to(device) | |
) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x): | |
B, T, C = x.shape | |
k = self.keys(x) # (B,T,C_h) | |
q = self.queries(x) # (B,T,C_h) | |
v = self.values(x) # (B,T,C_h) | |
wei = k @ q.transpose(-1, -2) * C ** (-0.5) # (B,T,T) | |
wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf")) | |
# wei = F.softmax(wei, dim=-1) # (B,T,T) | |
wei = torch.log(torch.exp(wei) + 1) # (B,T,T) | |
wei = self.dropout(wei) | |
out = wei @ v # (B,T,C_h) | |
return out | |
class LayerNorm(nn.Module): | |
def __init__(self, dim) -> None: | |
super().__init__() | |
self.eps = 1e-5 | |
# params | |
self.gamma = nn.Parameter(torch.ones(dim)) | |
self.beta = nn.Parameter(torch.zeros(dim)) | |
def forward(self, x): | |
xmean = x.mean(dim=1, keepdim=True) | |
xvar = ((x - xmean) ** 2).mean(dim=1, keepdim=True) | |
xhat = (x - xmean) / torch.sqrt(xvar + self.eps) | |
self.out = self.gamma * xhat + self.beta | |
return self.out | |
def parameters(self): | |
return [self.gamma, self.beta] | |
class MultiHeadAttention(nn.Module): | |
def __init__(self, n_heads, head_size) -> None: | |
super().__init__() | |
self.heads = nn.ModuleList( | |
[SelfAttentionHead(head_size) for _ in range(n_heads)] | |
) | |
self.proj = nn.Linear(n_embed, n_embed) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x): | |
B, T, C = x.shape | |
out = torch.cat([head(x) for head in self.heads], dim=-1) | |
out = self.proj(out) | |
out = self.dropout(out) | |
return out | |
class FeedForward(nn.Module): | |
def __init__(self, n_embed) -> None: | |
super().__init__() | |
self.ffn = nn.Sequential( | |
nn.Linear(n_embed, 4 * n_embed), | |
nn.ReLU(), | |
nn.Linear(4 * n_embed, n_embed), | |
nn.Dropout(dropout), | |
) | |
def forward(self, x): | |
return self.ffn(x) | |
class Block(nn.Module): | |
def __init__(self, n_embed, n_heads) -> None: | |
super().__init__() | |
self.head_size = n_embed // n_heads | |
# self.sa_head = MultiHeadAttention(n_heads, self.head_size) | |
self.sa_head = Mamba( | |
# This module uses roughly 3 * expand * d_model^2 parameters | |
d_model=n_embed, # Model dimension d_model | |
d_state=16, # SSM state expansion factor | |
d_conv=4, # Local convolution width | |
expand=2, # Block expansion factor | |
).to("cuda") | |
self.ffn = FeedForward(n_embed) | |
self.ln1 = nn.LayerNorm(n_embed) | |
self.ln2 = nn.LayerNorm(n_embed) | |
def forward(self, x): | |
x = x + self.sa_head(self.ln1(x)) | |
x = x + self.ffn(self.ln2(x)) | |
return x | |
class BigramNeuralNetwork(nn.Module): | |
def __init__(self, vocab_size): | |
super().__init__() | |
self.token_embedding_table = nn.Embedding(vocab_size, n_embed) | |
self.position_embedding_table = nn.Embedding(block_size, n_embed) | |
self.sa_head = MultiHeadAttention(4, int(n_embed / 4)) | |
self.lm_head = nn.Linear(n_embed, vocab_size) | |
self.ffn = FeedForward(n_embed) | |
self.blocks = nn.Sequential( | |
*[Block(n_embed, n_heads=n_heads) for _ in range(n_layers)] | |
) | |
def forward(self, idx, targets=None): | |
# idx = idx[:,-block_size:] | |
B, T = idx.shape | |
tok_emb = self.token_embedding_table(idx) # (B,T,C_e) | |
pos_emb = self.position_embedding_table( | |
torch.arange(T, device=device) | |
) # (T,C_e) | |
x = tok_emb + pos_emb # (B,T,C_e) | |
x = self.blocks(x) # (B,T,C_e) | |
logits = self.lm_head(x) # (B,T,vocab_size) | |
if targets is None: | |
loss = None | |
else: | |
B, T, C = logits.shape | |
logits = logits.view(B * T, C) | |
targets = targets.view(B * T) | |
loss = F.cross_entropy(logits, targets) | |
logits = logits.view(B, T, C) | |
return logits, loss | |
def generate(self, idx, max_new_tokens): | |
# idx is (B,T) | |
idx_next = [] | |
for i in range(max_new_tokens): | |
idx_cond = idx[:, -block_size:] | |
logits, loss = self(idx_cond) | |
last_timestep = logits[:, -1, :] | |
probs = F.softmax(last_timestep, dim=1) | |
next_index = torch.multinomial(probs, num_samples=1) | |
idx = torch.cat((idx, next_index), dim=1) | |
for arr in idx: | |
print(decode(arr.cpu().detach().numpy())) | |
return idx | |
def chunk_data_with_stride(data, block_size, stride): | |
# Create chunks using strides for overlapping sequences | |
return [data[i : i + block_size] for i in range(0, len(data) - block_size, stride)] | |
model = BigramNeuralNetwork(vocab_size) | |
optimizer = torch.optim.AdamW(model.parameters(), lr=lr) | |
losses_data = {"train": [], "test": []} | |
# checkpoint = torch.load('model.pt') | |
# model.load_state_dict(checkpoint['model_state_dict']) | |
# optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
# epoch = checkpoint['epoch'] | |
checkpoint_path = None # "./differentattention/model_40.pt" | |
epoch = 0 | |
if checkpoint_path: | |
checkpoint = torch.load(checkpoint_path) | |
print(checkpoint) | |
if checkpoint["model_state_dict"]: | |
model.load_state_dict(checkpoint["model_state_dict"].to(device)) | |
if checkpoint["optimizer_state_dict"]: | |
optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) | |
epoch = checkpoint["epoch"] | |
device = "cuda" | |
m = model.to(device) | |
print("Uses device " + device) | |
MODEL_CHECKPOINT = "./differentattention/model_{iter}.pt" | |
losses_data = {"train": [], "test": []} | |
# Create strided sequences | |
strided_sequences = chunk_data_with_stride(data, block_size, stride) | |
# Assuming strided_sequences is a list of tensors | |
train_sequences, val_sequences = train_test_split(strided_sequences, train_size=0.9) | |
# Concatenate the tensors in each list to form a single tensor for train and validation | |
train_data = torch.cat(train_sequences, dim=0) | |
val_data = torch.cat(val_sequences, dim=0) | |
print("# strided sequences:", len(strided_sequences)) | |
print(len(train_sequences)) | |
print(batch_size) | |
print(epochs) | |
print(len(train_sequences) / batch_size) | |
print((len(train_sequences) / batch_size) * epochs) | |
max_iters = int(np.round(len(train_sequences) / batch_size) * epochs) | |
losses_data = {"train": [], "test": []} | |
for iter in tqdm(range(epoch, max_iters)): | |
if iter % eval_iters == 0: | |
losses = estimate_loss() | |
losses_data["train"].append(losses["train"].cpu().numpy()) | |
losses_data["test"].append(losses["test"].cpu().numpy()) | |
print( | |
f"Step {iter}, train loss:{losses['train']:.4f}, test loss:{losses['test']:.4f}" | |
) | |
if iter % print_iters == 0: | |
losses = estimate_loss() | |
torch.save( | |
{ | |
"epoch": iter, | |
"model_state_dict": model.state_dict(), | |
"optimizer_state_dict": optimizer.state_dict(), | |
"loss": losses, | |
}, | |
MODEL_CHECKPOINT.format(iter=iter), | |
) | |
losses_data["train"].append(losses["train"].cpu().numpy()) | |
losses_data["test"].append(losses["test"].cpu().numpy()) | |
model.eval() | |
with torch.no_grad(): | |
# Generate from the model: | |
output = m.generate( | |
torch.zeros((1, 2), dtype=torch.long).to(device).contiguous(), 1000 | |
)[0].tolist() | |
print( | |
f"Step {iter}, train loss:{losses['train']:.4f}, test loss:{losses['test']:.4f}" | |
) | |
model.train() | |
# Get data | |
xb, yb = get_batch("train") | |
# Evaluate loss | |
logits, loss = model(xb, yb) | |
optimizer.zero_grad(set_to_none=True) | |
loss.backward() | |
torch.nn.utils.clip_grad.clip_grad_norm_(model.parameters(), 1.0) | |
optimizer.step() | |
torch.save(model.state_dict(), "./differentattention/model.pt") | |
# Generate from the model: | |
output = m.generate(torch.zeros((1, 2), dtype=torch.long).to(device), 1000)[0].tolist() | |
#!pip3 install seaborn | |
# import seaborn as sns | |
# import matplotlib.pyplot as plt | |
losses_df = pd.DataFrame(losses_data) | |
losses_df = losses_df.applymap(lambda x: float(x)) | |
print(losses_df) | |
# print(losses_df.head()) | |
# Line graph of loss train and loss test | |
# ax = sns.lineplot(data=losses_df[losses_df['train'] < 10]) | |
# ax.set(xlabel='Iteration(in 10s)', ylabel='Loss') | |
# plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Mamba Repo:
https://gist.github.com/thistleknot/0d2bbced6264cd2ac508145797989638
Paper:
https://arxiv.org/abs/2312.00752
Reddit post where I discovered this magic sauce:
https://www.reddit.com/r/MachineLearning/comments/18d65bz/d_thoughts_on_mamba/