Skip to content

Instantly share code, notes, and snippets.

@franperezlopez
Last active April 13, 2019 11:21
Show Gist options
  • Save franperezlopez/ef744a74d396cd38ac86eff19108485e to your computer and use it in GitHub Desktop.
Save franperezlopez/ef744a74d396cd38ac86eff19108485e to your computer and use it in GitHub Desktop.

Export fast.ai models to ONNX

fast.ai is an amazing library, allowing us to use a vast array of pre-trained models, but it also can be used as a foundation for new models.

The fast.ai tabular model is a great model that uses:

  • an embedding layer for representing categorical features
  • parametrized number of hidden layers modeling the continuous features
  • applies batch normalization, droput and weight decay to regularize the model (prevents overfitting and allows the model to be trained faster)

The code is so simple, that it takes an small amount of time to check it out. After reading the code, I wanted to make my "own" version of the tabular model. Mainly I wanted to try a new way to compose the hidden layers. This model uses Batch Norm + Linear + Dropout + Activation Function (ReLU), but some papers claim that the Activation Function should be placed before the Batch Norm the Dropout layers (activation functions used to be zero centered functions, so batch normalization can help to "align" the data). So basically, I implemented a new function to test this claim: https://gist.github.com/c42784c5d929e983bf068e488ba1e196

This function is embedded in a new class named TabularExperimentModel that makes some minor tweaks to the original TabularModel. Additionally, I created a factory method (tabularexperiment_learner) to create instances of the new class. The full code is available as a gist.

So far so good. Now is time to check the model architecture. First command you will check, it will be model.sumary().

https://gist.github.com/e09dcbc7d64f838214c81b4151ba6f82

TabularExperimentModel(
  (embeds): ModuleList(
    (0): Embedding(14, 7)
  )
  (emb_drop): Dropout(p=0.001)
  (bn_cont): BatchNorm1d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layers): Sequential(
    (0): Linear(in_features=31, out_features=200, bias=True)
    (1): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): Linear(in_features=200, out_features=100, bias=True)
    (4): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): Dropout(p=0.005)
    (6): ReLU(inplace)
    (7): Linear(in_features=100, out_features=1, bias=True)
  )
)

https://gist.github.com/9bcabea17b3ac1e32b7339bff9f0201d

======================================================================
Layer (type)         Output Shape         Param #    Trainable 
======================================================================
Embedding            [7]                  98         True      
______________________________________________________________________
Dropout              [7]                  0          False     
______________________________________________________________________
BatchNorm1d          [24]                 48         True      
______________________________________________________________________
Linear               [200]                6,400      True      
______________________________________________________________________
BatchNorm1d          [200]                400        True      
______________________________________________________________________
ReLU                 [100]                0          False     
______________________________________________________________________
Linear               [100]                20,100     True      
______________________________________________________________________
BatchNorm1d          [100]                200        True      
______________________________________________________________________
Dropout              [100]                0          False     
______________________________________________________________________
Linear               [1]                  101        True      
______________________________________________________________________

Total params: 27,347
Total trainable params: 27,347
Total non-trainable params: 0

But you will realise this command does not show all the layers (last Sigmoid and some Dropout's neurons are missing). These commands fall short due to the dynamic nature of PyTorch.

The Open Neural Network eXchange format (ONNX) is as open standard for distributing neural network models. This standard provides framework (PyTorch, TensorFlow, CNTK, ....) interoperability. Then, for example, the model can be created and trained using PyTorch in Linux, and then exported to be used for inferencing in CoreML / iOS devices.

As an additional bonus, third party applications can provide functionality across frameworks. For example, netron, which is a cross platform NN visualizer that works with PyTorch, Tensorflow and many other frameworks.

In order to get a visualization of the network, you need an small amount of data to feed (forward) the network, and then save the model in a file:

https://gist.github.com/ca53809ed433248bf6af2e7db3428147

In the code shown before, the second parameter of the method torch.onnx.export represents a tuple which will be passed to the forward method of the TabularExperimentModel(remember this method receives two parameters: the categorical and the continuous features).

It's more interesting to compare models graphically rather than using a "text report". For example, the following figure compares the hidden layer generated by default (base), with the new model (exp):

Also you can appreciate that even if the layers are not named after the names in PyTorch, is easy to identify each layer. Bear in mind, this is an small model, and for bigger models this can be troublesome. Another approach to follow is to get the node graph using Tensorboard.

Because we feed the NN with data, it is possible to visualize nodes added dynamically. For example, the forward method uses a sigmoid normalization when the y_range parameter is present. Before, using learn_exp.model.summary() this node was not displayed, but now using ONNX is already there:

In the former figure, selecting the Mul node shows where the data of the parameter is coming from (in this case, one parameter is the output of the sigmoid and the other parameter is a constant (self.y_range[1]-self.y_range[0]). As there is no labeling, the nodes are named in order, what makes hard to follow the data flow across the graph.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment