Skip to content

Instantly share code, notes, and snippets.

Last active February 20, 2019 18:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save muammar/6d298f7fa5700efdbd01ae697f63e089 to your computer and use it in GitHub Desktop.
Save muammar/6d298f7fa5700efdbd01ae697f63e089 to your computer and use it in GitHub Desktop.
Example Error of a Neural Network trained for regression with Pytorch
import time
import datetime
import torch
from import parity
class NeuralNetwork(torch.nn.Module):
"""Neural Network Regression with Pytorch
hiddenlayers : tuple
Structure of hidden layers in the neural network.
activation : str
The activation function.
def __init__(self, hiddenlayers=(3, 3), activation='relu'):
super(NeuralNetwork, self).__init__()
self.hiddenlayers = hiddenlayers
self.activation = activation
def prepare_model(self, input_dimension, data=None):
"""Prepare the model
input_dimension : int
Input's dimension.
data : object
DataSet object created from the handler.
activation = {'tanh': torch.nn.Tanh, 'relu': torch.nn.ReLU,
'celu': torch.nn.CELU}
print('Model Training')
print('Number of hidden-layers: {}' .format(len(self.hiddenlayers)))
print('Structure of Neural Net: {}' .
format('(input, ' + str(self.hiddenlayers)[1:-1] + ', output)'))
layers = range(len(self.hiddenlayers) + 1)
unique_element_symbols = data.unique_element_symbols['trainingset']
symbol_model_pair = []
self.output_layer_index = {}
for symbol in unique_element_symbols:
linears = []
intercept = (data.max_energy + data.min_energy) / 2.
intercept = torch.nn.Parameter(torch.tensor(intercept,
slope = (data.max_energy - data.min_energy) / 2.
slope = torch.nn.Parameter(torch.tensor(slope, requires_grad=True))
print(intercept, slope)
intercept_name = 'intercept_' + symbol
slope_name = 'slope_' + symbol
self.register_parameter(intercept_name, intercept)
self.register_parameter(slope_name, slope)
for index in layers:
# This is the input layer
if index == 0:
out_dimension = self.hiddenlayers[0]
_linear = torch.nn.Linear(input_dimension,
# This is the output layer
elif index == len(self.hiddenlayers):
inp_dimension = self.hiddenlayers[index - 1]
out_dimension = 1
self.output_layer_index[symbol] = index
_linear = torch.nn.Linear(inp_dimension, out_dimension)
# These are hidden-layers
inp_dimension = self.hiddenlayers[index - 1]
out_dimension = self.hiddenlayers[index]
_linear = torch.nn.Linear(inp_dimension, out_dimension)
# Stacking up the layers.
linears = torch.nn.Sequential(*linears)
symbol_model_pair.append([symbol, linears])
self.linears = torch.nn.ModuleDict(symbol_model_pair)
# Iterate over all modules and just intialize those that are a linear
# layer.
for m in self.modules():
if isinstance(m, torch.nn.Linear):
# nn.init.normal_(m.weight) # , mean=0, std=0.01)
def forward(self, X):
"""Forward propagation
This is forward propagation and it returns the atomic energy.
X : list
List of inputs in the feature space.
outputs : tensor
A list of tensors with energies per image.
outputs = []
for hash in X:
image = X[hash]
atomic_energies = []
for symbol, x in image:
x = self.linears[symbol](x)
intercept_name = 'intercept_' + symbol
slope_name = 'slope_' + symbol
for name, param in self.named_parameters():
if intercept_name == name:
intercept = param
elif slope_name == name:
slope = param
x = (slope * x) + intercept
atomic_energies =
image_energy = torch.sum(atomic_energies)
outputs = torch.stack(outputs)
return outputs
def train(self, inputs, targets, model=None, data=None, optimizer=None,
lr=None, weight_decay=None, regularization=None, epochs=100,
convergence=None, lossfxn=None):
"""Train the model
inputs : dict
Dictionary with hashed feature space.
epochs : int
Number of full training cycles.
targets : list
The expected values that the model has to learn aka y.
model : object
The NeuralNetwork class.
data : object
DataSet object created from the handler.
lr : float
Learning rate.
weight_decay : float
Weight decay passed to the optimizer. Default is 0.
regularization : float
This is the L2 regularization. It is not the same as weight decay.
convergence : dict
Instead of using epochs, users can set a convergence criterion.
lossfxn : obj
A loss function object.
#old_state_dict = {}
#for key in model.state_dict():
# old_state_dict[key] = model.state_dict()[key].clone()
targets = torch.tensor(targets, requires_grad=False)
# Define optimizer
if optimizer is None:
optimizer = torch.optim.Adam(model.parameters(), lr=lr,
print('{:6s} {:19s} {:8s}'.format('Epoch', 'Time Stamp', 'Loss'))
print('{:6s} {:19s} {:8s}'.format('------',
'-------------------', '---------'))
initial_time = time.time()
_loss = []
_rmse = []
epoch = 0
while True:
epoch += 1
outputs = model(inputs)
if lossfxn is None:
loss, rmse = self.loss_function(outputs, targets, optimizer, data)
raise('I do not know what to do')
ts = time.time()
ts = datetime.datetime.fromtimestamp(ts).strftime('%Y-%m-%d '
print('{:6d} {} {:8e} {:8f}' .format(epoch, ts, loss, rmse))
if convergence is None and epoch == epochs:
elif (convergence is not None and rmse < convergence['energy']):
training_time = time.time() - initial_time
print('Training the model took {}...' .format(training_time))
import matplotlib.pyplot as plt
plt.plot(list(range(epoch)), _loss, label='loss')
plt.plot(list(range(epoch)), _rmse, label='rmse/atom')
plt.legend(loc='upper left')
parity(outputs.detach().numpy(), targets.detach().numpy())
#new_state_dict = {}
#for key in model.state_dict():
# new_state_dict[key] = model.state_dict()[key].clone()
#for key in old_state_dict:
# if not (old_state_dict[key] == new_state_dict[key]).all():
# print('Diff in {}'.format(key))
# else:
# print('No diff in {}'.format(key))
#for symbol in data.unique_element_symbols['trainingset']:
# model = model.linears[symbol]
# print('Optimized parameters for {} symbol' .format(symbol))
# for index, param in enumerate(model.parameters()):
# print('Index {}' .format(index))
# print(param)
# try:
# print('Gradient', param.grad.sum())
# except AttributeError:
# print('No gradient?')
# print()
def loss_function(self, outputs, targets, optimizer, data):
"""Default loss function
If user does not input loss function we provide mean-squared error loss
outputs : tensor
Outputs of the model.
targets : tensor
Expected value of outputs.
optimizer : obj
An optimizer object to minimize the loss function error.
data : obj
A data object from mlchem.
loss : tensor
The value of the loss function.
rmse : float
Value of the root-mean squared error per atom.
optimizer.zero_grad() # clear previous gradients
criterion = torch.nn.MSELoss(reduction='sum')
atoms_per_image = torch.tensor(data.atoms_per_image,
outputs_atom = torch.div(outputs, atoms_per_image)
targets_atom = torch.div(targets, atoms_per_image)
loss = criterion(outputs_atom, targets_atom) * .5
rmse = torch.sqrt(loss).item()
return loss, rmse
Copy link

muammar commented Feb 4, 2019

Output of running script above

Model Training
Number of hidden-layers: 2
Structure of Neural Net: (input, 400, 400, output)
Parameter containing:
tensor(-3.6438593864, requires_grad=True) Parameter containing:
tensor(0.0028590607, requires_grad=True)
  (Cu): Sequential(
    (0): Linear(in_features=8, out_features=400, bias=True)
    (1): ReLU()
    (2): Linear(in_features=400, out_features=400, bias=True)
    (3): ReLU()
    (4): Linear(in_features=400, out_features=1, bias=True)

Epoch  Time Stamp          Loss    
------ ------------------- ---------
     1 2019-02-20 10:08:03 9.920674e-06 0.003150
     2 2019-02-20 10:08:03 9.538131e-06 0.003088
     3 2019-02-20 10:08:03 9.195143e-06 0.003032
     4 2019-02-20 10:08:03 8.905964e-06 0.002984
     5 2019-02-20 10:08:03 8.672600e-06 0.002945
     6 2019-02-20 10:08:03 8.488381e-06 0.002913
     7 2019-02-20 10:08:03 8.340588e-06 0.002888
     8 2019-02-20 10:08:03 8.215396e-06 0.002866
     9 2019-02-20 10:08:03 8.098184e-06 0.002846
    10 2019-02-20 10:08:03 7.976991e-06 0.002824
    11 2019-02-20 10:08:03 7.848244e-06 0.002801
    12 2019-02-20 10:08:03 7.708543e-06 0.002776
    13 2019-02-20 10:08:03 7.558343e-06 0.002749
    14 2019-02-20 10:08:03 7.399570e-06 0.002720
    15 2019-02-20 10:08:03 7.231776e-06 0.002689
    16 2019-02-20 10:08:03 7.055909e-06 0.002656
    17 2019-02-20 10:08:03 6.874913e-06 0.002622
    18 2019-02-20 10:08:03 6.688748e-06 0.002586
    19 2019-02-20 10:08:03 6.495355e-06 0.002549
    20 2019-02-20 10:08:03 6.293727e-06 0.002509
    21 2019-02-20 10:08:03 6.084394e-06 0.002467
    22 2019-02-20 10:08:03 5.864506e-06 0.002422
    23 2019-02-20 10:08:03 5.631352e-06 0.002373
    24 2019-02-20 10:08:03 5.385046e-06 0.002321
    25 2019-02-20 10:08:03 5.124490e-06 0.002264
    26 2019-02-20 10:08:03 4.850896e-06 0.002202
    27 2019-02-20 10:08:03 4.565629e-06 0.002137
    28 2019-02-20 10:08:03 4.270406e-06 0.002066
    29 2019-02-20 10:08:03 3.971060e-06 0.001993
    30 2019-02-20 10:08:03 3.665750e-06 0.001915
    31 2019-02-20 10:08:03 3.360189e-06 0.001833
    32 2019-02-20 10:08:03 3.054921e-06 0.001748
    33 2019-02-20 10:08:03 2.754281e-06 0.001660
    34 2019-02-20 10:08:03 2.458273e-06 0.001568
    35 2019-02-20 10:08:03 2.169354e-06 0.001473
    36 2019-02-20 10:08:03 1.888878e-06 0.001374
    37 2019-02-20 10:08:03 1.619132e-06 0.001272
    38 2019-02-20 10:08:03 1.361321e-06 0.001167
    39 2019-02-20 10:08:03 1.119603e-06 0.001058
    40 2019-02-20 10:08:03 8.962875e-07 0.000947
    41 2019-02-20 10:08:03 6.948774e-07 0.000834
    42 2019-02-20 10:08:03 5.181893e-07 0.000720
    43 2019-02-20 10:08:03 3.684362e-07 0.000607
    44 2019-02-20 10:08:03 2.454774e-07 0.000495
    45 2019-02-20 10:08:03 1.495960e-07 0.000387
    46 2019-02-20 10:08:03 7.985818e-08 0.000283
    47 2019-02-20 10:08:03 3.384619e-08 0.000184
    48 2019-02-20 10:08:03 8.882864e-09 0.000094
    49 2019-02-20 10:08:03 2.185772e-09 0.000047
    50 2019-02-20 10:08:03 1.046797e-08 0.000102
    51 2019-02-20 10:08:03 3.013091e-08 0.000174
    52 2019-02-20 10:08:03 5.732628e-08 0.000239
    53 2019-02-20 10:08:03 8.805344e-08 0.000297
    54 2019-02-20 10:08:03 1.181469e-07 0.000344
    55 2019-02-20 10:08:03 1.443455e-07 0.000380
    56 2019-02-20 10:08:03 1.642641e-07 0.000405
    57 2019-02-20 10:08:03 1.764898e-07 0.000420
    58 2019-02-20 10:08:03 1.806377e-07 0.000425
    59 2019-02-20 10:08:03 1.772145e-07 0.000421
    60 2019-02-20 10:08:03 1.676956e-07 0.000410
    61 2019-02-20 10:08:03 1.530057e-07 0.000391
    62 2019-02-20 10:08:03 1.349278e-07 0.000367
    63 2019-02-20 10:08:03 1.145393e-07 0.000338
    64 2019-02-20 10:08:03 9.338206e-08 0.000306
    65 2019-02-20 10:08:03 7.279360e-08 0.000270
    66 2019-02-20 10:08:03 5.376961e-08 0.000232
    67 2019-02-20 10:08:03 3.742230e-08 0.000193
    68 2019-02-20 10:08:03 2.403945e-08 0.000155
    69 2019-02-20 10:08:03 1.392948e-08 0.000118
    70 2019-02-20 10:08:03 6.882885e-09 0.000083
    71 2019-02-20 10:08:03 2.573529e-09 0.000051
    72 2019-02-20 10:08:03 5.660752e-10 0.000024
    73 2019-02-20 10:08:03 3.742855e-10 0.000019
    74 2019-02-20 10:08:03 1.514451e-09 0.000039
    75 2019-02-20 10:08:03 3.598899e-09 0.000060
    76 2019-02-20 10:08:03 6.195393e-09 0.000079
    77 2019-02-20 10:08:03 8.959091e-09 0.000095
    78 2019-02-20 10:08:03 1.149559e-08 0.000107
    79 2019-02-20 10:08:03 1.359345e-08 0.000117
    80 2019-02-20 10:08:03 1.508400e-08 0.000123
    81 2019-02-20 10:08:03 1.587210e-08 0.000126
    82 2019-02-20 10:08:03 1.597419e-08 0.000126
    83 2019-02-20 10:08:03 1.545939e-08 0.000124
    84 2019-02-20 10:08:03 1.451104e-08 0.000120
    85 2019-02-20 10:08:03 1.306447e-08 0.000114
    86 2019-02-20 10:08:03 1.139801e-08 0.000107
    87 2019-02-20 10:08:03 9.582664e-09 0.000098
    88 2019-02-20 10:08:03 7.716636e-09 0.000088
    89 2019-02-20 10:08:03 5.900489e-09 0.000077
    90 2019-02-20 10:08:03 4.236540e-09 0.000065
    91 2019-02-20 10:08:03 2.840380e-09 0.000053
    92 2019-02-20 10:08:03 1.724629e-09 0.000042
    93 2019-02-20 10:08:03 9.029577e-10 0.000030
    94 2019-02-20 10:08:03 3.696243e-10 0.000019
    95 2019-02-20 10:08:03 9.418955e-11 0.000010
    96 2019-02-20 10:08:03 2.742695e-11 0.000005
    97 2019-02-20 10:08:03 1.182912e-10 0.000011
    98 2019-02-20 10:08:03 3.190337e-10 0.000018
    99 2019-02-20 10:08:03 5.844640e-10 0.000024
   100 2019-02-20 10:08:03 8.814425e-10 0.000030
Training the model took 0.6154983043670654...
tensor([-14.5869894028, -14.5638799667], grad_fn=<StackBackward>)
tensor([-14.5868730545, -14.5640010834])

Copy link

muammar commented Feb 20, 2019

Loss Function

Parity Plot

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment