Skip to content

Instantly share code, notes, and snippets.

@lostella
Created October 10, 2018 20:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lostella/261fd5d08dfb5e2054c4d01a7e2bc88e to your computer and use it in GitHub Desktop.
Save lostella/261fd5d08dfb5e2054c4d01a7e2bc88e to your computer and use it in GitHub Desktop.
import mxnet as mx
class MyBlock(mx.gluon.HybridBlock):
def __init__(self):
super().__init__()
with self.name_scope():
self.lstm = mx.gluon.rnn.HybridSequentialRNNCell()
for layer in range(3):
self.lstm.add(mx.gluon.rnn.LSTMCell(hidden_size=20))
def hybrid_forward(self, F, seq):
outputs, state = self.lstm.unroll(inputs=seq, length=10, layout="NTC", merge_outputs=True)
return outputs
block = MyBlock()
block.initialize()
block.hybridize()
input = mx.nd.random_normal(shape=(32, 10, 5))
output = block(input)
block.export(path="./model", epoch=0)
symbol = mx.gluon.SymbolBlock.imports(
symbol_file="./model-symbol.json",
input_names=[f"data"],
param_file="./model-0000.params",
ctx=mx.Context.default_ctx
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment