Last active
February 20, 2022 12:40
-
-
Save yashbonde/027e1abb33c49d3db4b109d0f69919f5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. | |
# Licensed under the Apache License, Version 2.0 | |
# Modifications copyright Yash Bonde (C) 2021 Nimblebox.ai, Inc. | |
# This file is peak Google! <3 | |
# How far can you push Python before it's just too hard? | |
from typing import Any, Dict, Iterable, List, Tuple, Optional | |
import random | |
import itertools | |
import numpy as np | |
from time import time | |
from tqdm import tqdm | |
import jax | |
import haiku as hk | |
import jax.numpy as jnp | |
import logging | |
logger = logging.getLogger() | |
logger.setLevel(logging.INFO) | |
################################################################################ | |
# Transformer | |
# =========== | |
# Just a causal attention transformer model (GPT) | |
# Data set is a string sequence sampled with overlaps | |
################################################################################ | |
class CausalSelfAttention(hk.MultiHeadAttention): | |
def __call__(self, q, k = None, v = None, mask = None) -> jnp.ndarray: | |
if q.ndim != 3: | |
raise ValueError('Expect queries of shape [B, T, D].') | |
seq_len = q.shape[1] | |
causal_mask = np.tril(np.ones((seq_len, seq_len))) | |
mask = mask * causal_mask if mask is not None else causal_mask | |
return super().__call__( | |
q, | |
k if k is not None else q, | |
v if v is not None else q, | |
mask | |
) | |
class DenseBlock(hk.Module): | |
def __init__(self, init_scale, widening_factor = 4, name = None): | |
super().__init__(name = name) | |
self._init_scale = init_scale | |
self._widening_factor = widening_factor | |
def __call__(self, x: jnp.ndarray) -> jnp.ndarray: | |
hiddens = x.shape[-1] | |
initializer = hk.initializers.VarianceScaling(self._init_scale) | |
x = hk.Linear( | |
self._widening_factor * hiddens, w_init=initializer | |
)(x) | |
x = jax.nn.gelu(x) | |
return hk.Linear( | |
hiddens, w_init=initializer | |
)(x) | |
class Transformer(hk.Module): | |
def __init__(self, num_heads: int, num_layers: int, dropout_rate: float, name = None): | |
super().__init__(name = name) | |
self._num_heads = num_heads | |
self._num_layers = num_layers | |
self._dropout_rate = dropout_rate | |
def __call__( | |
self, | |
h: jnp.ndarray, | |
mask : Optional[jnp.ndarray] = None, | |
is_training: bool = True | |
): | |
init_scale = 2. / self._num_layers | |
dropout_rate = self._dropout_rate if is_training else 0.0 | |
if mask is not None: | |
mask = mask[:, None, None, :] | |
for i in range(self._num_layers): | |
# attention block | |
h_norm = hk.LayerNorm( | |
axis=-1, | |
create_scale=True, | |
create_offset=True, | |
name=f"h{i}_ln_1" | |
)(h) | |
h_attn = CausalSelfAttention( | |
num_heads=self._num_heads, | |
key_size = 32, | |
w_init_scale = init_scale, | |
value_size = None, | |
model_size = h.shape[-1], | |
name=f"h{i}_attn" | |
)(h_norm, mask=mask) | |
h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn) | |
h = h + h_attn | |
# dense block | |
h_norm = hk.LayerNorm( | |
axis=-1, | |
create_scale=True, | |
create_offset=True, | |
name=f"h{i}_ln_2" | |
)(h) | |
h_dense = DenseBlock(init_scale, name=f"h{i}_dense")(h_norm) | |
h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense) | |
h = h + h_dense | |
h = hk.LayerNorm( | |
axis=-1, | |
create_scale=True, | |
create_offset=True, | |
name=f"ln_f" | |
)(h) | |
return h | |
# data | |
def infinite_shuffle(iterable: Iterable, buffer_size: int): | |
ds = itertools.cycle(iterable) | |
buf = [next(ds) for _ in range(buffer_size)] | |
random.shuffle(buf) | |
while 1: | |
item = next(ds) | |
idx = random.randint(0, buffer_size - 1) # Inclusive | |
result, buf[idx] = buf[idx], item | |
yield result | |
class Dataset: | |
def __init__(self, path: str, batch_size: int, sequence_length: int): | |
"""Load a single-file ASCII dataset in memory.""" | |
self.vocab_size = 128 | |
self._batch_size = batch_size | |
with open(path, 'r') as f: | |
corpus = f.read() | |
if not corpus.isascii(): | |
raise ValueError('Loaded corpus is not ASCII.') | |
if '\0' in corpus: | |
# Reserve 0 codepoint for pad token. | |
raise ValueError('Corpus must not contain null byte.') | |
# Tokenize by taking ASCII codepoints. | |
corpus = np.array([ord(c) for c in corpus]).astype(np.int32) | |
assert np.min(corpus) > 0 | |
assert np.max(corpus) < self.vocab_size # Double-checking ASCII codepoints. | |
crop_len = sequence_length + 1 | |
num_batches, ragged = divmod(corpus.size, batch_size * crop_len) | |
if ragged: | |
corpus = corpus[:-ragged] | |
corpus = corpus.reshape([-1, crop_len]) | |
if num_batches < 10: | |
raise ValueError(f'Only {num_batches} batches; consider a shorter sequence or a smaller batch.') | |
self._ds = infinite_shuffle(corpus, batch_size * 10) | |
def __next__(self): | |
"""Yield next mini-batch.""" | |
batch = [next(self._ds) for _ in range(self._batch_size)] | |
batch = np.stack(batch) | |
# Create the language modeling observation/target pairs. | |
return dict(obs=batch[:, :-1], target=batch[:, 1:]) | |
def __iter__(self): | |
return self | |
@staticmethod | |
def decode(tokens: List[int]): | |
return ''.join(chr(t) for t in tokens) | |
def generate(forward_fn, config, num_steps: int, state: Dict[str, Any], text: str = 'This is a RSockClient.'): | |
tokens = np.array([[ord(c) for c in text]]) | |
for _ in range(num_steps): | |
output = forward_fn(state["params"], state["rng"], {"obs": tokens[:config.m]}, is_training = False) | |
out_tokens = output[:, -1].argmax(axis=-1) | |
tokens = np.concatenate([tokens, [out_tokens]], axis=1) | |
return Dataset.decode(tokens[0]) | |
################################################################################ | |
# Training | |
# ======== | |
# haiku is pure functional the forward operations must be written down as | |
# functions. This the structure of the code: | |
# forward_fn -> jnp.ndarray | |
# lm_loss_fn -> jnp.ndarray | |
# On data: | |
# Since haiku really really is functional, nothing can have a side effect. | |
# Thus the data object that has to be stored has to be written in OOPs | |
# style. | |
# # On code: | |
# There are two different inits ``__init__`` and ``init``, the | |
# former is called when the updater is created and the latter is called | |
# just before the loop. | |
# | |
# Then there is the ``update`` method, which is called on every iteration. | |
# The simplicity of this approach! | |
################################################################################ | |
import functools | |
import optax | |
import os | |
import pickle | |
def build_forward_fn(config: Dict): | |
def _forward(data, is_training: bool = False) -> jnp.ndarray: | |
"""Forward pass.""" | |
tokens = data['obs'] | |
input_mask = jnp.greater(tokens, 0) | |
seq_length = tokens.shape[1] | |
# Embed the input tokens and positions. | |
embed_init = hk.initializers.TruncatedNormal(stddev=0.02) | |
token_embedding_map = hk.Embed(config.vocab_size, config.c, w_init=embed_init) | |
token_embs = token_embedding_map(tokens) | |
positional_embeddings = hk.get_parameter( | |
'pos_embs', | |
[config.m, config.c], | |
init = embed_init | |
) | |
input_embeddings = token_embs + positional_embeddings[:seq_length] | |
# Run the transformer over the inputs. | |
transformer = Transformer( | |
num_heads=config.num_heads, | |
num_layers=config.num_layers, | |
dropout_rate=config.dropout_rate | |
) | |
output_embeddings = transformer( | |
input_embeddings, | |
input_mask, | |
is_training | |
) | |
# Reverse the embeddings (untied). | |
return hk.Linear(config.vocab_size)(output_embeddings) | |
return _forward | |
def lm_loss_fn(forward_fn, vocab_size: int, params, rng, data: Dict[str, jnp.ndarray], is_training: bool = True) -> jnp.ndarray: | |
"""Compute the loss on data wrt params.""" | |
logits = forward_fn(params, rng, data, is_training) | |
targets = jax.nn.one_hot(data['target'], vocab_size) | |
assert logits.shape == targets.shape | |
mask = jnp.greater(data['obs'], 0) | |
loss = -jnp.sum(targets * jax.nn.log_softmax(logits), axis=-1) | |
loss = jnp.sum(loss * mask) / jnp.sum(mask) | |
return loss | |
class ParamUpdater: | |
"""A stateless abstraction around an init_fn/update_fn pair. | |
This extracts some common boilerplate from the training loop. | |
""" | |
def __init__(self, net_init, loss_fn, optimizer: optax.GradientTransformation): | |
self._net_init = net_init | |
self._loss_fn = loss_fn | |
self._opt = optimizer | |
@functools.partial(jax.jit, static_argnums=0) | |
def init(self, rng, data) -> Dict: | |
"""Initializes state of the updater.""" | |
out_rng, init_rng = jax.random.split(rng) | |
params = self._net_init(init_rng, data) | |
opt_state = self._opt.init(params) | |
out = dict( | |
step=np.array(0), | |
rng=out_rng, | |
opt_state=opt_state, | |
params=params, | |
) | |
return out | |
@functools.partial(jax.jit, static_argnums=0) | |
def update(self, state: Dict[str, Any], data: Dict[str, jnp.ndarray]) -> Tuple[Dict, Dict]: | |
"""Updates the state using some data and returns metrics.""" | |
rng, new_rng = jax.random.split(state['rng']) | |
params = state['params'] | |
loss, g = jax.value_and_grad(self._loss_fn)(params, rng, data) | |
updates, opt_state = self._opt.update(g, state['opt_state']) | |
params = optax.apply_updates(params, updates) | |
new_state = { | |
'step': state['step'] + 1, | |
'rng': new_rng, | |
'opt_state': opt_state, | |
'params': params, | |
} | |
metrics = { | |
'step': state['step'], | |
'loss': loss, | |
} | |
return new_state, metrics | |
class CheckpointingUpdater: | |
"""A didactic checkpointing wrapper around an Updater. | |
A more mature checkpointing implementation might: | |
- Use np.savez() to store the core data instead of pickle. | |
- Not block JAX async dispatch. | |
- Automatically garbage collect old checkpoints. | |
Again since haiku is functional anything that has to be stored is written | |
in OOPs. | |
""" | |
def __init__(self, inner: ParamUpdater, checkpoint_dir: str, checkpoint_every_n: int = 10000): | |
self._inner = inner | |
self._checkpoint_dir = checkpoint_dir | |
self._checkpoint_every_n = checkpoint_every_n | |
def _checkpoint_paths(self): | |
return [p for p in os.listdir(self._checkpoint_dir) if 'checkpoint_' in p] | |
def init(self, rng, data): | |
"""Initialize experiment state.""" | |
if not os.path.exists(self._checkpoint_dir) or not self._checkpoint_paths(): | |
os.makedirs(self._checkpoint_dir, exist_ok=True) | |
return self._inner.init(rng, data) | |
else: | |
checkpoint = os.path.join(self._checkpoint_dir, max(self._checkpoint_paths())) | |
logging.info('Loading checkpoint from %s', checkpoint) | |
with open(checkpoint, 'rb') as f: | |
state = pickle.load(f) | |
return state | |
def update(self, state, data): | |
"""Update experiment state.""" | |
# NOTE: This blocks until `state` is computed. If you want to use JAX async | |
# dispatch, maintain state['step'] as a NumPy scalar instead of a JAX array. | |
# Context: https://jax.readthedocs.io/en/latest/async_dispatch.html | |
step = np.array(state['step']) | |
if step % self._checkpoint_every_n == 0: | |
path = os.path.join(self._checkpoint_dir, | |
'checkpoint_{:07d}.pkl'.format(step)) | |
checkpoint_state = jax.device_get(state) | |
logging.info('Serializing experiment state to %s', path) | |
with open(path, 'wb') as f: | |
pickle.dump(checkpoint_state, f) | |
state, out = self._inner.update(state, data) | |
return state, out | |
################################################################################ | |
# Main | |
# ==== | |
# haiku is pure functional the forward operations must be written down as | |
# functions. This the structure of the code: | |
# forward_fn -> jnp.ndarray | |
# lm_loss_fn -> jnp.ndarray | |
################################################################################ | |
class Config(dict): | |
__getattr__ = dict.__getitem__ | |
__setattr__ = dict.__setitem__ | |
def main( | |
filepath: str, | |
m: int = 64, | |
c: int = 8, | |
num_heads: int = 8, | |
num_layers: int = 6, | |
batch_size: int = 8, | |
dropout_rate: int = 0.1, | |
grad_clip_value: float = 1.0, | |
learning_rate: float = 0.001, | |
checkpoint_dir: str = './checkpoints', | |
max_steps: int = 10000, | |
log_every: int = 1000 | |
): | |
"""Train an ASCII language model on filepath""" | |
config = Config( | |
vocab_size = 128, # fixed for ASCII | |
filepath = filepath, | |
m = m, | |
c = c, | |
num_heads = num_heads, | |
num_layers = num_layers, | |
batch_size = batch_size, | |
dropout_rate = dropout_rate, | |
grad_clip_value = grad_clip_value, | |
learning_rate = learning_rate, | |
checkpoint_dir = checkpoint_dir, | |
) | |
train_dataset = Dataset( | |
path = filepath, | |
batch_size = batch_size, | |
sequence_length = m | |
) | |
# Set up the model, loss, and updater. | |
forward_fn = hk.transform(build_forward_fn(config)) | |
generate_fn = functools.partial(generate, forward_fn.apply, config) | |
loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, config.vocab_size) | |
optimizer = optax.chain( | |
optax.clip_by_global_norm(config.grad_clip_value), | |
optax.adam(config.learning_rate, b1=0.9, b2=0.99) | |
) | |
updater = ParamUpdater(forward_fn.init, loss_fn, optimizer) | |
updater = CheckpointingUpdater(updater, config.checkpoint_dir) | |
# Initialize parameters. | |
logging.info('Initializing parameters...') | |
rng = jax.random.PRNGKey(428) | |
data = next(train_dataset) | |
state = updater.init(rng, data) | |
logging.info('Starting train loop...') | |
prev_time = time() | |
pbar = tqdm(range(max_steps)) | |
for step in pbar: | |
data = next(train_dataset) | |
# print({k:v.shape for k,v in data.items()}) | |
state, metrics = updater.update(state, data) | |
# We use JAX runahead to mask data preprocessing and JAX dispatch overheads. | |
# Using values from state/metrics too often will block the runahead and can | |
# cause these overheads to become more prominent. | |
if step % log_every == 0: | |
steps_per_sec = log_every / (time() - prev_time) | |
prev_time = time() | |
metrics.update({'steps_per_sec': steps_per_sec}) | |
# generate a sample | |
sample = generate_fn(32, state) | |
logging.info({k: float(v) for k, v in metrics.items()}) | |
logging.info('Generated sample: %s', sample) | |
if __name__ == '__main__': | |
import fire | |
fire.Fire(main) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment