Skip to content

Instantly share code, notes, and snippets.

@peter0749
Created March 13, 2019 08:49
Show Gist options
  • Save peter0749/60963a864176ef598fbae5957efd3942 to your computer and use it in GitHub Desktop.
Save peter0749/60963a864176ef598fbae5957efd3942 to your computer and use it in GitHub Desktop.
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.nn import Parameter
from torchvision import datasets, transforms
class AngleLinear(nn.Module):
def __init__(self, in_features, out_features, m = 4):
super(AngleLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.Tensor(in_features,out_features))
self.weight.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5)
self.m = m
self.mlambda = [
lambda x: x**0,
lambda x: x**1,
lambda x: 2*x**2-1,
lambda x: 4*x**3-3*x,
lambda x: 8*x**4-8*x**2+1,
lambda x: 16*x**5-20*x**3+5*x
]
def forward(self, input):
x = input # size=(B,F) F is feature len
w = self.weight # size=(F,Classnum) F=in_features Classnum=out_features
ww = w.renorm(2,1,1e-5).mul(1e5)
xlen = x.pow(2).sum(1).pow(0.5) # size=B
wlen = ww.pow(2).sum(0).pow(0.5) # size=Classnum
cos_theta = x.mm(ww) # size=(B,Classnum)
cos_theta = cos_theta / xlen.view(-1,1) / wlen.view(1,-1)
cos_theta = cos_theta.clamp(-1,1)
cos_m_theta = self.mlambda[self.m](cos_theta)
theta = Variable(cos_theta.data.acos())
k = (self.m*theta/3.14159265).floor()
n_one = k*0.0 - 1
phi_theta = (n_one**k) * cos_m_theta - 2*k
cos_theta = cos_theta * xlen.view(-1,1)
phi_theta = phi_theta * xlen.view(-1,1)
output = (cos_theta,phi_theta)
return output # size=(B,Classnum,2)
class AngleLoss(nn.Module):
def __init__(self, gamma=0, test=False, mode='mean'):
super(AngleLoss, self).__init__()
self.gamma = gamma
self.it = 0
self.LambdaMin = 5.0
self.LambdaMax = 1500.0
self.lamb = 1500.0
self.test = test
self.mode = mode
def forward(self, input, target):
self.it += 1
cos_theta,phi_theta = input
target = target.view(-1,1) #size=(B,1)
index = cos_theta.data * 0.0 #size=(B,Classnum)
index.scatter_(1,target.data.view(-1,1),1)
index = index.byte()
index = Variable(index)
output = cos_theta * 1.0 #size=(B,Classnum)
if self.test:
output[index] = phi_theta[index]
else:
self.lamb = max(self.LambdaMin,self.LambdaMax/(1+0.1*self.it ))
output[index] -= cos_theta[index]*(1.0+0)/(1+self.lamb)
output[index] += phi_theta[index]*(1.0+0)/(1+self.lamb)
logpt = F.log_softmax(output, dim=-1)
logpt = logpt.gather(1,target)
logpt = logpt.view(-1)
pt = Variable(logpt.data.exp())
loss = -1 * (1-pt)**self.gamma * logpt
if self.mode=='mean':
loss = loss.mean()
else:
loss = loss.sum()
return loss
class Net(nn.Module):
def __init__(self, feature=False):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4*4*50, 500)
self.fc2 = AngleLinear(500, 10)
self.feature = feature
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
x = F.relu(self.fc1(x))
if self.feature:
return x
x = self.fc2(x)
return x
def train(args, model, device, train_loader, optimizer, epoch):
loss_func = AngleLoss()
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = loss_func(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test(args, model, device, test_loader):
loss_func = AngleLoss(test=True,mode='sum')
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += loss_func(output, target).item() # sum up batch loss
pred = output[0].argmax(dim=1, keepdim=True) # get the index of the max probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--save-model', action='store_true', default=False,
help='For Saving the current Model')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, test_loader)
if (args.save_model):
torch.save(model.state_dict(),"mnist_cnn.pt")
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment