Skip to content

Instantly share code, notes, and snippets.

@Orpheon
Created March 20, 2017 14:04
Show Gist options
  • Save Orpheon/4e31e0f71b551a9ffcd36c2ffd97c4ff to your computer and use it in GitHub Desktop.
Save Orpheon/4e31e0f71b551a9ffcd36c2ffd97c4ff to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.utils.data
from torch.autograd import Variable
# Minimal dataset
class Dataset(torch.utils.data.Dataset):
def __len__(self):
return 12
def __getitem__(self, idx):
return [idx, idx+1]
train_dataset = Dataset()
loader = torch.utils.data.DataLoader(dataset=train_dataset)
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
# Arguments should (by docs) be nn.Conv1d(#input channels, #output channels, kernel size)
self.layer = nn.Conv1d(1, 1, 12)
def forward(self, x):
out = self.layer(x)
return out
cnn = CNN()
for idx, (inputs, labels) in enumerate(loader):
inputs = Variable(inputs)
labels = Variable(labels)
outputs = cnn(inputs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment