Last active
July 25, 2019 06:49
-
-
Save hccho2/81265eea686465fc0fd7aba5cbb73051 to your computer and use it in GitHub Desktop.
attention계산과정에서의 행렬곱
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
N=2 # batch size | |
T=20 # encoder time length | |
D1=30 # encoder hidden dim | |
D2=6 # decoder hidden dim | |
D3=11 # attention dim | |
h = np.random.randn(N,T,D1) # all encoder hidden | |
s = np.random.randn(N,D2) # decoder hidden at one time step | |
Wm = np.random.randn(D1,D3) | |
Wq = np.random.randn(D2,D3) | |
A = np.matmul(h,Wm) + np.expand_dims(np.matmul(s,Wq),axis=1) | |
hs = np.concatenate([h,np.tile(np.expand_dims(s,1),(1,T,1))],axis=-1) | |
Wmq = np.concatenate([Wm,Wq],axis=0) | |
B = np.matmul(hs,Wmq) | |
print(np.allclose(A,B)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment