Last active
August 23, 2018 09:15
-
-
Save shang-vikas/1b700f2029ae594a967c81c39f743573 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
from torch import optim | |
from collections import OrderedDict | |
import numpy as np | |
import warnings | |
from torchvision.datasets import mnist | |
ds_t = mnist.MNIST('./data/') | |
from itertools import chain | |
class SimpleNet(nn.Module): | |
def __init__(self,ipt_sz,h_units=[128,128],num_classes=2): | |
super().__init__() | |
self.h_units = h_units | |
self.num_classes = num_classes | |
self._c = 0 | |
self.ipt_sz = ipt_sz | |
self.__linear = nn.Linear | |
self.__relu = nn.ReLU | |
self._optim = None | |
self._build() | |
def _build(self): | |
lyr_ = [] | |
lyr_.append(('l0', nn.Linear(self.ipt_sz,self.h_units[0]))) | |
lyr_.append(('r0', nn.ReLU())) | |
lyr_.extend([func(idx) for idx in range(1,len(self.h_units)) for func in (self._linear,self._relu)]) | |
self._fc = nn.Sequential(OrderedDict(lyr_)) | |
self._clf = nn.Sequential(OrderedDict([ | |
('last_ly',nn.Linear(self.h_units[-1],self.num_classes))])) | |
def _linear(self,idx): | |
return 'l'+str(idx),self.__linear(self.h_units[idx-1],self.h_units[idx]) | |
def _relu(self,idx): | |
self._c+=1 | |
return 'r'+str(self._c),self.__relu() | |
def forward(self,x): | |
y = self._fc(x) | |
y = self._clf(y) | |
return y | |
def train_(self,num_epochs,dtrain,batch_size=32,optimizer=None,criterion=None,lr=0.01,verbose=50): | |
if num_epochs==0: raise ValueError('num_epochs cant be zero') | |
if (not isinstance(dtrain[0],torch.Tensor)) or len(dtrain[0])==0 or len(dtrain[1])==0: raise ValueError("data can't be empty") | |
# if optimizer is not None: if not isinstance(optimizer(),(torch.optim.Optimizer,torch.optim.SGD)) : raise TypeError("Wrong Optimizer passed") | |
if criterion is None: raise ValueError("crtierion can't be None") | |
self._optim = optimizer(filter(lambda x: x.requires_grad,self.parameters()),lr=lr) | |
for epoch in range(num_epochs): | |
for batch in range(dtrain[0].size(0)//32): | |
out = self.forward(dtrain[0][batch*batch_size:batch_size*(batch+1)].reshape(-1,28*28).float()) | |
loss = criterion(out,dtrain[1][batch*batch_size:batch_size*(batch+1)]) | |
self._optim.zero_grad() | |
loss.backward() | |
self._optim.step() | |
print('loss:{}'.format(loss.item())) | |
def freeze(self,_group_name = []): | |
self.check_ls(_group_name) | |
if _group_name ==[]: warnings.warn("All the layers except the last one will be freezed") | |
self.__operation(False,_group_name) | |
def unfreeze(self,_group_name=[]): | |
self.check_ls(_group_name) | |
if _group_name == []: warnings.warn('All the layers are now training!!') | |
self.__operation(True,_group_name) | |
def __operation(self,flag,_lgrps=[]): | |
_len = len(list(Mod.named_parameters())) | |
for idx,_l in enumerate(Mod.named_parameters()): | |
if _lgrps ==[] and idx> _len -3 and (not flag): _l[1].requires_grad = True | |
elif _lgrps ==[]: _l[1].requires_grad = flag | |
elif _l[0].split('.')[0] in _lgrps: _l[1].requires_grad = flag | |
def check_ls(self,_l): | |
_a = True if isinstance(_l,list) else False | |
if not _a: raise TypeError("Please pass a list") | |
def des(self): | |
print(dir(self)) | |
Mod = SimpleNet(28*28,[128,256,256,512],10) | |
for i in Mod.parameters(): | |
print(i[0],i[1]) | |
break | |
Mod.train_(2,(ds_t.train_data,ds_t.train_labels),optimizer=torch.optim.SGD,criterion=nn.CrossEntropyLoss()) | |
for i in Mod.parameters(): | |
print(i[0],i[1]) | |
break |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment