Skip to content

Instantly share code, notes, and snippets.

@MilesCranmer
Last active April 18, 2023 21:48
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save MilesCranmer/1fb1d00c3fd0edccbd964220a72d97f5 to your computer and use it in GitHub Desktop.
Save MilesCranmer/1fb1d00c3fd0edccbd964220a72d97f5 to your computer and use it in GitHub Desktop.
Minimal multi-headed self-attention
using Flux
using Fluxperimental: @compact
nf = 10
nb = 32
nt = 100
d_attn = 64
d_value = 128
d_head = 16
d_out = 256
X = randn(Float32, nf, nt, nb)
# X: (feature, time, batch)
# Minimal:
attn(i,o,a,dv,H)=@compact(out=Dense(H*dv=>o),heads=[@compact(D=l->Dense(i=>l),K=D(a),V=D(dv),Q=D(a))do x;k,v,q=K(x),V(x),Q(x);x=sum(k.*q,dims=1)./√a;softmax(x,dims=2).*v;end;for _∈1:H])do x;vcat([h(x) for h∈heads]...)|>out;end
# Expanded:
function multihead_self_attention(num_features, num_out, d_attn, d_value, num_heads)
@compact(
out = Dense(num_heads * d_value => num_out),
heads = [
@compact(
K = Dense(num_features => d_attn),
V = Dense(num_features => d_value),
Q = Dense(num_features => d_attn)
) do x
k, v, q = K(x), V(x), Q(x)
x = sum(k .* q; dims=1) ./ sqrt(d_attn)
softmax(x; dims=2) .* v
end for _ in 1:num_heads
]
) do x
out(vcat([h(x) for h in heads]...))
end
end
attn(nf, d_out, d_attn, d_value, d_head)(X)
multihead_self_attention(nf, d_out, d_attn, d_value, d_head)(X)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment