Skip to content

Instantly share code, notes, and snippets.

@muammar
Last active February 20, 2019 18:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save muammar/6d298f7fa5700efdbd01ae697f63e089 to your computer and use it in GitHub Desktop.
Save muammar/6d298f7fa5700efdbd01ae697f63e089 to your computer and use it in GitHub Desktop.
Example Error of a Neural Network trained for regression with Pytorch
import time
import datetime
import torch
from mlchemistry.data.visualization import parity
torch.set_printoptions(precision=10)
class NeuralNetwork(torch.nn.Module):
"""Neural Network Regression with Pytorch
Parameters
----------
hiddenlayers : tuple
Structure of hidden layers in the neural network.
activation : str
The activation function.
"""
def __init__(self, hiddenlayers=(3, 3), activation='relu'):
super(NeuralNetwork, self).__init__()
self.hiddenlayers = hiddenlayers
self.activation = activation
def prepare_model(self, input_dimension, data=None):
"""Prepare the model
Parameters
----------
input_dimension : int
Input's dimension.
data : object
DataSet object created from the handler.
"""
activation = {'tanh': torch.nn.Tanh, 'relu': torch.nn.ReLU,
'celu': torch.nn.CELU}
print()
print('Model Training')
print('==============')
print('Number of hidden-layers: {}' .format(len(self.hiddenlayers)))
print('Structure of Neural Net: {}' .
format('(input, ' + str(self.hiddenlayers)[1:-1] + ', output)'))
layers = range(len(self.hiddenlayers) + 1)
unique_element_symbols = data.unique_element_symbols['trainingset']
symbol_model_pair = []
self.output_layer_index = {}
for symbol in unique_element_symbols:
linears = []
intercept = (data.max_energy + data.min_energy) / 2.
intercept = torch.nn.Parameter(torch.tensor(intercept,
requires_grad=True))
slope = (data.max_energy - data.min_energy) / 2.
slope = torch.nn.Parameter(torch.tensor(slope, requires_grad=True))
print(intercept, slope)
intercept_name = 'intercept_' + symbol
slope_name = 'slope_' + symbol
self.register_parameter(intercept_name, intercept)
self.register_parameter(slope_name, slope)
for index in layers:
# This is the input layer
if index == 0:
out_dimension = self.hiddenlayers[0]
_linear = torch.nn.Linear(input_dimension,
out_dimension)
linears.append(_linear)
linears.append(activation[self.activation]())
# This is the output layer
elif index == len(self.hiddenlayers):
inp_dimension = self.hiddenlayers[index - 1]
out_dimension = 1
self.output_layer_index[symbol] = index
_linear = torch.nn.Linear(inp_dimension, out_dimension)
linears.append(_linear)
# These are hidden-layers
else:
inp_dimension = self.hiddenlayers[index - 1]
out_dimension = self.hiddenlayers[index]
_linear = torch.nn.Linear(inp_dimension, out_dimension)
linears.append(_linear)
linears.append(activation[self.activation]())
# Stacking up the layers.
linears = torch.nn.Sequential(*linears)
symbol_model_pair.append([symbol, linears])
self.linears = torch.nn.ModuleDict(symbol_model_pair)
print(self.linears)
# Iterate over all modules and just intialize those that are a linear
# layer.
for m in self.modules():
if isinstance(m, torch.nn.Linear):
# nn.init.normal_(m.weight) # , mean=0, std=0.01)
torch.nn.init.xavier_uniform_(m.weight)
def forward(self, X):
"""Forward propagation
This is forward propagation and it returns the atomic energy.
Parameters
----------
X : list
List of inputs in the feature space.
Returns
-------
outputs : tensor
A list of tensors with energies per image.
"""
outputs = []
for hash in X:
image = X[hash]
atomic_energies = []
for symbol, x in image:
x = self.linears[symbol](x)
intercept_name = 'intercept_' + symbol
slope_name = 'slope_' + symbol
for name, param in self.named_parameters():
if intercept_name == name:
intercept = param
elif slope_name == name:
slope = param
x = (slope * x) + intercept
atomic_energies.append(x)
atomic_energies = torch.cat(atomic_energies)
image_energy = torch.sum(atomic_energies)
outputs.append(image_energy)
outputs = torch.stack(outputs)
return outputs
def train(self, inputs, targets, model=None, data=None, optimizer=None,
lr=None, weight_decay=None, regularization=None, epochs=100,
convergence=None, lossfxn=None):
"""Train the model
Parameters
----------
inputs : dict
Dictionary with hashed feature space.
epochs : int
Number of full training cycles.
targets : list
The expected values that the model has to learn aka y.
model : object
The NeuralNetwork class.
data : object
DataSet object created from the handler.
lr : float
Learning rate.
weight_decay : float
Weight decay passed to the optimizer. Default is 0.
regularization : float
This is the L2 regularization. It is not the same as weight decay.
convergence : dict
Instead of using epochs, users can set a convergence criterion.
lossfxn : obj
A loss function object.
"""
#old_state_dict = {}
#for key in model.state_dict():
# old_state_dict[key] = model.state_dict()[key].clone()
targets = torch.tensor(targets, requires_grad=False)
# Define optimizer
if optimizer is None:
optimizer = torch.optim.Adam(model.parameters(), lr=lr,
weight_decay=weight_decay)
print()
print('{:6s} {:19s} {:8s}'.format('Epoch', 'Time Stamp', 'Loss'))
print('{:6s} {:19s} {:8s}'.format('------',
'-------------------', '---------'))
initial_time = time.time()
_loss = []
_rmse = []
epoch = 0
while True:
epoch += 1
outputs = model(inputs)
if lossfxn is None:
loss, rmse = self.loss_function(outputs, targets, optimizer, data)
else:
raise('I do not know what to do')
_loss.append(loss)
_rmse.append(rmse)
ts = time.time()
ts = datetime.datetime.fromtimestamp(ts).strftime('%Y-%m-%d '
'%H:%M:%S')
print('{:6d} {} {:8e} {:8f}' .format(epoch, ts, loss, rmse))
if convergence is None and epoch == epochs:
break
elif (convergence is not None and rmse < convergence['energy']):
break
training_time = time.time() - initial_time
print('Training the model took {}...' .format(training_time))
print('outputs')
print(outputs)
print('targets')
print(targets)
import matplotlib.pyplot as plt
plt.plot(list(range(epoch)), _loss, label='loss')
plt.plot(list(range(epoch)), _rmse, label='rmse/atom')
plt.legend(loc='upper left')
plt.show()
parity(outputs.detach().numpy(), targets.detach().numpy())
#new_state_dict = {}
#for key in model.state_dict():
# new_state_dict[key] = model.state_dict()[key].clone()
#for key in old_state_dict:
# if not (old_state_dict[key] == new_state_dict[key]).all():
# print('Diff in {}'.format(key))
# else:
# print('No diff in {}'.format(key))
#print()
#for symbol in data.unique_element_symbols['trainingset']:
# model = model.linears[symbol]
# print('Optimized parameters for {} symbol' .format(symbol))
# for index, param in enumerate(model.parameters()):
# print('Index {}' .format(index))
# print(param)
# try:
# print('Gradient', param.grad.sum())
# except AttributeError:
# print('No gradient?')
# print()
def loss_function(self, outputs, targets, optimizer, data):
"""Default loss function
If user does not input loss function we provide mean-squared error loss
function.
Parameters
----------
outputs : tensor
Outputs of the model.
targets : tensor
Expected value of outputs.
optimizer : obj
An optimizer object to minimize the loss function error.
data : obj
A data object from mlchem.
Returns
-------
loss : tensor
The value of the loss function.
rmse : float
Value of the root-mean squared error per atom.
"""
optimizer.zero_grad() # clear previous gradients
criterion = torch.nn.MSELoss(reduction='sum')
atoms_per_image = torch.tensor(data.atoms_per_image,
requires_grad=False,
dtype=torch.float)
outputs_atom = torch.div(outputs, atoms_per_image)
targets_atom = torch.div(targets, atoms_per_image)
loss = criterion(outputs_atom, targets_atom) * .5
loss.backward()
optimizer.step()
rmse = torch.sqrt(loss).item()
return loss, rmse
@muammar
Copy link
Author

muammar commented Feb 4, 2019

Output of running script above

Model Training
==============
Number of hidden-layers: 2
Structure of Neural Net: (input, 400, 400, output)
Parameter containing:
tensor(-3.6438593864, requires_grad=True) Parameter containing:
tensor(0.0028590607, requires_grad=True)
ModuleDict(
  (Cu): Sequential(
    (0): Linear(in_features=8, out_features=400, bias=True)
    (1): ReLU()
    (2): Linear(in_features=400, out_features=400, bias=True)
    (3): ReLU()
    (4): Linear(in_features=400, out_features=1, bias=True)
  )
)

Epoch  Time Stamp          Loss    
------ ------------------- ---------
     1 2019-02-20 10:08:03 9.920674e-06 0.003150
     2 2019-02-20 10:08:03 9.538131e-06 0.003088
     3 2019-02-20 10:08:03 9.195143e-06 0.003032
     4 2019-02-20 10:08:03 8.905964e-06 0.002984
     5 2019-02-20 10:08:03 8.672600e-06 0.002945
     6 2019-02-20 10:08:03 8.488381e-06 0.002913
     7 2019-02-20 10:08:03 8.340588e-06 0.002888
     8 2019-02-20 10:08:03 8.215396e-06 0.002866
     9 2019-02-20 10:08:03 8.098184e-06 0.002846
    10 2019-02-20 10:08:03 7.976991e-06 0.002824
    11 2019-02-20 10:08:03 7.848244e-06 0.002801
    12 2019-02-20 10:08:03 7.708543e-06 0.002776
    13 2019-02-20 10:08:03 7.558343e-06 0.002749
    14 2019-02-20 10:08:03 7.399570e-06 0.002720
    15 2019-02-20 10:08:03 7.231776e-06 0.002689
    16 2019-02-20 10:08:03 7.055909e-06 0.002656
    17 2019-02-20 10:08:03 6.874913e-06 0.002622
    18 2019-02-20 10:08:03 6.688748e-06 0.002586
    19 2019-02-20 10:08:03 6.495355e-06 0.002549
    20 2019-02-20 10:08:03 6.293727e-06 0.002509
    21 2019-02-20 10:08:03 6.084394e-06 0.002467
    22 2019-02-20 10:08:03 5.864506e-06 0.002422
    23 2019-02-20 10:08:03 5.631352e-06 0.002373
    24 2019-02-20 10:08:03 5.385046e-06 0.002321
    25 2019-02-20 10:08:03 5.124490e-06 0.002264
    26 2019-02-20 10:08:03 4.850896e-06 0.002202
    27 2019-02-20 10:08:03 4.565629e-06 0.002137
    28 2019-02-20 10:08:03 4.270406e-06 0.002066
    29 2019-02-20 10:08:03 3.971060e-06 0.001993
    30 2019-02-20 10:08:03 3.665750e-06 0.001915
    31 2019-02-20 10:08:03 3.360189e-06 0.001833
    32 2019-02-20 10:08:03 3.054921e-06 0.001748
    33 2019-02-20 10:08:03 2.754281e-06 0.001660
    34 2019-02-20 10:08:03 2.458273e-06 0.001568
    35 2019-02-20 10:08:03 2.169354e-06 0.001473
    36 2019-02-20 10:08:03 1.888878e-06 0.001374
    37 2019-02-20 10:08:03 1.619132e-06 0.001272
    38 2019-02-20 10:08:03 1.361321e-06 0.001167
    39 2019-02-20 10:08:03 1.119603e-06 0.001058
    40 2019-02-20 10:08:03 8.962875e-07 0.000947
    41 2019-02-20 10:08:03 6.948774e-07 0.000834
    42 2019-02-20 10:08:03 5.181893e-07 0.000720
    43 2019-02-20 10:08:03 3.684362e-07 0.000607
    44 2019-02-20 10:08:03 2.454774e-07 0.000495
    45 2019-02-20 10:08:03 1.495960e-07 0.000387
    46 2019-02-20 10:08:03 7.985818e-08 0.000283
    47 2019-02-20 10:08:03 3.384619e-08 0.000184
    48 2019-02-20 10:08:03 8.882864e-09 0.000094
    49 2019-02-20 10:08:03 2.185772e-09 0.000047
    50 2019-02-20 10:08:03 1.046797e-08 0.000102
    51 2019-02-20 10:08:03 3.013091e-08 0.000174
    52 2019-02-20 10:08:03 5.732628e-08 0.000239
    53 2019-02-20 10:08:03 8.805344e-08 0.000297
    54 2019-02-20 10:08:03 1.181469e-07 0.000344
    55 2019-02-20 10:08:03 1.443455e-07 0.000380
    56 2019-02-20 10:08:03 1.642641e-07 0.000405
    57 2019-02-20 10:08:03 1.764898e-07 0.000420
    58 2019-02-20 10:08:03 1.806377e-07 0.000425
    59 2019-02-20 10:08:03 1.772145e-07 0.000421
    60 2019-02-20 10:08:03 1.676956e-07 0.000410
    61 2019-02-20 10:08:03 1.530057e-07 0.000391
    62 2019-02-20 10:08:03 1.349278e-07 0.000367
    63 2019-02-20 10:08:03 1.145393e-07 0.000338
    64 2019-02-20 10:08:03 9.338206e-08 0.000306
    65 2019-02-20 10:08:03 7.279360e-08 0.000270
    66 2019-02-20 10:08:03 5.376961e-08 0.000232
    67 2019-02-20 10:08:03 3.742230e-08 0.000193
    68 2019-02-20 10:08:03 2.403945e-08 0.000155
    69 2019-02-20 10:08:03 1.392948e-08 0.000118
    70 2019-02-20 10:08:03 6.882885e-09 0.000083
    71 2019-02-20 10:08:03 2.573529e-09 0.000051
    72 2019-02-20 10:08:03 5.660752e-10 0.000024
    73 2019-02-20 10:08:03 3.742855e-10 0.000019
    74 2019-02-20 10:08:03 1.514451e-09 0.000039
    75 2019-02-20 10:08:03 3.598899e-09 0.000060
    76 2019-02-20 10:08:03 6.195393e-09 0.000079
    77 2019-02-20 10:08:03 8.959091e-09 0.000095
    78 2019-02-20 10:08:03 1.149559e-08 0.000107
    79 2019-02-20 10:08:03 1.359345e-08 0.000117
    80 2019-02-20 10:08:03 1.508400e-08 0.000123
    81 2019-02-20 10:08:03 1.587210e-08 0.000126
    82 2019-02-20 10:08:03 1.597419e-08 0.000126
    83 2019-02-20 10:08:03 1.545939e-08 0.000124
    84 2019-02-20 10:08:03 1.451104e-08 0.000120
    85 2019-02-20 10:08:03 1.306447e-08 0.000114
    86 2019-02-20 10:08:03 1.139801e-08 0.000107
    87 2019-02-20 10:08:03 9.582664e-09 0.000098
    88 2019-02-20 10:08:03 7.716636e-09 0.000088
    89 2019-02-20 10:08:03 5.900489e-09 0.000077
    90 2019-02-20 10:08:03 4.236540e-09 0.000065
    91 2019-02-20 10:08:03 2.840380e-09 0.000053
    92 2019-02-20 10:08:03 1.724629e-09 0.000042
    93 2019-02-20 10:08:03 9.029577e-10 0.000030
    94 2019-02-20 10:08:03 3.696243e-10 0.000019
    95 2019-02-20 10:08:03 9.418955e-11 0.000010
    96 2019-02-20 10:08:03 2.742695e-11 0.000005
    97 2019-02-20 10:08:03 1.182912e-10 0.000011
    98 2019-02-20 10:08:03 3.190337e-10 0.000018
    99 2019-02-20 10:08:03 5.844640e-10 0.000024
   100 2019-02-20 10:08:03 8.814425e-10 0.000030
Training the model took 0.6154983043670654...
outputs
tensor([-14.5869894028, -14.5638799667], grad_fn=<StackBackward>)
targets
tensor([-14.5868730545, -14.5640010834])

@muammar
Copy link
Author

muammar commented Feb 20, 2019

loss
Loss Function

parity
Parity Plot

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment