Created
October 26, 2019 07:46
-
-
Save mmsamiei/263df6f10935fbcf69eeb0f32dfe4ca6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torchtext.data import BucketIterator, interleave_keys | |
batch_size = 32 | |
def batch_size_fn(new, count, sofar): | |
"Keep augmenting batch and calculate total number of tokens + padding." | |
# (new example to add, current effective batch size, current count of examples in the batch) | |
# when returned value meets batch_size (effective, innate effective batch_size | |
# defined as global bala bala) then wraper create a batch | |
sum_len = len(new.query) + len(new.response) | |
if sum_len > 500: | |
return batch_size | |
elif sum_len > 300: | |
return sofar + 16 | |
elif sum_len > 200: | |
return sofar + 8 | |
elif sum_len > 120: | |
return sofar + 4 | |
elif sum_len > 60: | |
return sofar + 2 | |
elif sum_len > 30: | |
return sofar + 1 | |
elif sum_len > 20: | |
return sofar + 0.5 | |
else: | |
return sofar + 0.25 | |
train_iterator = BucketIterator(dataset= train_dataset, batch_size=batch_size, | |
batch_size_fn = batch_size_fn | |
,device=device | |
,sort_key=lambda x: interleave_keys(len(x.query), len(x.response)) | |
, sort = True | |
, repeat = False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment