Skip to content

Instantly share code, notes, and snippets.

@OhadRubin
Last active January 27, 2024 11:50
Show Gist options
  • Save OhadRubin/52c6f912eb91b43893f9e866e011c6e2 to your computer and use it in GitHub Desktop.
Save OhadRubin/52c6f912eb91b43893f9e866e011c6e2 to your computer and use it in GitHub Desktop.
sliding_window_eval.py
from more_itertools import windowed, repeat_last
import numpy as np
from more_itertools import grouper
input_ids = np.arange(65)
# input_ids = np.arange(4097)
width = 16
stride = 4
# class DummyTokenizer:
# def __init__(self):
def sliding_window(tok_input_ids, width, stride, padding_value = -100, bos_token_id = -2):
if stride>0:
all_windows = windowed(tok_input_ids, n=width, step=stride,fillvalue=padding_value)
mask_iterable = [np.ones(width,dtype=bool),
np.r_[np.zeros(width-stride,dtype=bool), np.ones(stride,dtype=bool)]
]
else:
all_windows = windowed(tok_input_ids, n=width, step=width-1,fillvalue=padding_value)
mask_iterable = [np.ones(width,dtype=bool),
np.r_[np.zeros(1,dtype=bool), np.ones(width-1,dtype=bool)]
]
all_windows = map(lambda x:np.array(x,dtype=np.int32), all_windows)
for tokens, loss_mask in zip(all_windows,repeat_last(mask_iterable)):
attention_mask = tokens!=padding_value
tokens[~attention_mask] = padding_value
loss_mask[~attention_mask] = False
tokens = np.r_[bos_token_id, tokens]
attention_mask = np.r_[True, attention_mask[:-1]]
input_tokens = tokens[:-1]
targets = tokens[1:]
yield dict(input_tokens=input_tokens,
targets=targets,
attention_mask=attention_mask,
loss_mask=loss_mask)
for element in sliding_window(np.arange(19), width=8, stride=4):
input_tokens = element["input_tokens"]
targets = element["targets"]
attention_mask = element["attention_mask"]
loss_mask = element["loss_mask"]
print(targets[loss_mask])
print(element)
for element in sliding_window(np.arange(19), width=8, stride=0):
input_tokens = element["input_tokens"]
targets = element["targets"]
attention_mask = element["attention_mask"]
loss_mask = element["loss_mask"]
print(targets[loss_mask])
print(element)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment