Last active
March 20, 2020 12:25
-
-
Save faizankshaikh/382c31bd4663bc6aabbfc1c8ce40f3c5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
import numpy as np | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
import torch.utils.data.distributed | |
from torchvision import datasets, transforms | |
from torch.utils.data import DataLoader | |
## check how to create model input | |
class Model(nn.Module): | |
def __init__(self, model_pars=None, data_pars=None, compute_pars=None): | |
# get input_shape, hidden_shape, output_shape | |
input_shape = model_pars["input_shape"] | |
hidden_shape = model_pars["hidden_shape"] | |
output_shape = model_pars["output_shape"] | |
super(Model, self).__init__() | |
self.fc1 = nn.Linear(input_shape, hidden_shape) | |
self.fc2 = nn.Linear(hidden_shape, output_shape) | |
def forward(self, x): | |
x = x.view(-1, 784) | |
x = F.relu(self.fc1(x)) | |
x = self.fc2(x) | |
return F.log_softmax(x, dim=0) | |
def _train(model, data_loader, epoch, optimizer): | |
model.train() | |
for batch_idx, (data, target) in enumerate(data_loader): | |
optimizer.zero_grad() | |
output = model(data) | |
loss = F.nll_loss(output, target) | |
loss.backward() | |
optimizer.step() | |
return model | |
def _pred(model, data_loader): | |
model.eval() | |
test_loss = 0 | |
pred = torch.tensor([], dtype=torch.long) | |
with torch.no_grad(): | |
for data, target in data_loader: | |
output = model(data) | |
test_loss += F.nll_loss(output, target, size_average=False).item() | |
output = output.data.max(1)[1] | |
pred = torch.cat((pred, output), 0) | |
test_loss /= len(data_loader.dataset) | |
return pred.numpy() | |
def fit(model, data_pars={}, compute_pars={}, out_pars={}, **kw): | |
""" | |
""" | |
# get dataset | |
# get compute_params - epoch (optional) | |
# define optimizer | |
# do training | |
data_pars['train'] = True | |
data_loader = get_dataset(data_pars) | |
optimizer = optim.SGD(model.parameters(), lr=compute_pars['learning_rate']) | |
model = _train(model, data_loader, compute_pars['epoch'], optimizer) | |
sess = None | |
return model, sess | |
def predict(model, sess=None, data_pars={}, compute_pars={}, out_pars={}, **kw): | |
""" | |
""" | |
data_pars['train'] = False | |
data_loader = get_dataset(data_pars) | |
ypred = _pred(model, data_loader) | |
return ypred | |
def get_dataset(data_pars=None, **kw): | |
""" | |
""" | |
if data_pars['mode'] == 'test': | |
# get mnist | |
mnist_data_train = datasets.MNIST( | |
root=data_pars['path'], | |
download=True, | |
train=True, | |
transform=transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.1307,), (0.3081,)) | |
]) | |
) | |
train_data_loader = DataLoader(mnist_data_train, batch_size=64) | |
mnist_data_test = datasets.MNIST(root=data_pars['path'], download=True, train=False, transform=transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.1307,), (0.3081,)) | |
])) | |
test_data_loader = DataLoader(mnist_data_test, batch_size=1000) | |
if data_pars['train']: | |
return train_data_loader | |
else: | |
return test_data_loader | |
""" | |
# (optional) do scaling | |
""" | |
#return train_data_loader, test_data_loader Xtrain, ytrain, Xtest, ytest | |
def get_params(param_pars={}, **kw): | |
""" | |
""" | |
return model_pars, data_pars, compute_pars, out_pars | |
def test(data_path="", pars_choice="", config_mode=""): | |
""" | |
""" | |
return None | |
if __name__ == '__main__': | |
test(pars_choice="test01") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment