Skip to content

Instantly share code, notes, and snippets.

@hexagit
Created June 16, 2022 12:43
Show Gist options
  • Save hexagit/247c69ee28630034e0ae791ad9cc94a4 to your computer and use it in GitHub Desktop.
Save hexagit/247c69ee28630034e0ae791ad9cc94a4 to your computer and use it in GitHub Desktop.
# ライブラリの読み込み
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor, Lambda, Compose
# 定義
class NNModel(nn.Module):
def __init__(self, x, y):
super(NNModel, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(x*y, 120),
nn.ReLU(),
nn.Linear(120, 60),
nn.ReLU(),
nn.Linear(60, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
def Training(dataLoader, model, lossFunc, optimizer):
size = len(dataLoader.dataset)
for batch, (X, y) in enumerate(dataLoader):
# 損失誤差算出
pred = model(X)
loss = lossFunc(pred, y)
# 誤差逆伝播
optimizer.zero_grad()
loss.backward()
optimizer.step()
def Test(dataloader, model, lossFunc):
size = len(dataloader.dataset)
model.eval()
testLoss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
pred = model(X)
testLoss += lossFunc(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
testLoss /= size
correct /= size
print(f"Accuracy: {(100*correct):>0.1f}%, Avg loss: {testLoss:>8f}")
# データロード
dataPath = os.path.dirname(__file__) + '\data'
transform = Compose([ToTensor(), transforms.Normalize((0.5,), (0.5,))]) # 正規化用処理
trainDataset = datasets.MNIST(
root=dataPath,
train=True,
download=True,
transform=transform
)
trainDataLoader = DataLoader(
dataset=trainDataset,
batch_size=100,
shuffle=True,
drop_last=True
)
# モデル作成
model = NNModel(28, 28)
lossFunc = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.005)
for epoch in range(50):
# 学習
Training(trainDataLoader, model, lossFunc, optimizer)
# テスト
Test(trainDataLoader, model, lossFunc)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment