Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save meikuam/aca7a0b9ee6bd3ded4feef92c24c1775 to your computer and use it in GitHub Desktop.
Save meikuam/aca7a0b9ee6bd3ded4feef92c24c1775 to your computer and use it in GitHub Desktop.
>>> import graphviz
>>> import torch
>>> import torch.nn as nn
>>> x = torch.randn([1, 1, 16])
>>> x.shape
torch.Size([1, 1, 16])
>>> rnn = nn.LSTM(
... input_size=16,
... hidden_size=8,
... num_layers=2,
... batch_first=True)
>>> y = rnn(x)
>>> make_dot(y[0]).render("lstm_torchviz", format="png")
'lstm_torchviz.png'
------------
import hiddenlayer as hl
>>> import torch
>>> import torch.nn as nn
>>> transforms = [ hl.transforms.Prune('Constant') ] # Removes Constant nodes from graph.
>>> rnn = nn.LSTM(input_size=16,hidden_size=8,num_layers=2,batch_first=True)
>>> x = torch.randn([1, 1, 16])
>>> graph = hl.build_graph(rnn, x, transforms=transforms)
/home/user/.local/lib/python3.8/site-packages/torch/onnx/symbolic_opset9.py:1801: UserWarning: Exporting a model to ONNX with a batch_size other than 1, with a variable length with LSTM can cause an error when running the ONNX model with a different batch size. Make sure to save the model with a batch size of 1, or define the initial states (h0/c0) as inputs of the model.
warnings.warn("Exporting a model to ONNX with a batch_size other than 1, " +
>>> graph.theme = hl.graph.THEMES['blue'].copy()
>>> graph.save('rnn_hiddenlayer', format='png')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment