Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save bizrockman/f5c5b7bd80090169673a3479f47834e0 to your computer and use it in GitHub Desktop.
Save bizrockman/f5c5b7bd80090169673a3479f47834e0 to your computer and use it in GitHub Desktop.
train.py
from MEGABYTE_pytorch import MEGABYTE
import datetime
import time
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
# constants
print(torch.__version__)
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
PRIME_LEN = 100
SEQ_LEN = 8192
# helpers
def cycle(loader):
while True:
for data in loader:
yield data
def decode_token(token):
return str(chr(max(32, token)))
def decode_tokens(tokens):
return ''.join(list(map(decode_token, tokens)))
# instantiate GPT-like decoder model
model = MEGABYTE(
num_tokens = 256,
dim = (768, 512, 256),
depth = (6, 4, 2),
max_seq_len = (512, 4, 4),
flash_attn = False
).cuda()
# prepare enwik8 data
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
full_seq = self.data[rand_start: rand_start + self.seq_len].long()
return full_seq.cuda()
def __len__(self):
return self.data.size(0) // self.seq_len
#dataset = load_dataset("wikipedia", "20220301.de", split="train[1000:20000]")
#texts = dataset['text']
# Load the German Wikipedia dataset
dataset = load_dataset("wikipedia", "20220301.de") #28GB RAM Usage
# Convert the text to ASCII values
texts = dataset['train']['text'] # Extract the text
x = np.array([], dtype=np.uint8) # Initialize the array
for text in texts:
ascii_text = np.array([ord(c) for c in text], dtype=np.uint8) # Convert the string to ASCII values
x = np.concatenate((x, ascii_text)) # Add the ASCII values to the array
if len(x) >= int(95e6): # If the array has reached 95 million elements, break the loop
x = x[:int(95e6)] # Truncate the array to 95 million elements
break
# Split the data into training and validation sets
train_x, valid_x = np.split(x, [int(0.9 * len(x))])
data_train, data_val = map(torch.from_numpy, (train_x, valid_x))
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
# Convert the first 8000 bytes back into a string
#train_text = "".join(chr(c) for c in data_train[:8000].tolist())
#print(train_text)
#input("Press enter to continue execution...")
# optimizer
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# training
start_time = datetime.datetime.now()
print(f"Start time: {start_time.strftime('%H:%M:%S')}")
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
model.train()
for __ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader), return_loss = True)
loss.backward()
print(f'training loss: {loss.item()}')
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()
if i % VALIDATE_EVERY == 0:
model.eval()
with torch.no_grad():
loss = model(next(val_loader), return_loss = True)
print(f'validation loss: {loss.item()}')
if i != 0 and i % GENERATE_EVERY == 0:
model.eval()
# Save model at each validation step
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optim.state_dict(),
'loss': loss,
}, 'megabyte-wikide.pt')
inp = random.choice(val_dataset)[:-1]
prime_inp = inp[:PRIME_LEN]
prime = decode_tokens(prime_inp)
print(f'%s \n\n %s', (prime, '*' * 100))
sample = model.generate(prime_inp[None, :])
sample = sample.flatten(1)
output_str = decode_tokens(sample[0][PRIME_LEN:])
print(output_str)
torch.save(model.state_dict(), 'megabyte-final.pt')
# Record the end time
end_time = datetime.datetime.now()
print(f"End time: {end_time.strftime('%H:%M:%S')}")
# Compute the duration and print it
duration = end_time - start_time
total_seconds = int(duration.total_seconds())
hours, remainder = divmod(total_seconds, 60*60)
minutes, seconds = divmod(remainder, 60)
print(f"Duration: {hours:02}:{minutes:02}:{seconds:02}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment