Skip to content

Instantly share code, notes, and snippets.

@tomgrek
Created March 23, 2018 02:36
Show Gist options
  • Save tomgrek/312dbcefa2f102b4ba22c9df817af8b0 to your computer and use it in GitHub Desktop.
Save tomgrek/312dbcefa2f102b4ba22c9df817af8b0 to your computer and use it in GitHub Desktop.
chatbot article part 2_a
class DataGenerator():
def __init__(self, dset):
self.dset = dset
self.len = len(self.dset)
self.idx = 0
def __len__(self):
return self.len
def __iter__(self):
return self
def __next__(self):
x = Variable(torch.LongTensor([self.dset[self.idx:self.idx+2]]))
if self.idx == self.len - 2:
raise StopIteration
y = Variable(torch.LongTensor([self.dset[self.idx+2]]), requires_grad=False)
self.idx = self.idx + 1
return (x, y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment