Last active
September 15, 2024 00:54
-
-
Save evanatyourservice/8d223edf29fe3e82682a197fd228fa1c to your computer and use it in GitHub Desktop.
How to prepare and evaluate on hellaswag in JAX
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
from tqdm import tqdm | |
import numpy as np | |
import jax | |
import jax.numpy as jnp | |
from flax.training.train_state import TrainState | |
import optax | |
import optax.tree_utils as otu | |
import tensorflow as tf | |
from transformers import AutoTokenizer | |
hellaswag_val_path = "data/hellaswag_val.jsonl" | |
batch_size = 32 | |
block_size = 1024 | |
model_id = "mistralai/Mistral-7B-v0.3" | |
def prepare_hellaswag(batch_size: int, block_size: int, tf_prefetch: int = 2): | |
"""Read file and tokenize the hellaswag dataset.""" | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_id, trust_remote_code=True, use_fast=True | |
) | |
all_data = [] | |
all_beginning_lengths = [] | |
all_seq_lengths = [] | |
all_labels = [] | |
with open("data/hellaswag_val.jsonl", "r") as f: | |
# iterate over lines and tokenize | |
for line in tqdm(f, total=10042): | |
item = json.loads(line) | |
context = item["ctx"] | |
endings = item["endings"] | |
correct_end = item["label"] | |
beginning_length = len(tokenizer(context)["input_ids"]) | |
data_to_concat = [] | |
beginning_lengths_to_concat = [] | |
seq_lengths_to_concat = [] | |
for ending in endings: | |
output = tokenizer(context + " " + ending)["input_ids"] | |
output_len = len(output) | |
# pad to block_size | |
if output_len < block_size: | |
output = output + [tokenizer.eos_token_id] * ( | |
block_size - output_len | |
) | |
# max length is block_size | |
output = output[:block_size] | |
data_to_concat.append(output) | |
beginning_lengths_to_concat.append(beginning_length) | |
seq_lengths_to_concat.append(output_len) | |
all_data.append(np.array(data_to_concat, dtype=np.uint16)) | |
all_beginning_lengths.append( | |
np.array(beginning_lengths_to_concat, dtype=np.int32) | |
) | |
all_seq_lengths.append( | |
np.array(seq_lengths_to_concat, dtype=np.int32) | |
) | |
all_labels.append(int(correct_end)) | |
all_data = np.array(all_data, dtype=np.uint16) # (10042, 4, seq_len) | |
all_beginning_lengths = np.array( | |
all_beginning_lengths, dtype=np.int32 | |
) # (10042, 4) | |
all_seq_lengths = np.array(all_seq_lengths, dtype=np.int32) # (10042, 4) | |
all_labels = np.array(all_labels, dtype=np.int32) # (10042,) | |
ds = tf.data.Dataset.from_tensor_slices( | |
(all_data, all_beginning_lengths, all_seq_lengths, all_labels) | |
) | |
# ds = ds.repeat() | |
ds = ds.batch( | |
batch_size // jax.process_count(), | |
drop_remainder=True, | |
num_parallel_calls=tf.data.AUTOTUNE, | |
) | |
ds = ds.prefetch(tf_prefetch) | |
ds = ds.as_numpy_iterator() | |
return ds | |
def hs_eval_step_unreduced( | |
state: TrainState, | |
tokens: jnp.ndarray, | |
begin_lens: jnp.ndarray, | |
seq_lens: jnp.ndarray, | |
) -> jnp.ndarray: | |
compute_dtype = jnp.float32 | |
logits = state.apply_fn( | |
otu.tree_cast(state.params, compute_dtype), tokens[:, :-1] | |
) | |
assert logits.dtype == compute_dtype | |
logits = logits.astype(jnp.float32) | |
targets = tokens[:, 1:] | |
losses = optax.softmax_cross_entropy_with_integer_labels(logits, targets) | |
@jax.vmap | |
def unreduced_losses(loss, begin_len, seq_len): | |
seq_range = jnp.arange(len(loss)) | |
seq_mask = jnp.logical_and( | |
seq_range < seq_len - 1, seq_range >= begin_len - 1 | |
).astype(jnp.bool) | |
loss = loss * seq_mask | |
return jnp.sum(loss) / jnp.sum(seq_mask) | |
losses = unreduced_losses(losses, begin_lens, seq_lens) | |
return losses | |
def eval_hellaswag(state: TrainState, data, begin_lens, seq_lens, labels): | |
"""Evaluate the hellaswag dataset.""" | |
# data comes in shape (b, 4, block_size + 1) | |
# seq lens come in shape (b, 4) | |
# labels come in shape (b,) | |
bs_in = data.shape[0] | |
data = jnp.reshape(data, (-1, data.shape[-1])) | |
begin_lens = jnp.reshape(begin_lens, (-1,)) | |
seq_lens = jnp.reshape(seq_lens, (-1,)) | |
losses = hs_eval_step_unreduced(state, data, begin_lens, seq_lens) | |
choices = jnp.argmin(jnp.reshape(losses, (bs_in, 4)), axis=-1) | |
correct = jnp.sum(choices == labels) | |
accuracy = correct / bs_in | |
return accuracy | |
""" | |
Example usage: | |
ds = prepare_hellaswag( | |
data_path=hellaswag_val_path, | |
batch_size=batch_size, | |
block_size=block_size, | |
tf_prefetch=2, | |
) | |
for hellaswag_batch in ds: | |
accuracy = eval_hellaswag(state, *hellaswag_batch) | |
print(f"Accuracy: {accuracy}") | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment