Skip to content

Instantly share code, notes, and snippets.

@shang-vikas
Last active August 23, 2018 09:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shang-vikas/1b700f2029ae594a967c81c39f743573 to your computer and use it in GitHub Desktop.
Save shang-vikas/1b700f2029ae594a967c81c39f743573 to your computer and use it in GitHub Desktop.
import torch
from torch import nn
from torch import optim
from collections import OrderedDict
import numpy as np
import warnings
from torchvision.datasets import mnist
ds_t = mnist.MNIST('./data/')
from itertools import chain
class SimpleNet(nn.Module):
def __init__(self,ipt_sz,h_units=[128,128],num_classes=2):
super().__init__()
self.h_units = h_units
self.num_classes = num_classes
self._c = 0
self.ipt_sz = ipt_sz
self.__linear = nn.Linear
self.__relu = nn.ReLU
self._optim = None
self._build()
def _build(self):
lyr_ = []
lyr_.append(('l0', nn.Linear(self.ipt_sz,self.h_units[0])))
lyr_.append(('r0', nn.ReLU()))
lyr_.extend([func(idx) for idx in range(1,len(self.h_units)) for func in (self._linear,self._relu)])
self._fc = nn.Sequential(OrderedDict(lyr_))
self._clf = nn.Sequential(OrderedDict([
('last_ly',nn.Linear(self.h_units[-1],self.num_classes))]))
def _linear(self,idx):
return 'l'+str(idx),self.__linear(self.h_units[idx-1],self.h_units[idx])
def _relu(self,idx):
self._c+=1
return 'r'+str(self._c),self.__relu()
def forward(self,x):
y = self._fc(x)
y = self._clf(y)
return y
def train_(self,num_epochs,dtrain,batch_size=32,optimizer=None,criterion=None,lr=0.01,verbose=50):
if num_epochs==0: raise ValueError('num_epochs cant be zero')
if (not isinstance(dtrain[0],torch.Tensor)) or len(dtrain[0])==0 or len(dtrain[1])==0: raise ValueError("data can't be empty")
# if optimizer is not None: if not isinstance(optimizer(),(torch.optim.Optimizer,torch.optim.SGD)) : raise TypeError("Wrong Optimizer passed")
if criterion is None: raise ValueError("crtierion can't be None")
self._optim = optimizer(filter(lambda x: x.requires_grad,self.parameters()),lr=lr)
for epoch in range(num_epochs):
for batch in range(dtrain[0].size(0)//32):
out = self.forward(dtrain[0][batch*batch_size:batch_size*(batch+1)].reshape(-1,28*28).float())
loss = criterion(out,dtrain[1][batch*batch_size:batch_size*(batch+1)])
self._optim.zero_grad()
loss.backward()
self._optim.step()
print('loss:{}'.format(loss.item()))
def freeze(self,_group_name = []):
self.check_ls(_group_name)
if _group_name ==[]: warnings.warn("All the layers except the last one will be freezed")
self.__operation(False,_group_name)
def unfreeze(self,_group_name=[]):
self.check_ls(_group_name)
if _group_name == []: warnings.warn('All the layers are now training!!')
self.__operation(True,_group_name)
def __operation(self,flag,_lgrps=[]):
_len = len(list(Mod.named_parameters()))
for idx,_l in enumerate(Mod.named_parameters()):
if _lgrps ==[] and idx> _len -3 and (not flag): _l[1].requires_grad = True
elif _lgrps ==[]: _l[1].requires_grad = flag
elif _l[0].split('.')[0] in _lgrps: _l[1].requires_grad = flag
def check_ls(self,_l):
_a = True if isinstance(_l,list) else False
if not _a: raise TypeError("Please pass a list")
def des(self):
print(dir(self))
Mod = SimpleNet(28*28,[128,256,256,512],10)
for i in Mod.parameters():
print(i[0],i[1])
break
Mod.train_(2,(ds_t.train_data,ds_t.train_labels),optimizer=torch.optim.SGD,criterion=nn.CrossEntropyLoss())
for i in Mod.parameters():
print(i[0],i[1])
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment