Skip to content

Instantly share code, notes, and snippets.

@poutyface
Created November 8, 2021 23:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save poutyface/04b38696a93d37031086ab9833f93541 to your computer and use it in GitHub Desktop.
Save poutyface/04b38696a93d37031086ab9833f93541 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import math
import random
from PIL import ImageOps, ImageEnhance, ImageFilter
import numpy as np
from RandAugment import RandAugment
#import torch_xla
#import torch_xla.core.xla_model as xm
#device = xm.xla_device()
TPU=False
HALF=False
print(torch.__version__)
torch.backends.cudnn.benchmark=True
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#if TPU:
# device = xm.xla_device()
print(device)
CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]
transform_train = transforms.Compose([
transforms.RandomResizedCrop(32, scale=(0.8, 1.0), interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(),
RandAugment(2, 5),
#transforms.RandomResizedCrop(32, scale=(0.8, 1.0)),
#transforms.RandomHorizontalFlip(),
#RandAugment(2, 5),
#EMB(),
#transforms.Resize((64,64)),
#transforms.RandomCrop(32, padding=4),
#CropUpper(),
#transforms.RandomVerticalFlip(),
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD)
#transforms.RandomErasing(value="random"),
])
transform_test = transforms.Compose([
transforms.Resize(32, torchvision.transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
class Cutout(object):
def __init__(self, length):
self.length = length
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask
return img
#transform_train.transforms.append(Cutout(8))
#transform_train.transforms.append(Cutout(10))
BATCH_SIZE = 64
train_dataset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
class AuxiliaryHeadCIFAR(nn.Module):
def __init__(self, C, num_classes):
"""assuming input size 8x8"""
super(AuxiliaryHeadCIFAR, self).__init__()
self.features = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.classifier = nn.Linear(768, num_classes)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0),-1))
return x
class ImageProcess(nn.Module):
def __init__(self, ch=32):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, ch, 3, 1, 1, bias=True),
nn.SiLU(),
#nn.Conv2d(32, ch, 1, bias=True),
nn.MaxPool2d(3, 2, 1)
)
def forward(self, x):
x = self.conv(x)
return x
class Attention(nn.Module):
def __init__(self, emb_dim, n_head):
super().__init__()
# key, query, value projections for all heads
self.key = nn.Linear(emb_dim, emb_dim, bias=False)
self.query = nn.Linear(emb_dim, emb_dim, bias=False)
self.value = nn.Linear(emb_dim, emb_dim, bias=False)
self.attn_drop = nn.Dropout(0.1)
# output projection
self.proj = nn.Linear(emb_dim, emb_dim)
self.n_head = n_head
self.norm = nn.LayerNorm(emb_dim)
def forward(self, x):
B, T, C = x.size()
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = F.softmax(att, dim=-1)
att = self.attn_drop(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
#y = self.norm(y)
#y = F.gelu(y)
# output projection
y = self.proj(y)
return y
class SelfAttention(nn.Module):
""" an unassuming Transformer block """
def __init__(self, emb_dim, n_head, s):
super().__init__()
self.emb_dim = emb_dim
self.s = s
self.c = nn.Conv2d(emb_dim, emb_dim, 3, 1, 1, groups=emb_dim, bias=False)
self.norm1 = nn.LayerNorm(emb_dim)
self.norm2 = nn.LayerNorm(emb_dim)
#self.norm1 = nn.BatchNorm1d(emb_dim)
#self.norm2 = nn.BatchNorm1d(emb_dim)
self.norm3 = nn.LayerNorm(emb_dim)
self.attn = Attention(emb_dim, n_head)
self.mlp = nn.Sequential(
nn.Linear(emb_dim, 2 * emb_dim),
#nn.LayerNorm(2*emb_dim),
nn.GELU(),
#nn.LayerNorm(2*emb_dim),
nn.Linear(2 * emb_dim , emb_dim),
#nn.GELU(),
#nn.LayerNorm(emb_dim)
)
def drop_path(self, x, drop_prob):
if self.training:
keep_prob = 1.-drop_prob
mask = torch.FloatTensor(x.size(0), 1, 1).bernoulli_(keep_prob).to(device)
x.div_(keep_prob)
x.mul_(mask)
return x
def forward(self, x):
B, _, _ = x.size()
x1 = x.reshape([B, self.emb_dim, self.s, self.s])
x1 = self.c(x1)
x1 = x1.reshape([B, self.s*self.s, self.emb_dim])
x = x + self.drop_path(self.attn(self.norm1(x1)), 0.0)
x = x + self.drop_path(self.mlp(self.norm2(x)), 0.0)
#x = x + self.norm1(self.drop_path(self.attn(x), 0.0))
#x = x + self.norm2(self.drop_path(self.mlp(x), 0.0))
return x
import math
class Encoder(nn.Module):
def __init__(self, size=16, emb_dim=32, n_head=4, n_layers=1):
super().__init__()
sizes = [size, size//2, size//4]
self.b1 = nn.Sequential(*[SelfAttention(32, 2, 16) for _ in range(1)])
self.l1 = nn.Linear(32, 92)
self.p1 = nn.MaxPool2d(2, 2)
self.b2 = nn.Sequential(*[SelfAttention(92, 2, 8) for _ in range(1)])
self.l2 = nn.Linear(92, 256)
self.p2 = nn.MaxPool2d(2, 2)
self.b3 = nn.Sequential(*[SelfAttention(256, 2, 4) for _ in range(2)])
def forward(self, x):
B, _, _ = x.size()
x = self.b1(x)
x = x.reshape([B, 32, 16, 16])
x = self.p1(x)
x = x.reshape([B, 8*8, 32])
x = self.l1(x)
x = self.b2(x)
#x1 = x
x = x.reshape([B, 92, 8, 8])
x = self.p2(x)
x = x.reshape([B, 4*4, 92])
x = self.l2(x)
x = self.b3(x)
return x
class TF(nn.Module):
def __init__(self, n_class=10, emb_dim=192, n_head=3, n_layers=7):
super().__init__()
self.image_process = ImageProcess(32)
self.flattener = nn.Flatten(2, 3)
self.enc = Encoder(size=32, emb_dim=emb_dim, n_head=n_head, n_layers=n_layers)
#self.pe = nn.Parameter(torch.zeros(1, 16*16, 32), requires_grad=True)
self.pe = nn.Parameter(self.emb_sine(16*16, 32), requires_grad=False)
self.seq_pool = nn.Linear(emb_dim, 1)
self.last_layer = nn.Linear(emb_dim, n_class)
def emb_sine(self, n_channels, dim):
pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
for p in range(n_channels)])
pe[:, 0::2] = torch.sin(pe[:, 0::2])
pe[:, 1::2] = torch.cos(pe[:, 1::2])
return pe.unsqueeze(0)
def forward(self, x):
x = self.image_process(x)
"""
B, C, H, W = x.size()
axis_pos = [torch.linspace(-1.0, 1.0, size) for size in (H, W)]
p = torch.stack(torch.meshgrid(*axis_pos), dim=-1).to(device)
p = p.permute(2, 0, 1)
p = p.repeat(B, 1, 1, 1)
x = torch.cat([x, p], dim=1)
"""
x = self.flattener(x).transpose(-2, -1)
x += self.pe
x = self.enc(x)
#x1 = x1.mean(dim=1)
x = x.mean(dim=1)
#print(x.shape)
#print(x1.shape)
#x = torch.cat([x1, x], dim=1)
#print(x.shape)
#x = torch.matmul(F.softmax(self.seq_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2)
x = self.last_layer(x)
return x
def CT():
model = TF(n_class=10, emb_dim=256, n_head=1, n_layers=6)
return model
def resnet18(pretrained=False, **kwargs):
model = BNet()
return model
class LabelSmoothingCrossEntropy(nn.Module):
"""
NLL loss with label smoothing.
"""
def __init__(self, smoothing=0.1):
"""
Constructor for the LabelSmoothing module.
:param smoothing: label smoothing factor
"""
super(LabelSmoothingCrossEntropy, self).__init__()
assert smoothing < 1.0
self.smoothing = smoothing
self.confidence = 1. - smoothing
def _compute_losses(self, x, target):
log_prob = F.log_softmax(x, dim=-1)
nll_loss = -log_prob.gather(dim=-1, index=target.unsqueeze(1))
nll_loss = nll_loss.squeeze(1)
smooth_loss = -log_prob.mean(dim=-1)
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
return loss
def forward(self, x, target):
return self._compute_losses(x, target).mean()
#net = resnet18()
net = CT()
net = net.to(device)
if HALF:
net = net.half()
train_criterion = LabelSmoothingCrossEntropy().to(device)
criterion = nn.CrossEntropyLoss().to(device)
biases = []
weights = []
"""
optimizer = optim.SGD([
{"params": weights},
#{"params": biases}
{"params": biases, "weight_decay": 0.0}
], lr=1e-1, momentum=0.9, weight_decay=1e-4, nesterov=True)
"""
"""
optimizer = optim.SGD(
net.parameters(),
lr=1e-3,
momentum=0.9,
weight_decay=1e-6)
"""
optimizer = optim.AdamW(
net.parameters(),
lr=0.001,
weight_decay=1e-4
)
#scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30,60], gamma=0.1)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)
#scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 200)
EPOCH = 200
def train(epoch, optimizer, scheduler):
print('\nEpoch: %d' % epoch)
net.to(device)
net.train()
print(scheduler.get_last_lr())
train_loss = 0
correct = 0
total = 0
cos_sim = nn.CosineSimilarity(dim=1).to(device)
for batch_idx, (inputs, targets) in enumerate(train_loader):
#net.drop_path_prob = 0.1 * epoch / EPOCH
inputs, targets = inputs.to(device), targets.to(device)
if HALF:
inputs, targets = inputs.half(), targets.half()
optimizer.zero_grad()
outputs = net(inputs)
loss = train_criterion(outputs, targets)
#with torch.no_grad():
# out2 = net(inputs)
#out2 = out2.detach()
#loss = loss*0.1 + (-0.9 * cos_sim(outputs, out2).mean())
loss.backward()
if TPU:
xm.optimizer_step(optimizer, barrier=True)
else:
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
if batch_idx % 200 == 0:
print('%.3f | %.3f%% (%d/%d)' % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
print('Acc: %.3f%% (%d/%d)' % (100.*correct/total, correct, total))
scheduler.step()
def test(epoch, optimizer):
net.eval()
test_loss = 0
correct = 0
correct_aux = 0
total = 0
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(test_loader):
#optimizer.zero_grad()
inputs, targets = inputs.to(device), targets.to(device)
if HALF:
inputs, targets = inputs.half(), targets.half()
outputs = net(inputs)
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
#_, predicted = aux.max(1)
#correct_aux += predicted.eq(targets).sum().item()
print('Eval %.3f%% (%d/%d)' % (100.*correct/total, correct, total))
#print('Eval %.3f%% (%d/%d)' % (100.*correct_aux/total, correct_aux, total))
for epoch in range(0, EPOCH):
train(epoch, optimizer, scheduler)
test(epoch, optimizer)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment