Skip to content

Instantly share code, notes, and snippets.

@lostella
Created October 11, 2018 11:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lostella/9a790fd89726c1741a1fcf4194a5dac6 to your computer and use it in GitHub Desktop.
Save lostella/9a790fd89726c1741a1fcf4194a5dac6 to your computer and use it in GitHub Desktop.
import mxnet as mx
class MyBlock(mx.gluon.HybridBlock):
def __init__(self):
super().__init__()
with self.name_scope():
self.lstmcell = mx.gluon.rnn.LSTMCell(hidden_size=20)
def hybrid_forward(self, F, seq):
outputs, state = self.lstmcell.unroll(inputs=seq, length=10, layout="NTC", merge_outputs=True)
return outputs
block = MyBlock()
block.initialize()
block.hybridize()
input = mx.nd.random_normal(shape=(32, 10, 5))
output = block(input)
block.export(path="./model", epoch=0)
symbol = mx.gluon.SymbolBlock.imports(
symbol_file="./model-symbol.json",
input_names=[f"data"],
param_file="./model-0000.params",
ctx=mx.Context.default_ctx
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment