Last active
July 19, 2019 14:18
-
-
Save AnchorBlues/b51836c96b90b25b35f209ce7ac8f522 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from sklearn.metrics import accuracy_score | |
from sklearn.externals import joblib | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from tensorflow.keras.datasets import mnist | |
from skorch import NeuralNet | |
class CNN(nn.Module): | |
def __init__(self): | |
super(CNN, self).__init__() | |
self.conv1 = nn.Conv2d(1, 32, kernel_size=3) | |
self.conv2 = nn.Conv2d(32, 64, kernel_size=3) | |
self.conv2_drop = nn.Dropout2d() | |
# 1600 = number channels * width * height | |
self.fc1 = nn.Linear(1600, 128) | |
self.fc2 = nn.Linear(128, 10) | |
def forward(self, x): | |
x = F.relu(F.max_pool2d(self.conv1(x), 2)) # (32, 13, 13) | |
x = F.relu(F.max_pool2d(self.conv2_drop( | |
self.conv2(x)), 2)) # (64, 5, 5) | |
# flatten over channel, height and width = 1600 = 64*5*5 | |
x = x.view(-1, x.size(1) * x.size(2) * x.size(3)) | |
x = F.relu(self.fc1(x)) | |
x = F.dropout(x, training=self.training) | |
x = self.fc2(x) | |
x = F.log_softmax(x, dim=-1) | |
return x | |
if __name__ == "__main__": | |
torch.manual_seed(0) | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
# mnistデータのロード | |
print("data is loading...") | |
(x_train, y_train), (x_test, y_test) = mnist.load_data() | |
# normalize・channelize, make dtype float32 | |
x_train = np.expand_dims(x_train / 255, 1).astype(np.float32) | |
x_test = np.expand_dims(x_test / 255, 1).astype(np.float32) | |
# make dtype int64 for criterion | |
y_train = y_train.astype(np.int64) | |
print("model is setting...") | |
model = NeuralNet( | |
CNN, | |
max_epochs=10, | |
optimizer=torch.optim.Adam, | |
lr=0.001, | |
device=device, | |
batch_size=128, | |
criterion=nn.NLLLoss | |
) | |
model.fit(x_train, y_train) | |
print("model is saving ... ") | |
filename = "model_NeuralNet.obj" | |
joblib.dump(model, filename) | |
del model | |
print("model is loading ...") | |
model = joblib.load(filename) | |
print("evaluating...") | |
pred = model.predict(x_test) | |
assert pred.shape == (len(x_test), 10) | |
pred = pred.argmax(axis=1) | |
print("acc:{}".format(accuracy_score(y_test, pred))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment