Skip to content

Instantly share code, notes, and snippets.

@ZaydH
Created May 18, 2019 03:43
Show Gist options
  • Save ZaydH/ba4685d1b49563bc8d76c7ee6f45cef5 to your computer and use it in GitHub Desktop.
Save ZaydH/ba4685d1b49563bc8d76c7ee6f45cef5 to your computer and use it in GitHub Desktop.
Skorch Bug - Batch Size 1
import torch
import torch.nn as nn
from skorch import NeuralNetClassifier
num_classes = 5
input_width = 10
class Module(nn.Module):
def __init__(self):
super().__init__()
self._mod = nn.Sequential(nn.Linear(input_width, 64),
nn.BatchNorm1d(64),
nn.Linear(64, num_classes))
def forward(self, x):
return self._mod(x)
bs = 32
num_ele = 44
x = torch.rand((num_ele, input_width))
y = torch.randint(0, num_classes, (num_ele,))
mod = NeuralNetClassifier(Module(), batch_size=bs)
mod.fit(x, y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment