Skip to content

Instantly share code, notes, and snippets.

@jonhare
Created March 10, 2020 13:39
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jonhare/73a59dcc5416729548a086a983e81f07 to your computer and use it in GitHub Desktop.
Save jonhare/73a59dcc5416729548a086a983e81f07 to your computer and use it in GitHub Desktop.
import torch
from torchvision import transforms
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, size=5000, dim=40, random_offset=0):
super(MyDataset, self).__init__()
self.size = size
self.dim = dim
self.random_offset = random_offset
def __getitem__(self, index):
if index >= len(self):
raise IndexError("{} index out of range".format(self.__class__.__name__))
rng_state = torch.get_rng_state()
torch.manual_seed(index + self.random_offset)
while True:
img = torch.zeros(self.dim, self.dim)
dx = torch.randint(-10,10,(1,),dtype=torch.float)
dy = torch.randint(-10,10,(1,),dtype=torch.float)
c = torch.randint(-20,20,(1,), dtype=torch.float)
params = torch.cat((dy/dx, c))
xy = torch.randint(0,img.shape[1], (20, 2), dtype=torch.float)
xy[:,1] = xy[:,0] * params[0] + params[1]
xy.round_()
xy = xy[ xy[:,1] > 0 ]
xy = xy[ xy[:,1] < self.dim ]
xy = xy[ xy[:,0] < self.dim ]
for i in range(xy.shape[0]):
x, y = xy[i][0], self.dim - xy[i][1]
img[int(y), int(x)]=1
if img.sum() > 2:
break
torch.set_rng_state(rng_state)
return img.unsqueeze(0), params
def __len__(self):
return self.size
train_data = MyDataset()
val_data = MyDataset(size=500, random_offset=33333)
test_data = MyDataset(size=500, random_offset=99999)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment