Skip to content

Instantly share code, notes, and snippets.

@stephenroller
Last active July 17, 2020 02:36
Show Gist options
  • Save stephenroller/05bbff43ad2a995abbaa18bb9d23229f to your computer and use it in GitHub Desktop.
Save stephenroller/05bbff43ad2a995abbaa18bb9d23229f to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
"""
Example of dynamic/adaptive batching.
Author: Stephen Roller (twitter/github @stephenroller)
Public domain licensed. do whatever you want with this.
Example usage:
$ wget -O bible.txt http://www.gutenberg.org/cache/epub/10/pg10.txt
$ python adaptive.py --nonadaptive bible.txt
Loading data
Number of docs = 24669
Arbitrary line of file: 2:6 But there went up a mist from the earth, and watered the whole face of the ground.
100%|█████████████████████████████| 24669/24669 [00:10<00:00, 2351.76it/s]
WARNING: Errant batch size: torch.Size([29, 216])
Total number of batches = 771
Average non-padding tokens / batch = 1412.0 (full = 16384)
Largest batchsize = torch.Size([32, 200])
Widest batch = torch.Size([32, 512])
Thinest batch = torch.Size([32, 32])
$ python adaptive.py bible.txt
Loading data
Number of docs = 24669
Arbitrary line of file: 2:6 But there went up a mist from the earth, and watered the whole face of the ground.
100%|█████████████████████████████| 24669/24669 [00:10<00:00, 2394.39it/s]
WARNING: Errant batch size: torch.Size([5, 376])
Total number of batches = 83
Average non-padding tokens / batch = 13119.0 (full = 16384)
Largest batchsize = torch.Size([680, 24])
Widest batch = torch.Size([8, 512])
Thinest batch = torch.Size([680, 24])
"""
import logging
import argparse
import random
import torch
import tqdm
from transformers import GPT2TokenizerFast
MAX_LENGTH = 512 # max document length
BATCH_SIZE = 32 # worst case (minimum) batch size
TOKENS_PER_BATCH = MAX_LENGTH * BATCH_SIZE
BUFFER_SIZE = 4096 # trade off efficiency for memory
def make_documents(filebuf):
# gutenberg text is wrapped at 80 characters. This is roughly the equivalent
# of splitting up the document by \n\n.
all_documents = []
doc = []
for line in filebuf:
line = line.rstrip()
if line:
doc.append(line)
elif doc:
doctxt = " ".join(doc).replace(" ", " ")
all_documents.append(doctxt)
doc = []
if doc:
doctxt = " ".join(doc).replace(" ", " ")
all_documents.append(doctxt)
return all_documents
def adaptive_extract_batch(buf):
"""
Pull out a batch from the buffer. Modifies the buffer in place.
"""
# sort to put like examples together
buf.sort()
batch = []
longest = 0
start = random.randint(0, len(buf) - 1)
while buf and len(batch) * longest < TOKENS_PER_BATCH:
cost, r, tokens = buf.pop(start)
longest = max(len(tokens), longest)
if longest % 8 != 0:
longest = longest + 8 - longest % 8
batch.append((cost, r, tokens))
start = max(start - 1, 0)
# we need to force batchsize to be a multiple of 8 for fp16, so put
# things back into the buffer
while len(batch) > 8 and len(batch) % 8 != 0:
buf.append(batch.pop(0))
batch_tensor = torch.zeros((len(batch), longest), dtype=torch.int64)
for i, (_, _, tokens) in enumerate(batch):
batch_tensor[i, : len(tokens)] = tokens
return batch_tensor
def naive_extract_batch(buf):
"""
Pull out a batch naively (in order)
"""
batch = []
longest = 0
for i in range(BATCH_SIZE):
if not buf:
break
cost, r, tokens = buf.pop(0)
batch.append(tokens)
longest = max(longest, len(tokens))
if longest % 8 != 0:
longest = longest + 8 - longest % 8
batch_tensor = torch.zeros((len(batch), longest), dtype=torch.int64)
for i, tokens in enumerate(batch):
batch_tensor[i, : len(tokens)] = tokens
return batch_tensor
def batcher(stream, tokenizer, extract_fn):
buf = []
for item in stream:
tokenized = tokenizer(
item, max_length=MAX_LENGTH, return_tensors='pt', truncation=True,
)['input_ids'][0]
# truncate
cost = len(tokenized) // 8
buf.append((cost, random.random(), tokenized))
if len(buf) >= BUFFER_SIZE:
yield extract_fn(buf)
# ate all the data. handle whatever is leftover in the buffer
while buf:
yield extract_fn(buf)
def main(args=None):
parser = argparse.ArgumentParser()
parser.add_argument('file', type=argparse.FileType('r'))
parser.add_argument('--nonadaptive', action='store_true')
opts = parser.parse_args(args)
# load up data
print("Loading data")
docs = make_documents(opts.file)
print(f"Number of docs = {len(docs)}")
print(f"Arbitrary line of file: {docs[42]}")
# initialize the tokenizer
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
max_width = 0
max_height = 0
min_width = MAX_LENGTH + 1
widest_batch = None
thinest_batch = None
longest_batch = None
tokens_per_batch = []
if opts.nonadaptive:
extract = naive_extract_batch
else:
extract = adaptive_extract_batch
for b, batch in enumerate(batcher(tqdm.tqdm(docs, ncols=74), tokenizer, extract)):
height, width = batch.shape
tokens_per_batch.append((batch != 0).sum())
assert width % 8 == 0
assert height > 0
assert width > 0
if height % 8 != 0:
print(f"WARNING: Errant batch size: {batch.shape}")
if height >= max_height:
longest_batch = batch.shape
max_height = height
if width >= max_width:
widest_batch = batch.shape
max_width = width
if width <= min_width:
thinest_batch = batch.shape
min_width = width
print(f"Total number of batches = {b + 1}")
tokens_per_batch = sum(tokens_per_batch) / len(tokens_per_batch)
print(
f"Average non-padding tokens / batch = {tokens_per_batch:.1f} (full = {TOKENS_PER_BATCH})"
)
print(f"Largest batchsize = {longest_batch}")
print(f"Widest batch = {widest_batch}")
print(f"Thinest batch = {thinest_batch}")
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment