Skip to content

Instantly share code, notes, and snippets.

@Deepayan137
Last active September 16, 2020 06:36
Show Gist options
  • Save Deepayan137/2c6c546afb63cc0d95f9418051823937 to your computer and use it in GitHub Desktop.
Save Deepayan137/2c6c546afb63cc0d95f9418051823937 to your computer and use it in GitHub Desktop.
sample code
import torch
from torch.utils.data import Dataset
import numpy as np
class DummyDataset(Dataset):
def __init__(self, **kwargs):
self.prob = kwargs['prob']
self.vocab_size = kwargs['vocab_size']
self.nSamples = kwargs['nSamples']
self.src_data = np.random.choice(self.vocab_size,
self.nSamples)
def __len__(self):
return self.nSamples
def func1(self, x):
return x//2
def func2(self, x):
return 2*x + 1
def get_target(self, x):
if np.random.random() > self.prob:
return self.func1(x)
return self.func2(x)
def __getitem__(self, index):
assert index < self.nSamples
x = self.src_data[index]
y = self.get_target(x)
return {'src':x, 'tgt':y}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment