Skip to content

Instantly share code, notes, and snippets.

@yashbonde
Last active February 20, 2022 12:40
Show Gist options
  • Save yashbonde/027e1abb33c49d3db4b109d0f69919f5 to your computer and use it in GitHub Desktop.
Save yashbonde/027e1abb33c49d3db4b109d0f69919f5 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
# Licensed under the Apache License, Version 2.0
# Modifications copyright Yash Bonde (C) 2021 Nimblebox.ai, Inc.
# This file is peak Google! <3
# How far can you push Python before it's just too hard?
from typing import Any, Dict, Iterable, List, Tuple, Optional
import random
import itertools
import numpy as np
from time import time
from tqdm import tqdm
import jax
import haiku as hk
import jax.numpy as jnp
import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)
################################################################################
# Transformer
# ===========
# Just a causal attention transformer model (GPT)
# Data set is a string sequence sampled with overlaps
################################################################################
class CausalSelfAttention(hk.MultiHeadAttention):
def __call__(self, q, k = None, v = None, mask = None) -> jnp.ndarray:
if q.ndim != 3:
raise ValueError('Expect queries of shape [B, T, D].')
seq_len = q.shape[1]
causal_mask = np.tril(np.ones((seq_len, seq_len)))
mask = mask * causal_mask if mask is not None else causal_mask
return super().__call__(
q,
k if k is not None else q,
v if v is not None else q,
mask
)
class DenseBlock(hk.Module):
def __init__(self, init_scale, widening_factor = 4, name = None):
super().__init__(name = name)
self._init_scale = init_scale
self._widening_factor = widening_factor
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
hiddens = x.shape[-1]
initializer = hk.initializers.VarianceScaling(self._init_scale)
x = hk.Linear(
self._widening_factor * hiddens, w_init=initializer
)(x)
x = jax.nn.gelu(x)
return hk.Linear(
hiddens, w_init=initializer
)(x)
class Transformer(hk.Module):
def __init__(self, num_heads: int, num_layers: int, dropout_rate: float, name = None):
super().__init__(name = name)
self._num_heads = num_heads
self._num_layers = num_layers
self._dropout_rate = dropout_rate
def __call__(
self,
h: jnp.ndarray,
mask : Optional[jnp.ndarray] = None,
is_training: bool = True
):
init_scale = 2. / self._num_layers
dropout_rate = self._dropout_rate if is_training else 0.0
if mask is not None:
mask = mask[:, None, None, :]
for i in range(self._num_layers):
# attention block
h_norm = hk.LayerNorm(
axis=-1,
create_scale=True,
create_offset=True,
name=f"h{i}_ln_1"
)(h)
h_attn = CausalSelfAttention(
num_heads=self._num_heads,
key_size = 32,
w_init_scale = init_scale,
value_size = None,
model_size = h.shape[-1],
name=f"h{i}_attn"
)(h_norm, mask=mask)
h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn)
h = h + h_attn
# dense block
h_norm = hk.LayerNorm(
axis=-1,
create_scale=True,
create_offset=True,
name=f"h{i}_ln_2"
)(h)
h_dense = DenseBlock(init_scale, name=f"h{i}_dense")(h_norm)
h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense)
h = h + h_dense
h = hk.LayerNorm(
axis=-1,
create_scale=True,
create_offset=True,
name=f"ln_f"
)(h)
return h
# data
def infinite_shuffle(iterable: Iterable, buffer_size: int):
ds = itertools.cycle(iterable)
buf = [next(ds) for _ in range(buffer_size)]
random.shuffle(buf)
while 1:
item = next(ds)
idx = random.randint(0, buffer_size - 1) # Inclusive
result, buf[idx] = buf[idx], item
yield result
class Dataset:
def __init__(self, path: str, batch_size: int, sequence_length: int):
"""Load a single-file ASCII dataset in memory."""
self.vocab_size = 128
self._batch_size = batch_size
with open(path, 'r') as f:
corpus = f.read()
if not corpus.isascii():
raise ValueError('Loaded corpus is not ASCII.')
if '\0' in corpus:
# Reserve 0 codepoint for pad token.
raise ValueError('Corpus must not contain null byte.')
# Tokenize by taking ASCII codepoints.
corpus = np.array([ord(c) for c in corpus]).astype(np.int32)
assert np.min(corpus) > 0
assert np.max(corpus) < self.vocab_size # Double-checking ASCII codepoints.
crop_len = sequence_length + 1
num_batches, ragged = divmod(corpus.size, batch_size * crop_len)
if ragged:
corpus = corpus[:-ragged]
corpus = corpus.reshape([-1, crop_len])
if num_batches < 10:
raise ValueError(f'Only {num_batches} batches; consider a shorter sequence or a smaller batch.')
self._ds = infinite_shuffle(corpus, batch_size * 10)
def __next__(self):
"""Yield next mini-batch."""
batch = [next(self._ds) for _ in range(self._batch_size)]
batch = np.stack(batch)
# Create the language modeling observation/target pairs.
return dict(obs=batch[:, :-1], target=batch[:, 1:])
def __iter__(self):
return self
@staticmethod
def decode(tokens: List[int]):
return ''.join(chr(t) for t in tokens)
def generate(forward_fn, config, num_steps: int, state: Dict[str, Any], text: str = 'This is a RSockClient.'):
tokens = np.array([[ord(c) for c in text]])
for _ in range(num_steps):
output = forward_fn(state["params"], state["rng"], {"obs": tokens[:config.m]}, is_training = False)
out_tokens = output[:, -1].argmax(axis=-1)
tokens = np.concatenate([tokens, [out_tokens]], axis=1)
return Dataset.decode(tokens[0])
################################################################################
# Training
# ========
# haiku is pure functional the forward operations must be written down as
# functions. This the structure of the code:
# forward_fn -> jnp.ndarray
# lm_loss_fn -> jnp.ndarray
# On data:
# Since haiku really really is functional, nothing can have a side effect.
# Thus the data object that has to be stored has to be written in OOPs
# style.
# # On code:
# There are two different inits ``__init__`` and ``init``, the
# former is called when the updater is created and the latter is called
# just before the loop.
#
# Then there is the ``update`` method, which is called on every iteration.
# The simplicity of this approach!
################################################################################
import functools
import optax
import os
import pickle
def build_forward_fn(config: Dict):
def _forward(data, is_training: bool = False) -> jnp.ndarray:
"""Forward pass."""
tokens = data['obs']
input_mask = jnp.greater(tokens, 0)
seq_length = tokens.shape[1]
# Embed the input tokens and positions.
embed_init = hk.initializers.TruncatedNormal(stddev=0.02)
token_embedding_map = hk.Embed(config.vocab_size, config.c, w_init=embed_init)
token_embs = token_embedding_map(tokens)
positional_embeddings = hk.get_parameter(
'pos_embs',
[config.m, config.c],
init = embed_init
)
input_embeddings = token_embs + positional_embeddings[:seq_length]
# Run the transformer over the inputs.
transformer = Transformer(
num_heads=config.num_heads,
num_layers=config.num_layers,
dropout_rate=config.dropout_rate
)
output_embeddings = transformer(
input_embeddings,
input_mask,
is_training
)
# Reverse the embeddings (untied).
return hk.Linear(config.vocab_size)(output_embeddings)
return _forward
def lm_loss_fn(forward_fn, vocab_size: int, params, rng, data: Dict[str, jnp.ndarray], is_training: bool = True) -> jnp.ndarray:
"""Compute the loss on data wrt params."""
logits = forward_fn(params, rng, data, is_training)
targets = jax.nn.one_hot(data['target'], vocab_size)
assert logits.shape == targets.shape
mask = jnp.greater(data['obs'], 0)
loss = -jnp.sum(targets * jax.nn.log_softmax(logits), axis=-1)
loss = jnp.sum(loss * mask) / jnp.sum(mask)
return loss
class ParamUpdater:
"""A stateless abstraction around an init_fn/update_fn pair.
This extracts some common boilerplate from the training loop.
"""
def __init__(self, net_init, loss_fn, optimizer: optax.GradientTransformation):
self._net_init = net_init
self._loss_fn = loss_fn
self._opt = optimizer
@functools.partial(jax.jit, static_argnums=0)
def init(self, rng, data) -> Dict:
"""Initializes state of the updater."""
out_rng, init_rng = jax.random.split(rng)
params = self._net_init(init_rng, data)
opt_state = self._opt.init(params)
out = dict(
step=np.array(0),
rng=out_rng,
opt_state=opt_state,
params=params,
)
return out
@functools.partial(jax.jit, static_argnums=0)
def update(self, state: Dict[str, Any], data: Dict[str, jnp.ndarray]) -> Tuple[Dict, Dict]:
"""Updates the state using some data and returns metrics."""
rng, new_rng = jax.random.split(state['rng'])
params = state['params']
loss, g = jax.value_and_grad(self._loss_fn)(params, rng, data)
updates, opt_state = self._opt.update(g, state['opt_state'])
params = optax.apply_updates(params, updates)
new_state = {
'step': state['step'] + 1,
'rng': new_rng,
'opt_state': opt_state,
'params': params,
}
metrics = {
'step': state['step'],
'loss': loss,
}
return new_state, metrics
class CheckpointingUpdater:
"""A didactic checkpointing wrapper around an Updater.
A more mature checkpointing implementation might:
- Use np.savez() to store the core data instead of pickle.
- Not block JAX async dispatch.
- Automatically garbage collect old checkpoints.
Again since haiku is functional anything that has to be stored is written
in OOPs.
"""
def __init__(self, inner: ParamUpdater, checkpoint_dir: str, checkpoint_every_n: int = 10000):
self._inner = inner
self._checkpoint_dir = checkpoint_dir
self._checkpoint_every_n = checkpoint_every_n
def _checkpoint_paths(self):
return [p for p in os.listdir(self._checkpoint_dir) if 'checkpoint_' in p]
def init(self, rng, data):
"""Initialize experiment state."""
if not os.path.exists(self._checkpoint_dir) or not self._checkpoint_paths():
os.makedirs(self._checkpoint_dir, exist_ok=True)
return self._inner.init(rng, data)
else:
checkpoint = os.path.join(self._checkpoint_dir, max(self._checkpoint_paths()))
logging.info('Loading checkpoint from %s', checkpoint)
with open(checkpoint, 'rb') as f:
state = pickle.load(f)
return state
def update(self, state, data):
"""Update experiment state."""
# NOTE: This blocks until `state` is computed. If you want to use JAX async
# dispatch, maintain state['step'] as a NumPy scalar instead of a JAX array.
# Context: https://jax.readthedocs.io/en/latest/async_dispatch.html
step = np.array(state['step'])
if step % self._checkpoint_every_n == 0:
path = os.path.join(self._checkpoint_dir,
'checkpoint_{:07d}.pkl'.format(step))
checkpoint_state = jax.device_get(state)
logging.info('Serializing experiment state to %s', path)
with open(path, 'wb') as f:
pickle.dump(checkpoint_state, f)
state, out = self._inner.update(state, data)
return state, out
################################################################################
# Main
# ====
# haiku is pure functional the forward operations must be written down as
# functions. This the structure of the code:
# forward_fn -> jnp.ndarray
# lm_loss_fn -> jnp.ndarray
################################################################################
class Config(dict):
__getattr__ = dict.__getitem__
__setattr__ = dict.__setitem__
def main(
filepath: str,
m: int = 64,
c: int = 8,
num_heads: int = 8,
num_layers: int = 6,
batch_size: int = 8,
dropout_rate: int = 0.1,
grad_clip_value: float = 1.0,
learning_rate: float = 0.001,
checkpoint_dir: str = './checkpoints',
max_steps: int = 10000,
log_every: int = 1000
):
"""Train an ASCII language model on filepath"""
config = Config(
vocab_size = 128, # fixed for ASCII
filepath = filepath,
m = m,
c = c,
num_heads = num_heads,
num_layers = num_layers,
batch_size = batch_size,
dropout_rate = dropout_rate,
grad_clip_value = grad_clip_value,
learning_rate = learning_rate,
checkpoint_dir = checkpoint_dir,
)
train_dataset = Dataset(
path = filepath,
batch_size = batch_size,
sequence_length = m
)
# Set up the model, loss, and updater.
forward_fn = hk.transform(build_forward_fn(config))
generate_fn = functools.partial(generate, forward_fn.apply, config)
loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, config.vocab_size)
optimizer = optax.chain(
optax.clip_by_global_norm(config.grad_clip_value),
optax.adam(config.learning_rate, b1=0.9, b2=0.99)
)
updater = ParamUpdater(forward_fn.init, loss_fn, optimizer)
updater = CheckpointingUpdater(updater, config.checkpoint_dir)
# Initialize parameters.
logging.info('Initializing parameters...')
rng = jax.random.PRNGKey(428)
data = next(train_dataset)
state = updater.init(rng, data)
logging.info('Starting train loop...')
prev_time = time()
pbar = tqdm(range(max_steps))
for step in pbar:
data = next(train_dataset)
# print({k:v.shape for k,v in data.items()})
state, metrics = updater.update(state, data)
# We use JAX runahead to mask data preprocessing and JAX dispatch overheads.
# Using values from state/metrics too often will block the runahead and can
# cause these overheads to become more prominent.
if step % log_every == 0:
steps_per_sec = log_every / (time() - prev_time)
prev_time = time()
metrics.update({'steps_per_sec': steps_per_sec})
# generate a sample
sample = generate_fn(32, state)
logging.info({k: float(v) for k, v in metrics.items()})
logging.info('Generated sample: %s', sample)
if __name__ == '__main__':
import fire
fire.Fire(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment