Skip to content

Instantly share code, notes, and snippets.

@xlvector
Last active November 24, 2016 03:07
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save xlvector/370cb20c62a9ca16d5d9fea43a27f33c to your computer and use it in GitHub Desktop.
Save xlvector/370cb20c62a9ca16d5d9fea43a27f33c to your computer and use it in GitHub Desktop.
//used for training
def bi_lstm_unroll(seq_len, input_size,num_hidden, num_embed, num_label, dropout=0.):
embed_weight = mx.sym.Variable("embed_weight")
cls_weight = mx.sym.Variable("cls_weight")
cls_bias = mx.sym.Variable("cls_bias")
last_states = []
last_states.append(LSTMState(c = mx.sym.Variable("l0_init_c"), h = mx.sym.Variable("l0_init_h")))
last_states.append(LSTMState(c = mx.sym.Variable("l1_init_c"), h = mx.sym.Variable("l1_init_h")))
forward_param = LSTMParam(i2h_weight=mx.sym.Variable("l0_i2h_weight"),
i2h_bias=mx.sym.Variable("l0_i2h_bias"),
h2h_weight=mx.sym.Variable("l0_h2h_weight"),
h2h_bias=mx.sym.Variable("l0_h2h_bias"))
backward_param = LSTMParam(i2h_weight=mx.sym.Variable("l1_i2h_weight"),
i2h_bias=mx.sym.Variable("l1_i2h_bias"),
h2h_weight=mx.sym.Variable("l1_h2h_weight"),
h2h_bias=mx.sym.Variable("l1_h2h_bias"))
data = mx.sym.Variable('data')
label = mx.sym.Variable('softmax_label')
embed = mx.sym.Embedding(data=data, input_dim=input_size,
weight=embed_weight, output_dim=num_embed, name='embed')
wordvec = mx.sym.SliceChannel(data=embed, num_outputs=seq_len, squeeze_axis=1)
forward_hidden = []
for seqidx in range(seq_len):
hidden = wordvec[seqidx]
next_state = lstm(num_hidden, indata=hidden,
prev_state=last_states[0],
param=forward_param,
seqidx=seqidx, layeridx=0, dropout=dropout)
hidden = next_state.h
last_states[0] = next_state
forward_hidden.append(hidden)
backward_hidden = []
for seqidx in range(seq_len):
k = seq_len - seqidx - 1
hidden = wordvec[k]
next_state = lstm(num_hidden, indata=hidden,
prev_state=last_states[1],
param=backward_param,
seqidx=k, layeridx=1,dropout=dropout)
hidden = next_state.h
last_states[1] = next_state
backward_hidden.insert(0, hidden)
hidden_all = []
for i in range(seq_len):
hidden_all.append(mx.sym.Concat(*[forward_hidden[i], backward_hidden[i]], dim=1))
hidden_concat = mx.sym.Concat(*hidden_all, dim=0)
pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=num_label,
weight=cls_weight, bias=cls_bias, name='pred')
label = mx.sym.transpose(data=label)
label = mx.sym.Reshape(data=label, target_shape=(0,))
sm = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax')
return sm
//used for inference
class BiLSTMInferenceModel(object):
def __init__(self,
seq_len,
input_size,
num_hidden,
num_embed,
num_label,
arg_params,
ctx=mx.cpu(),
dropout=0.):
self.sym = bi_lstm_inference_symbol(input_size, seq_len,
num_hidden,
num_embed,
num_label,
dropout)
print "input size: ", input_size
batch_size = 1
init_c = [('l%d_init_c'%l, (batch_size, num_hidden)) for l in range(2)]
init_h = [('l%d_init_h'%l, (batch_size, num_hidden)) for l in range(2)]
data_shape = [("data", (batch_size,))]
input_shapes = dict(init_c + init_h + data_shape)
print input_shapes
self.executor = self.sym.simple_bind(ctx=mx.cpu(), **input_shapes)
for key in self.executor.arg_dict.keys():
if key in arg_params:
print key, arg_params[key].shape, self.executor.arg_dict[key].shape
arg_params[key].copyto(self.executor.arg_dict[key])
#self.executor.arg_dict[key] = arg_params[key]
state_name = []
for i in range(2):
state_name.append("l%d_init_c" % i)
state_name.append("l%d_init_h" % i)
self.states_dict = dict(zip(state_name, self.executor.outputs[1:]))
self.input_arr = mx.nd.zeros(data_shape[0][1])
def forward(self, input_data, new_seq=False):
if new_seq == True:
for key in self.states_dict.keys():
self.executor.arg_dict[key][:] = 0.
print input_data
self.executor.arg_dict["data"] = input_data
#input_data.copyto(self.executor.arg_dict["data"])
self.executor.forward()
for key in self.states_dict.keys():
print key
self.states_dict[key].copyto(self.executor.arg_dict[key])
prob = self.executor.outputs[0].asnumpy()
return prob
def bi_lstm_inference_symbol(input_size, seq_len,
num_hidden, num_embed, num_label, dropout=0.):
seqidx = 0
embed_weight=mx.sym.Variable("embed_weight")
cls_weight = mx.sym.Variable("cls_weight")
cls_bias = mx.sym.Variable("cls_bias")
last_states = [LSTMState(c = mx.sym.Variable("l0_init_c"), h = mx.sym.Variable("l0_init_h")),
LSTMState(c = mx.sym.Variable("l1_init_c"), h = mx.sym.Variable("l1_init_h"))]
forward_param = LSTMParam(i2h_weight=mx.sym.Variable("l0_i2h_weight"),
i2h_bias=mx.sym.Variable("l0_i2h_bias"),
h2h_weight=mx.sym.Variable("l0_h2h_weight"),
h2h_bias=mx.sym.Variable("l0_h2h_bias"))
backward_param = LSTMParam(i2h_weight=mx.sym.Variable("l1_i2h_weight"),
i2h_bias=mx.sym.Variable("l1_i2h_bias"),
h2h_weight=mx.sym.Variable("l1_h2h_weight"),
h2h_bias=mx.sym.Variable("l1_h2h_bias"))
data = mx.sym.Variable("data")
embed = mx.sym.Embedding(data=data, input_dim=input_size,
weight=embed_weight, output_dim=num_embed, name='embed')
wordvec = mx.sym.SliceChannel(data=embed, num_outputs=seq_len, squeeze_axis=1)
forward_hidden = []
for seqidx in range(seq_len):
next_state = lstm(num_hidden, indata=wordvec[seqidx],
prev_state=last_states[0],
param=forward_param,
seqidx=seqidx, layeridx=0, dropout=0.0)
hidden = next_state.h
last_states[0] = next_state
forward_hidden.append(hidden)
backward_hidden = []
for seqidx in range(seq_len):
k = seq_len - seqidx - 1
next_state = lstm(num_hidden, indata=wordvec[k],
prev_state=last_states[1],
param=backward_param,
seqidx=k, layeridx=1, dropout=0.0)
hidden = next_state.h
last_states[1] = next_state
backward_hidden.insert(0, hidden)
hidden_all = []
for i in range(seq_len):
hidden_all.append(mx.sym.Concat(*[forward_hidden[i], backward_hidden[i]], dim=1))
hidden_concat = mx.sym.Concat(*hidden_all, dim=0)
fc = mx.sym.FullyConnected(data=hidden_concat, num_hidden=num_label,
weight=cls_weight, bias=cls_bias, name='pred')
sm = mx.sym.SoftmaxOutput(data=fc, name='softmax')
output = [sm]
for state in last_states:
output.append(state.c)
output.append(state.h)
return mx.sym.Group(output)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment