Skip to content

Instantly share code, notes, and snippets.

@mmsamiei
Created October 26, 2019 07:46
Show Gist options
  • Save mmsamiei/263df6f10935fbcf69eeb0f32dfe4ca6 to your computer and use it in GitHub Desktop.
Save mmsamiei/263df6f10935fbcf69eeb0f32dfe4ca6 to your computer and use it in GitHub Desktop.
from torchtext.data import BucketIterator, interleave_keys
batch_size = 32
def batch_size_fn(new, count, sofar):
"Keep augmenting batch and calculate total number of tokens + padding."
# (new example to add, current effective batch size, current count of examples in the batch)
# when returned value meets batch_size (effective, innate effective batch_size
# defined as global bala bala) then wraper create a batch
sum_len = len(new.query) + len(new.response)
if sum_len > 500:
return batch_size
elif sum_len > 300:
return sofar + 16
elif sum_len > 200:
return sofar + 8
elif sum_len > 120:
return sofar + 4
elif sum_len > 60:
return sofar + 2
elif sum_len > 30:
return sofar + 1
elif sum_len > 20:
return sofar + 0.5
else:
return sofar + 0.25
train_iterator = BucketIterator(dataset= train_dataset, batch_size=batch_size,
batch_size_fn = batch_size_fn
,device=device
,sort_key=lambda x: interleave_keys(len(x.query), len(x.response))
, sort = True
, repeat = False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment