Skip to content

Instantly share code, notes, and snippets.

@Deepayan137
Created September 16, 2020 14:57
Show Gist options
  • Save Deepayan137/5e3febbc8bfc7b926dac472864ce7242 to your computer and use it in GitHub Desktop.
Save Deepayan137/5e3febbc8bfc7b926dac472864ce7242 to your computer and use it in GitHub Desktop.
a toy dataset for seq2seq implementation
import torch
from torch.utils.data import Dataset
import numpy as np
import pdb
class DummyDataset(Dataset):
def __init__(self, prob, vocab_size=None,
nSamples=None, max_len=None):
self.prob = prob
if not vocab_size: vocab_size = 10
if not nSamples: nSamples = 20
if not max_len: max_len = 5
self.vocab_size = vocab_size
self.nSamples = nSamples
self.max_len = max_len
self.src_data = self._prepare_src_data()
def __len__(self):
return self.nSamples
def __getitem__(self, index):
assert index < self.nSamples
src = self.src_data[index]
tgt = [self.get_target_id(x) for x in
src]
return {'src':src, 'tgt':tgt}
def func1(self, x):
return x//2
def func2(self, x):
return 2*x + 1
def get_target_id(self, x):
if np.random.random() > self.prob:
return self.func1(x)
return self.func2(x)
def sample_src_ids(self):
src_len = self.get_src_len()
return np.random.choice(self.vocab_size,
src_len)
def get_src_len(self):
return np.random.randint(1,self.max_len)
def _prepare_src_data(self):
return [self.sample_src_ids() for
i in range(self.nSamples)]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment