Skip to content

Instantly share code, notes, and snippets.

@hccho2
Last active July 25, 2019 06:49
Show Gist options
  • Save hccho2/81265eea686465fc0fd7aba5cbb73051 to your computer and use it in GitHub Desktop.
Save hccho2/81265eea686465fc0fd7aba5cbb73051 to your computer and use it in GitHub Desktop.
attention계산과정에서의 행렬곱
N=2 # batch size
T=20 # encoder time length
D1=30 # encoder hidden dim
D2=6 # decoder hidden dim
D3=11 # attention dim
h = np.random.randn(N,T,D1) # all encoder hidden
s = np.random.randn(N,D2) # decoder hidden at one time step
Wm = np.random.randn(D1,D3)
Wq = np.random.randn(D2,D3)
A = np.matmul(h,Wm) + np.expand_dims(np.matmul(s,Wq),axis=1)
hs = np.concatenate([h,np.tile(np.expand_dims(s,1),(1,T,1))],axis=-1)
Wmq = np.concatenate([Wm,Wq],axis=0)
B = np.matmul(hs,Wmq)
print(np.allclose(A,B))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment