-
-
Save thistleknot/93c79696e25d0b89ff9ed829a02fbd9b to your computer and use it in GitHub Desktop.
mamba dyanmic lr scheduler lion fff
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#https://gist.githubusercontent.com/thistleknot/raw/mamba_trainer.py | |
#SimplerMambaSSM | |
#https://colab.research.google.com/drive/1g9qpeVcFa0ca0cnhmqusO4RZtQdh9umY#scrollTo=2lECw6S4N7cn | |
""" | |
NanoGPT initial base | |
Mamba | |
Implemented Papers | |
* Symbolic Discovery of Optimization Algorithms | |
* Exponentially Faster Language Modeling | |
Custom Algorithm's | |
* Custom LR Scheduler (warmup peak using left and right ema's) | |
* 100% Efficient Batching | |
#Attempted Papers | |
#Simplifying Transformer Blocks (attempted, but all are dependent on attenttion, which makes the model a transformer block) | |
#Gaussian Adaptive Attention is All You Need: Robust Contextual Representations Across Multiple Modalities | |
#Agent Attention: On the Integration of Softmax and Linear Attention | |
Note: #Mamba isn't using attention, nor can it it appears | |
#to bench mark what multiple of lr works best, run with | |
for m in 0.9 1 1.1; do python simplermambassm-fff-revised-v2-agent-attn-ga-halving.py --lr_multiple $m; done | |
""" | |
#!pip install mamba-ssm causal-conv1d | |
#resources | |
#token length thresholds | |
#https://gist.github.com/thistleknot/1d25d5b65f3d7255ddc8ffbd3981a0bb | |
#efficient batching | |
#https://gist.github.com/thistleknot/97617f91538ad075b9a44437f88e8680 | |
#!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt | |
#https://github.com/havenhq/mamba-chat/blob/main/trainer/mamba_trainer.py | |
#https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py | |
#https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py | |
#https://huggingface.co/clibrain/mamba-2.8b-instruct-openhermes | |
from collections import defaultdict | |
from mamba_ssm import Mamba | |
from mamba_ssm.models.mixer_seq_simple import Block | |
from nltk.corpus import brown | |
from sklearn.model_selection import train_test_split | |
from torch.nn import functional as F | |
from torch.optim.optimizer import Optimizer | |
from tqdm import tqdm | |
from transformers import AutoTokenizer | |
import argparse | |
import json | |
import math | |
import numpy as np | |
import os | |
import pandas as pd | |
import random | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import wandb | |
import math | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
def clear_model(model): | |
del model | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
# Function to tokenize and filter | |
def tokenize_and_filter(dataset, min_size, max_size): | |
tokenized = tokenizer.batch_encode_plus(dataset)['input_ids'] | |
filtered = [tokens for tokens in tokenized if min_size <= len(tokens) <= max_size] | |
return filtered | |
def rank_ecdf(values): | |
n = len(values) | |
sorted_indices = np.argsort(values) | |
ranks = np.argsort(sorted_indices) | |
return (ranks + 0.5) / n | |
def count_parameters(model): | |
return sum(p.numel() for p in model.parameters() if p.requires_grad) | |
def encode(text): | |
return tokenizer.encode(text, add_special_tokens=False, return_tensors="pt").squeeze() | |
def decode(token_ids): | |
return tokenizer.decode(token_ids, skip_special_tokens=True) | |
class Lion(Optimizer): | |
r"""Implements Lion algorithm.""" | |
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0): | |
"""Initialize the hyperparameters. | |
Args: | |
params (iterable): iterable of parameters to optimize or dicts defining | |
parameter groups | |
lr (float, optional): learning rate (default: 1e-4) | |
betas (Tuple[float, float], optional): coefficients used for computing | |
running averages of gradient and its square (default: (0.9, 0.99)) | |
weight_decay (float, optional): weight decay coefficient (default: 0) | |
""" | |
if not 0.0 <= lr: | |
raise ValueError('Invalid learning rate: {}'.format(lr)) | |
if not 0.0 <= betas[0] < 1.0: | |
raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0])) | |
if not 0.0 <= betas[1] < 1.0: | |
raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1])) | |
defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) | |
super().__init__(params, defaults) | |
@torch.no_grad() | |
def step(self, closure=None): | |
"""Performs a single optimization step. | |
Args: | |
closure (callable, optional): A closure that reevaluates the model | |
and returns the loss. | |
Returns: | |
the loss. | |
""" | |
loss = None | |
if closure is not None: | |
with torch.enable_grad(): | |
loss = closure() | |
for group in self.param_groups: | |
for p in group['params']: | |
if p.grad is None: | |
continue | |
# Perform stepweight decay | |
p.data.mul_(1 - group['lr'] * group['weight_decay']) | |
grad = p.grad | |
state = self.state[p] | |
# State initialization | |
if len(state) == 0: | |
# Exponential moving average of gradient values | |
state['exp_avg'] = torch.zeros_like(p) | |
exp_avg = state['exp_avg'] | |
beta1, beta2 = group['betas'] | |
# Weight update | |
update = exp_avg * beta1 + grad * (1 - beta1) | |
p.add_(update.sign_(), alpha=-group['lr']) | |
# Decay the momentum running average coefficient | |
exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2) | |
return loss | |
# Function to get current VRAM usage | |
def get_vram_usage(): | |
torch.cuda.synchronize() | |
return torch.cuda.memory_allocated() | |
# Function to get batches | |
def get_batch(data, args): | |
index = torch.randint(0, len(data) - block_size, (batch_size,)) | |
x = torch.stack([data[ind: ind + block_size] for ind in index]) | |
y = torch.stack([data[ind + 1: ind + block_size + 1] for ind in index]) | |
return x.to(device), y.to(device) | |
# Function to estimate loss | |
#@torch.no_grad() | |
@torch.inference_mode | |
def estimate_loss(X, Y): | |
model.eval() | |
logits, loss = model(X, Y) | |
perplexity = torch.exp(loss).item() | |
model.train() | |
return [loss.item(), perplexity] | |
def precalculate_lengths(tokenized_sequences): | |
return [len(seq) for seq in tokenized_sequences] | |
#FFF | |
class BigramNeuralNetwork(nn.Module): | |
#depth is for FFF | |
def __init__(self, vocab_size, hidden_neurons, n_heads=34, n_embed=768, depth=11, n_layers=21, dropout=0.2): | |
super().__init__() | |
self.hidden_neurons = hidden_neurons | |
self.token_embedding_table = nn.Embedding(vocab_size, hidden_neurons[0]) | |
#self.position_embedding_table = nn.Embedding(block_size, hidden_neurons[0]) | |
self.lm_head = nn.Linear(n_embed, vocab_size) | |
# Initialize blocks with reduced_n_embed | |
self.blocks = nn.Sequential(*[Block(n_embed=n_embed, n_heads=n_heads, depth=depth, input_width=None, dropout=dropout) for _ in range(n_layers)]) | |
#handles var length inputs | |
def forward(self, idx, targets=None): | |
B, T = idx.shape | |
h_embed = self.hidden_neurons[0] | |
tok_emb = self.token_embedding_table(idx) # (B,T,C_e) | |
#pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C_e) | |
x = tok_emb# + pos_emb # (B,T,C_e) | |
x = self.blocks(x) | |
logits = self.lm_head(x) # (B,T,vocab_size) | |
if targets is None: | |
loss = None | |
else: | |
B, T, C = logits.shape | |
logits = logits.view(B * T, C) | |
targets = targets.view(B * T) | |
loss = F.cross_entropy(logits, targets) | |
logits = logits.view(B, T, C) | |
return logits, loss | |
#max_new_tokens is based on what is passed in | |
def generate(self, idx, max_new_tokens): | |
# idx is (B,T) | |
idx_next = [] | |
for i in range(max_new_tokens): | |
idx_cond = idx[:, -max_new_tokens:] | |
logits, loss = self(idx_cond) | |
last_timestep = logits[:, -1, :] | |
probs = F.softmax(last_timestep, dim=1) | |
next_index = torch.multinomial(probs, num_samples=1) | |
idx = torch.cat((idx, next_index), dim=1) | |
return idx | |
#FFF | |
class Block(nn.Module): | |
def __init__(self, n_embed, n_heads, depth, embed_width=None, input_width=None, dropout=0.2): | |
super().__init__() | |
if embed_width is None: | |
embed_width = n_embed | |
if input_width is None: | |
input_width = embed_width | |
self.dropout = dropout | |
self.head_size = embed_width // n_heads | |
self.sa_head = Mamba( | |
d_model=n_embed, | |
d_state=16, | |
d_conv=4, | |
expand=2, | |
).to("cuda") | |
if input_width is None: | |
input_width = n_embed | |
self.num_heads = n_heads | |
self.head_dim = int(input_width / n_heads) | |
self.fff = FFF(input_width, input_width - 16, depth) # Update output_width -16 | |
# Update the LayerNorm with the correct normalized_shape after pruning | |
self.ln1 = nn.LayerNorm(n_embed) | |
self.ln2 = nn.LayerNorm(n_embed) | |
self.fff = FFF(input_width, n_embed, depth) # Update with correct dimensions | |
def forward(self, x): | |
x = x + self.sa_head(self.ln1(x)) | |
x = x + self.fff(self.ln2(x)) # Use FFF here | |
return x | |
#FFF | |
class FFF(nn.Module): | |
def __init__(self, input_width, output_width, depth=11): | |
super().__init__() | |
self.input_width = input_width | |
self.output_width = output_width | |
self.depth = depth | |
self.linear_in = nn.Linear(input_width, depth + 1, bias=True) | |
self.linear_out = nn.Linear(depth + 1, output_width, bias=False) | |
def forward(self, oldx: torch.Tensor) -> torch.Tensor: | |
x = oldx.reshape(-1, self.input_width) | |
batch_size = x.shape[0] | |
logits = self.linear_in(x) # (batch_size, depth+1) | |
activations = torch.nn.functional.gelu(logits) | |
new_logits = self.linear_out(activations) | |
batch_size, seq_length, _ = oldx.shape | |
ret = new_logits.view(batch_size, seq_length, -1) | |
return ret | |
class LayerNorm(nn.Module): | |
def __init__(self, dim) -> None: | |
super().__init__() | |
self.eps = 1e-5 | |
# params | |
self.gamma = nn.Parameter(torch.ones(dim)) | |
self.beta = nn.Parameter(torch.zeros(dim)) | |
def forward(self, x): | |
xmean = x.mean(dim=1, keepdim=True) | |
xvar = ((x - xmean) ** 2).mean(dim=1, keepdim=True) | |
xhat = (x - xmean) / torch.sqrt(xvar + self.eps) | |
self.out = self.gamma * xhat + self.beta | |
return self.out | |
def parameters(self): | |
return [self.gamma, self.beta] | |
def flatten_weights(weights): | |
""" | |
Flatten the weights of a neural network layer into a single tensor. | |
:param weights: A list of weight tensors from a neural network layer. | |
:return: A single flattened tensor containing all weights. | |
""" | |
return torch.cat([w.flatten() for w in weights]) | |
def get_row_count(file_path): | |
with open(file_path) as file: | |
# Skip the header line | |
next(file) | |
line_count = sum(1 for line in file) | |
return line_count + 1 | |
# Function to find the combination of values that adds up to the target sum | |
def find_combination_to_sum(counts, target): | |
#print("Target inside function (find_combination_to_sum):", target) | |
values = [] | |
for val, count in counts.items(): | |
#print(f"Value (val): {val}, Type: {type(val)}") | |
#print(f"Count: {count}, Type: {type(count)}") | |
#print(f"Target // val: {target // val}, Type of target // val: {type(target // val)}") | |
values.extend([val] * min(count, target // val)) | |
# Initialize the DP table | |
n = len(values) | |
dp = [[False] * (target + 1) for _ in range(n + 1)] | |
# Base case: target sum 0 is always achievable (by choosing nothing) | |
for i in range(n + 1): | |
dp[i][0] = True | |
# Build the DP table | |
for i in range(1, n + 1): | |
for j in range(1, target + 1): | |
dp[i][j] = dp[i - 1][j] | |
if values[i - 1] <= j: | |
dp[i][j] |= dp[i - 1][j - values[i - 1]] | |
# Check if the target sum is possible | |
if not dp[n][target]: | |
return None | |
# Trace back the solution | |
result = [] | |
i, j = n, target | |
while i > 0 and j > 0: | |
if dp[i][j] != dp[i - 1][j]: | |
result.append(values[i - 1]) | |
j -= values[i - 1] | |
i -= 1 | |
return result | |
def sample_and_remove(combination, records): | |
# Group records by their length | |
grouped_records = defaultdict(list) | |
for record in records: | |
grouped_records[len(record)].append(record) | |
sampled_records = [] | |
if(combination): | |
for lens_size in combination: | |
# Check if there are enough records of this lens size | |
if grouped_records[lens_size]: | |
# Sample one record of this lens size | |
sample = random.sample(grouped_records[lens_size], 1)[0] | |
# Add to sampled records | |
sampled_records.append(sample) | |
# Remove this record from the grouped records | |
grouped_records[lens_size].remove(sample) | |
# Flatten the grouped records back to a single list | |
modified_records = [item for sublist in grouped_records.values() for item in sublist] | |
return sampled_records, modified_records | |
else: | |
return [], records | |
def create_batches_v2(records, block_size, num_batches): | |
#print("block_size in create_batches_v2:", block_size) | |
#print("num_batches in create_batches_v2:", num_batches) | |
samples = [] | |
modified_records = records.copy() | |
for r in range(0, num_batches): | |
sample, modified_records = retrieve_sample(modified_records, block_size, num_batches) | |
if(len(sample)==0): | |
return [], records | |
else: | |
samples.append(sample) | |
if(len(samples)<num_batches): | |
return [], records | |
else: | |
return samples, modified_records | |
def retrieve_sample(records, block_size, num_batches): | |
#print("block_size in retrieve_sample:", block_size) | |
lens = [len(s) for s in records] | |
# Assuming 'lens' is a list containing your data | |
grouped = pd.DataFrame(lens, columns=['lens']).groupby('lens').size() | |
# Convert to dictionary | |
counts_dict = grouped.to_dict() | |
combination = find_combination_to_sum(counts_dict, block_size) | |
sample, records = sample_and_remove(combination, records) | |
return sample, records | |
# Function to sample indices based on dataset size | |
def sample_indices(dataset_size, sample_size): | |
#return sorted(random.choices(range(dataset_size), k=sample_size)) | |
try: | |
return sorted(random.sample(range(dataset_size), sample_size)) | |
except: | |
return sorted(random.choices(range(dataset_size), k=sample_size)) | |
# Load and process the primary dataset (quotes_tokens) | |
quotes_filename = 'quotes_tokens.csv' | |
quotes_df = pd.read_csv(quotes_filename) | |
quotes_data = quotes_df.iloc[:,0].apply(lambda x: json.loads(x)).tolist() | |
num_quotes = len(quotes_data) | |
print('num_quotes',num_quotes) | |
# Datasets and their proportions relative to the size of quotes dataset | |
datasets_proportions = { | |
'idioms': 1, # 30% of the number of quotes | |
'exs': 2, # 30% of the number of quotes | |
'defs': 2, # 20% of the number of quotes | |
} | |
# Initialize the tokenized_data dataset with the quotes dataset | |
tokenized_data = quotes_data | |
# Process each additional dataset (idioms, exs, defs) | |
for dataset, proportion in datasets_proportions.items(): | |
filename = f"{dataset}_tokens.csv" | |
df = pd.read_csv(filename) | |
# Read and convert the sampled rows | |
if(False): | |
dataset_size = df.shape[0] | |
sample_size = int(num_quotes * proportion) | |
sampled_indices = sample_indices(dataset_size, sample_size) | |
sampled_data = df.iloc[sampled_indices, 0].apply(lambda x: json.loads(x)).tolist() | |
tokenized_data.extend(sampled_data) | |
else: | |
nonsampled_data = df.drop_duplicates(subset=[df.columns[0]]) | |
print(dataset, len(nonsampled_data)) | |
# Apply the transformation to the entire column and convert to a list | |
nonsampled_data = nonsampled_data.iloc[:, 0].apply(lambda x: json.loads(x)).tolist() | |
tokenized_data.extend(nonsampled_data) | |
print('tokenized_data len after idioms',len(tokenized_data)) | |
# Sample from the larger datasets (brown_tokens, wiki_tokens, ivy_tokens) using Dask | |
aggregate_datasets = ['brown_tokens', 'wiki_tokens', 'essays_tokens', 'convos_tokens', 'dolly_tokens'] | |
# Loop through each dataset and prepare indices to sample | |
aggregated = pd.DataFrame() | |
for dataset in aggregate_datasets: | |
filename = f"{dataset}.csv" | |
print(filename) | |
df = pd.read_csv(filename) | |
df = df.drop_duplicates(subset=[df.columns[0]]) | |
nonsampled_data = df.iloc[:, 0].apply(lambda x: json.loads(x)).tolist() | |
print(len(nonsampled_data)) | |
filtered_data = [t for t in nonsampled_data if len(t) <= 384*2] | |
print(dataset, 'filtered', len(filtered_data)) | |
tokenized_data.extend(filtered_data) | |
# tokenized_data now contains quotes, idioms, exs, defs, and a proportionate amount of other datasets | |
print('Total data length:', len(tokenized_data)) | |
print('max tokenized_data len', np.max([len(t) for t in tokenized_data])) | |
target_tokens = 6144 #actual was 96 x 43 = 4128, this is based on empirical observations based on static parm size unchanging * block_size * batch_size <= max VRAM | |
#if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
#parser.add_argument("--model", type=str, default=None) | |
parser.add_argument("--tokenizer", type=str, default="HuggingFaceH4/zephyr-7b-beta") | |
parser.add_argument("--custom_tokenizer", type=bool, default=False) | |
parser.add_argument("--epochs", type=int, default=1) | |
parser.add_argument("--initial_lr", type=float, default=5e-05) | |
parser.add_argument("--peak_lr", type=float, default=0.0001265625)#3.75e-05#2.252541e-09#3.75e-05 | |
parser.add_argument("--lr_multiple", type=float, default=1) | |
parser.add_argument("--desired_lr", type=float, default=1.28725e-4) | |
parser.add_argument("--n_embed", type=int, default=1536) | |
parser.add_argument("--n_heads", type=int, default=21) | |
parser.add_argument("--n_layers", type=int, default=21) | |
parser.add_argument("--train_ratio", type=float, default=0.9) | |
parser.add_argument("--save_path", type=str, default="models") | |
parser.add_argument("--model_checkpoint", type=str, default=None) | |
parser.add_argument("--dropout", type=float, default=0.2) #0.2#4.000000e-02 | |
parser.add_argument("--weight_decay", type=float, default=0.15) #0.1#1.601807e-02 | |
parser.add_argument("--patience", type=int, default=1) | |
parser.add_argument("--grad_clip", type=float, default=1.0) | |
parser.add_argument("--sample_size", type=float, default=1) | |
parser.add_argument("--dropout_multiple", type=float, default=1) | |
parser.add_argument("--weight_decay_multiple", type=float, default=1) | |
parser.add_argument("--grad_clip_multiple", type=float, default=1) | |
parser.add_argument("--eval_ratio", type=float, default=0.25) | |
parser.add_argument("--grad_steps", type=float, default=4) | |
args, _ = parser.parse_known_args() | |
parser.add_argument("--depth", type=int, default=args.n_layers) | |
args = parser.parse_args() | |
# Apply modifications to args based on other args | |
args.peak_lr *= args.lr_multiple | |
args.dropout *= args.dropout_multiple | |
args.weight_decay *= args.weight_decay_multiple | |
args.grad_clip *= args.grad_clip_multiple | |
losses_data = {"train": [], "test": []} | |
best_model_path = "./best_model.pt" | |
if (args.model_checkpoint): | |
#checkpoint = torch.load('model.pt') | |
checkpoint = args.model_checkpoint | |
model.load_state_dict(checkpoint['model_state_dict']) | |
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
epoch = checkpoint['epoch'] | |
checkpoint = torch.load(args.args.model_checkpoint) | |
print(args.checkpoint) | |
if checkpoint["model_state_dict"]: | |
model.load_state_dict(checkpoint["model_state_dict"].to(device)) | |
if checkpoint["optimizer_state_dict"]: | |
optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) | |
epoch = checkpoint["epoch"] | |
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) | |
tokenizer.eos_token = "<|endoftext|>" | |
tokenizer.pad_token = tokenizer.eos_token | |
# Get the vocabulary size | |
vocab_size = len(tokenizer) | |
print("Vocabulary Size:", vocab_size) | |
# Early stopping parameters | |
best_perplexity = float("inf") | |
best_iter = 0 | |
evaluations_since_improvement = 0 | |
# Initialize and train the tokenizer | |
print(np.max([len(t) for t in tokenized_data])) | |
data_len = len(tokenized_data) | |
print(type(tokenized_data)) | |
# Convert each item in tokenized_data to its string representation | |
stringified_data = [json.dumps(item, sort_keys=True) for item in tokenized_data] | |
# Use a set to find unique elements | |
unique_stringified_data = set(stringified_data) | |
# Convert the unique string representations back to JSON objects | |
combined_records = [json.loads(item) for item in unique_stringified_data] | |
print('len combined_records:', len(combined_records)) | |
random.shuffle(combined_records) | |
print('len unique_records:', len(combined_records)) | |
unique_records = combined_records | |
# actual training data | |
print('len unique_records:', len(unique_records)) | |
uniques_filtered = unique_records | |
random.shuffle(uniques_filtered) | |
print('before sample', len(uniques_filtered)) | |
sample_size_int = int(np.round(len(uniques_filtered)*args.sample_size)) | |
uniques_filtered = uniques_filtered[0:sample_size_int] | |
print('after sample', len(uniques_filtered)) | |
train_tokenized, val_tokenized = train_test_split(uniques_filtered, train_size=args.train_ratio) | |
train_lens = lens = [len(t) for t in train_tokenized] | |
val_lens = lens = [len(t) for t in val_tokenized] | |
max_len = np.max([*train_lens,*val_lens]) | |
print("max_len:", max_len) | |
print("target_tokens:", target_tokens) | |
print("target_tokens/max_len:", target_tokens/max_len) | |
block_size = max_len | |
train_tokenized = [record + [tokenizer.eos_token_id] for record in train_tokenized if len(record) + 1 <= block_size] | |
val_tokenized = [record + [tokenizer.eos_token_id] for record in val_tokenized if len(record) + 1 <= block_size] | |
print(block_size) | |
batch_size=int(np.round(target_tokens/max_len)) | |
sampled_train, remainder = create_batches_v2(train_tokenized, block_size, batch_size) | |
#overwrite | |
train_tokenized = [record + [tokenizer.eos_token_id] for record in train_tokenized if len(record) + 1 <= block_size] | |
val_tokenized = [record + [tokenizer.eos_token_id] for record in val_tokenized if len(record) + 1 <= block_size] | |
print('total tokens',np.sum(train_lens)) | |
# Print the values for debugging | |
print("block_size:", block_size) | |
print("batch_size:", batch_size) | |
print("Sum of train_lens:", np.sum(train_lens)) | |
# Calculate the denominator | |
denominator = block_size * batch_size | |
print("Denominator (block_size * batch_size):", denominator) | |
# Check if the denominator is zero to avoid division by zero error | |
if denominator == 0: | |
print("Denominator is zero. Cannot divide by zero.") | |
epoch_iters = 0 # or handle this case as appropriate for your program | |
else: | |
# Calculate epoch_iters | |
epoch_iters = int(np.round(np.sum(train_lens) / denominator)) | |
print("epoch_iters calculated:", epoch_iters) | |
# len | |
print("epoch_iters", epoch_iters) | |
prune_iter = 0 | |
prune_pcts = [.01,.02,.03,.05,.08,.21,.39,.55,.89] | |
max_iters = epoch_iters * args.epochs | |
eval_iters = int(np.round(max_iters * args.eval_ratio)) | |
print('eval_iters:', eval_iters) | |
print("max_iters", max_iters) | |
print_iters = int(np.round(max_iters / args.epochs)) | |
print(len(sampled_train)) | |
print(batch_size) | |
print(args.epochs) | |
print(len(sampled_train) / batch_size) | |
print((len(sampled_train) / batch_size) * args.epochs) | |
# 1 epoch | |
check_perplexity_iter = int(np.round(max_iters / args.epochs)) | |
model = BigramNeuralNetwork(vocab_size, hidden_neurons=[args.n_embed], n_heads=args.n_heads, n_embed=args.n_embed, depth=args.depth, n_layers=args.n_layers, dropout=args.dropout) | |
model = model.to(device) | |
total_params = count_parameters(model) | |
print(f"Total number of parameters in the model: {total_params}") | |
available_train_tokenized = train_tokenized.copy() | |
available_val_tokenized = val_tokenized.copy() | |
print('len available_train_tokenized',len(available_train_tokenized)) | |
shuffle_threshold = int(np.round(denominator/np.mean([len(t) for t in available_train_tokenized]))) | |
shuffle_threshold = shuffle_threshold*3 | |
print('shuffle_threshold',shuffle_threshold) | |
os.environ["WANDB_MODE"] = "offline" | |
wandb.init(project="Selective State Space Attention") | |
optimizer = Lion(model.parameters(), lr=args.peak_lr, weight_decay=args.weight_decay) | |
optimizer.zero_grad() | |
for iter in tqdm(range(0, max_iters)): | |
old_available_train_tokenized = available_train_tokenized.copy() | |
train_subset = available_train_tokenized[0:shuffle_threshold] | |
train_data, returned_train_tokenized = create_batches_v2(train_subset, block_size, batch_size) | |
available_train_tokenized = old_available_train_tokenized[shuffle_threshold:] + returned_train_tokenized | |
if(len(train_data)==0): | |
available_train_tokenized = train_tokenized + available_train_tokenized | |
old_available_train_tokenized = available_train_tokenized.copy() | |
train_subset = available_train_tokenized[0:shuffle_threshold] | |
train_data, returned_train_tokenized = create_batches_v2(train_subset, block_size, batch_size) | |
available_train_tokenized = old_available_train_tokenized[shuffle_threshold:] + returned_train_tokenized | |
train_data_flat = [np.hstack(t) for t in train_data] | |
train_padded = torch.cat([torch.tensor(t, dtype=torch.long) for t in train_data_flat], dim=0) | |
x_tr, y_tr = get_batch(train_padded, args) | |
# eval loss | |
if( ((iter % eval_iters) == 0) or (iter == max_iters)): | |
old_available_val_tokenized = available_val_tokenized.copy() | |
valid_subset = available_val_tokenized[0:shuffle_threshold] | |
val_data, returned_val_tokenized = create_batches_v2(valid_subset, block_size, batch_size) | |
if(len(val_data)==0): | |
available_val_tokenized = val_tokenized + available_val_tokenized | |
old_available_val_tokenized = available_val_tokenized.copy() | |
valid_subset = available_val_tokenized[0:shuffle_threshold] | |
val_data, returned_val_tokenized = create_batches_v2(valid_subset, block_size, batch_size) | |
available_val_tokenized = old_available_val_tokenized[shuffle_threshold:] + returned_val_tokenized | |
val_data_flat = [np.hstack(t) for t in val_data] | |
val_padded = torch.cat([torch.tensor(t, dtype=torch.long) for t in val_data_flat], dim=0) | |
x_te, y_te = get_batch(val_padded, args) | |
# Generate from the model: | |
if(True): | |
output = model.generate( | |
torch.zeros( | |
(1, 2), | |
dtype=torch.long).to(device), | |
max_new_tokens=int(np.round(block_size/4)) # Assuming you want to generate 'block_size' new tokens | |
)[0].tolist() | |
# Split the output on the EOS token | |
split_output = [] | |
temp = [] | |
for token_id in output: | |
if token_id == tokenizer.eos_token_id: | |
split_output.append(temp) | |
temp = [] | |
else: | |
temp.append(token_id) | |
if temp: | |
split_output.append(temp) | |
# Decode each segment and print with new lines | |
for segment in split_output: | |
print(decode(segment)) | |
print() # This adds a new line between segments | |
#after | |
losses = estimate_loss(x_te, y_te) | |
loss = losses[0] | |
perplexity = losses[1] | |
losses_data["test"].append(loss) | |
#print(f"Evaluation Iteration {iter}, Eval Fast EMA: {eval_losses_fast_ema}, Eval Slow EMA: {eval_losses_slow_ema}, Eval MACD: {eval_macd_val}, Eval Signal Line: {eval_signal_line}, Eval Premium: {eval_premium}") | |
print(f"Step {iter}, test loss:{loss:.4f}, perplexity:{perplexity:.4f}") | |
eval_loss = loss | |
wandb.log( | |
{ | |
"iteration": iter, | |
"eval_loss": eval_loss, | |
"perplexity": perplexity, | |
"evaluations_since_improvement": evaluations_since_improvement, | |
}) | |
# Start checking for perplexity after a certain number of | |
# iterations | |
#if iter >= check_perplexity_iter: | |
save_path = "models" | |
# Check for improvement and save the best model | |
if perplexity < best_perplexity: | |
print(f"Perplexity improved at iteration {iter}, prior best: {best_perplexity}, new best: {perplexity}") | |
best_perplexity = perplexity | |
best_iter = iter | |
evaluations_since_improvement = 0 | |
MODEL_CHECKPOINT = f"./{save_path}/model_{best_iter}_{best_perplexity}.pt" | |
# Save the best model | |
torch.save( | |
{ | |
"epoch": best_iter, | |
"model_state_dict": model.state_dict(), | |
"optimizer_state_dict": optimizer.state_dict(), | |
"loss": losses, | |
}, | |
MODEL_CHECKPOINT.format(iter=best_iter, best_perplexity=best_perplexity), | |
) | |
else: | |
print(f"{perplexity}, no improvement. Evaluations since improvement: {evaluations_since_improvement}") | |
evaluations_since_improvement += 1 | |
# Early stopping check | |
if evaluations_since_improvement >= args.patience: | |
#print(f"Early stopping triggered at iteration {iter}, perplexity: {perplexity}") | |
#break | |
pass | |
# Always Train loss Forward pass | |
logits, loss = model(x_tr, y_tr) | |
optimizer.zero_grad(set_to_none=True) | |
# Backward pass and clipping for full precision | |
loss.backward() | |
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) | |
#optimizer.step() | |
if ((iter % args.grad_steps== 0) and (iter > 0)): | |
optimizer.step() # Update model parameters | |
optimizer.zero_grad() # Zero gradients after update | |
# scheduler.step() # Uncomment if you are using a learning rate scheduler | |
# Print loss | |
# Convert loss to float for printing, as it might be in half precision | |
print_loss = loss.item() | |
# Correct this line | |
losses_data["train"].append(print_loss) # Change from "test" to "train" | |
# Update the call to update_macd for training losses | |
#print(losses_data["train"]) | |
print(f"Step {iter}, train loss:{print_loss:.4f}") | |
# Log metrics to wandb | |
wandb.log({ | |
"iteration": iter, | |
"train_loss": print_loss, | |
"eval_loss": eval_loss, | |
"evaluations_since_improvement": evaluations_since_improvement, | |
}) | |
# Create a dictionary to store the best perplexity and the final parameter values | |
results = { | |
'best_perplexity': best_perplexity, | |
'final_peak_lr': args.peak_lr, | |
'final_dropout': args.dropout, | |
'final_weight_decay': args.weight_decay, | |
'final_grad_clip': args.grad_clip | |
# Add other final parameter values if necessary | |
} | |
# Assuming 'results' is a dictionary containing the perplexity and other information | |
df = pd.DataFrame([results]) | |
# Get the learning rate and perplexity | |
perplexity = results['best_perplexity'] | |
# Round the perplexity | |
rounded_perplexity = round(perplexity, 2) | |
# Construct the filename | |
filename = f"best_perplexity_{rounded_perplexity}.csv" | |
# Save the DataFrame to a CSV file | |
df.to_csv(filename, index=False) | |
# Finish wandb run | |
wandb.finish() | |
print("Training Losses:") | |
for i, loss in enumerate(losses_data["train"]): | |
print(f"Iteration {i}: Train Loss = {loss}") | |
# Print Testing Losses | |
print("\nTesting Losses:") | |
for i, loss in enumerate(losses_data["test"]): | |
eval_iter = i * eval_iters | |
print(f"Iteration {eval_iter}: Eval Loss = {loss}") | |
# Convert your data to a pandas DataFrame | |
df = pd.DataFrame({'Iteration': [i * epoch_iters for i in range(len(losses_data["test"]))], | |
'Eval Loss': losses_data["test"]}) | |
# Specify your CSV file path | |
csv_file_path = 'testing_losses.csv' | |
# Save to CSV | |
df.to_csv(csv_file_path, index=False) | |
print(f"Testing losses have been saved to {csv_file_path}") | |
save_path = "models" | |
MODEL_CHECKPOINT = f"./{save_path}/model_{best_iter}_{best_perplexity}.pt" | |
print('best',MODEL_CHECKPOINT) | |
# Step 2: Load the checkpoint data | |
#loaded without | |
checkpoint = torch.load(MODEL_CHECKPOINT.format(iter=best_iter, best_perplexity=best_perplexity)) | |
#print("best checkpoint",checkpoint) | |
# Step 3: Update model and optimizer states | |
model.load_state_dict(checkpoint['model_state_dict']) | |
# Generate from the model: | |
tokenizer.eos_token_id | |
# Generate from the model: | |
output = model.generate( | |
torch.zeros( | |
(1, | |
2), | |
dtype=torch.long).to(device), | |
block_size)[0].tolist() | |
#print(decode(output)) | |
# Split the output on the EOS token | |
split_output = [] | |
temp = [] | |
for token_id in output: | |
if token_id == tokenizer.eos_token_id: | |
split_output.append(temp) | |
temp = [] | |
else: | |
temp.append(token_id) | |
if temp: | |
split_output.append(temp) | |
# Decode each segment and print with new lines | |
for segment in split_output: | |
print(decode(segment)) | |
print() # This adds a new line between segments |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment