Last active
July 1, 2018 14:31
-
-
Save tanikawa04/e296ac9c4ff6eab84ff332fc602c62bb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import chainer.links as L | |
import chainer.functions as F | |
from chainer import Variable | |
def sequence_embed(embed, xs): | |
x_len = [len(x) for x in xs] | |
x_section = np.cumsum(x_len[:-1]) | |
ex = embed(F.concat(xs, axis=0)) | |
exs = F.split_axis(ex, x_section, 0) | |
return exs | |
# 適当なニューラルネット (注意: 本来はこのような小規模なネットワークは組みません) | |
embed = L.EmbedID(10, 3) | |
lstm = L.NStepLSTM(2, 3, 3, 0.5) | |
# 各文の単語 ID 列を Variable でラップする | |
xs = [ | |
Variable(np.array([0, 5, 7, 1], dtype=np.int32)), | |
Variable(np.array([0, 4, 3, 6, 7, 1], dtype=np.int32)) | |
] | |
# 正常に処理される | |
emb_xs = sequence_embed(embed, xs) | |
hy, cy, ys = lstm(None, None, emb_xs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment