Skip to content

Instantly share code, notes, and snippets.

@terasakisatoshi
Last active June 1, 2022 01:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save terasakisatoshi/5e427413f7e86786b0c423bd0e4add21 to your computer and use it in GitHub Desktop.
Save terasakisatoshi/5e427413f7e86786b0c423bd0e4add21 to your computer and use it in GitHub Desktop.
ディープ何もわからん
import numpy as np
import os
from glob import glob
import torch
import torch.nn as nn
from tqdm import tqdm
from PIL import Image
import imageio
import torchvision
import albumentations as A
from albumentations.pytorch import ToTensorV2
import albumentations.augmentations.functional as F
# torch.manual_seed(2022)
LABELS = ["daisy", "dandelion", "roses", "sunflowers", "tulips"]
LABEL2INDEX = {lab:idx for (idx,lab) in enumerate(LABELS)}
class FlowerDataset(torch.utils.data.Dataset):
def __init__(self, dataset_dir):
self.dataset_dir = dataset_dir
self.anns = self.load_annotations()
def load_annotations(self):
train_anns=[]
val_anns = []
test_anns = []
for label in LABELS:
files = sorted(glob(os.path.join(dataset_dir, label, "*.jpg")))
train_size = int(0.8*len(files))
val_size = int(0.1*len(files))
test_size = len(files) - train_size - val_size
train, val, test = torch.utils.data.random_split(files, [train_size, val_size, test_size])
train_anns.extend([(f, LABEL2INDEX[label]) for f in train])
val_anns.extend([(f, LABEL2INDEX[label]) for f in val])
test_anns.extend([(f, LABEL2INDEX[label]) for f in test])
return dict(train=train_anns, val=val_anns, test=test_anns)
dataset_dir = os.path.expanduser("~/dataset/flower_photos")
train_transform = A.Compose(
[
A.Resize(128,128),
#A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
]
)
test_transform = val_transform = A.Compose(
[
A.Resize(128,128),
#A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
]
)
tfm_recipe = dict(train=train_transform, val=val_transform, test=test_transform)
class TransformDataset(torch.utils.data.Dataset):
def __init__(self, base_set, phase):
self.base_set = base_set
self.phase = phase
self.tfm = tfm_recipe[phase]
def __getitem__(self, idx):
f, class_id = self.base_set[idx]
img = imageio.imread(f)
tfmed = self.tfm(image=img)["image"]
return tfmed/255, class_id
def __len__(self):
return len(self.base_set)
dset = FlowerDataset(dataset_dir)
train_set = TransformDataset(dset.anns["train"],phase="train")
val_set = TransformDataset(dset.anns["val"],phase="val")
test_set = TransformDataset(dset.anns["test"],phase="test")
def conv_bn(inC, outC, use_bn=True):
bias = False if use_bn else True
layers = [nn.Conv2d(inC, outC, kernel_size=3, stride=1,padding=1, bias=bias)]
if use_bn:
layers.append(nn.BatchNorm2d(outC))
layers.append(nn.ReLU(inplace=True))
return nn.Sequential(*layers)
model = torch.nn.Sequential(
conv_bn(3,16,use_bn=True),
nn.MaxPool2d(2),
conv_bn(16,32,use_bn=True),
nn.MaxPool2d(2),
conv_bn(32,64,use_bn=True),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(64*16*16,512),
nn.ReLU(inplace=True),
nn.Linear(512,5),
)
model = torchvision.models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(512,5)
model = torchvision.models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(512,5)
model.train();
train_loader = torch.utils.data.DataLoader(
train_set,
batch_size=128,
shuffle=True,
)
val_loader = torch.utils.data.DataLoader(
val_set,
batch_size=128,
shuffle=False,
)
criterion = nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), 0.001)
for i in range(10):
print(i)
model.train()
train_tot = 0
acc = 0
tot = 0
for (x, y) in tqdm(train_loader):
tot += x.size(0)
out = model(x)
loss = criterion(out, y)
opt.zero_grad()
loss.backward()
opt.step()
train_tot += loss.item()
acc += sum(torch.argmax(out, axis=1) == y)
train_loss = train_tot / len(train_loader)
accuracy = 100*acc / tot # percent
print(train_loss, accuracy)
model.eval()
val_tot = 0
acc = 0
tot = 0
for (x, y) in tqdm(val_loader):
tot += x.size(0)
with torch.no_grad():
out = model(x)
loss = criterion(out, y)
val_tot += loss.item()
acc += sum(torch.argmax(out, axis=1) == y)
val_loss = val_tot / len(val_loader)
accuracy = 100*acc / tot # percent
print(val_loss, accuracy)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment