Skip to content

Instantly share code, notes, and snippets.

@AnchorBlues
Last active July 19, 2019 14:18
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AnchorBlues/b51836c96b90b25b35f209ce7ac8f522 to your computer and use it in GitHub Desktop.
Save AnchorBlues/b51836c96b90b25b35f209ce7ac8f522 to your computer and use it in GitHub Desktop.
import numpy as np
from sklearn.metrics import accuracy_score
from sklearn.externals import joblib
import torch
from torch import nn
import torch.nn.functional as F
from tensorflow.keras.datasets import mnist
from skorch import NeuralNet
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
self.conv2_drop = nn.Dropout2d()
# 1600 = number channels * width * height
self.fc1 = nn.Linear(1600, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2)) # (32, 13, 13)
x = F.relu(F.max_pool2d(self.conv2_drop(
self.conv2(x)), 2)) # (64, 5, 5)
# flatten over channel, height and width = 1600 = 64*5*5
x = x.view(-1, x.size(1) * x.size(2) * x.size(3))
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
x = F.log_softmax(x, dim=-1)
return x
if __name__ == "__main__":
torch.manual_seed(0)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# mnistデータのロード
print("data is loading...")
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# normalize・channelize, make dtype float32
x_train = np.expand_dims(x_train / 255, 1).astype(np.float32)
x_test = np.expand_dims(x_test / 255, 1).astype(np.float32)
# make dtype int64 for criterion
y_train = y_train.astype(np.int64)
print("model is setting...")
model = NeuralNet(
CNN,
max_epochs=10,
optimizer=torch.optim.Adam,
lr=0.001,
device=device,
batch_size=128,
criterion=nn.NLLLoss
)
model.fit(x_train, y_train)
print("model is saving ... ")
filename = "model_NeuralNet.obj"
joblib.dump(model, filename)
del model
print("model is loading ...")
model = joblib.load(filename)
print("evaluating...")
pred = model.predict(x_test)
assert pred.shape == (len(x_test), 10)
pred = pred.argmax(axis=1)
print("acc:{}".format(accuracy_score(y_test, pred)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment