Skip to content

Instantly share code, notes, and snippets.

@denizyuret
Created November 22, 2021 09:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save denizyuret/770643f0ffeaa8f83951836e0493f5ec to your computer and use it in GitHub Desktop.
Save denizyuret/770643f0ffeaa8f83951836e0493f5ec to your computer and use it in GitHub Desktop.
# Vaswani, Ashish, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Łukasz Kaiser, and Illia Polosukhin. "Attention is all you need." In Advances in neural information processing systems, pp. 5998-6008. 2017.
# [1] https://papers.nips.cc/paper/7181-attention-is-all-you-need/, https://arxiv.org/abs/1706.03762 (reference paper)
# [2] https://github.com/harvardnlp/annotated-transformer, http://nlp.seas.harvard.edu/2018/04/03/attention.html (reference implementation)
# [3] https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/modules/multi_headed_attn.py
# [4] https://github.com/tensorflow/tensor2tensor
# TODO: LabelSmoothing?
# include("debug.jl")
include("mtdata.jl")
using IterTools, Base.Iterators, Printf, LinearAlgebra
using Knet: param, param0, atype, bmm, softmax, dropout, nll, relu
using Statistics: mean, std
## Transformer
struct Transformer; srcvocab; tgtvocab; srcembed; tgtembed; encoder; decoder; generator; end
# TODO: Transformer options = atype, winit, binit, init?, separate attention_dropout?
function Transformer(srcvocab, tgtvocab; dmodel=512, dff=2048, nheads=8, nlayers=6, maxlen=5000, dropout=0)
posembed, droplayer = PositionalEncoding(dmodel, maxlen), Dropout(dropout)
srcembed = Chain(Embed(length(srcvocab), dmodel), posembed, droplayer)
tgtembed = Chain(Embed(length(tgtvocab), dmodel), posembed, droplayer)
encoder = Chain([EncoderLayer(dmodel, dff, nheads, dropout) for n=1:nlayers]..., LayerNorm(dmodel))
decoder = Chain([DecoderLayer(dmodel, dff, nheads, dropout) for n=1:nlayers]..., LayerNorm(dmodel))
generator = Linear(dmodel, length(tgtvocab))
Transformer(srcvocab, tgtvocab, srcembed, tgtembed, encoder, decoder, generator)
end
function (t::Transformer)(src, tgt; average=true) # (T1,B) = size(src); (T2,B2) = size(tgt); @assert B == B2; (X,V) = size(t.tgtembed.layers[1].w)
enc = t.encoder(t.srcembed(src)) # @size enc (X,T1,B)
tgt1,tgt2 = tgt[1:end-1,:], tgt[2:end,:] # @size tgt1 (T2-1,B); @size tgt2 (T2-1,B)
dec = t.decoder(t.tgtembed(tgt1), enc) # @size dec (X,T2-1,B)
gen = t.generator(dec) # @size gen (V,T2-1,B)
(sumloss,numwords) = nll(gen, tgt2, average=false) # TODO: handle nll difference for different batchsizes.
average ? sumloss/numwords : (sumloss,numwords) # TODO: loss per word or per sequence?
end
## Chain
struct Chain; layers; end
function Chain(layer1, layer2, layers...)
Chain((layer1, layer2, layers...))
end
function (l::Chain)(x, o...)
for layer in l.layers
x = layer(x, o...)
end
return x
end
## EncoderLayer
struct EncoderLayer; selfattn; feedforw; end
function EncoderLayer(dmodel::Int, dff::Int, nheads::Int, dropout)
selfattn = MultiHeadAttention(dmodel, nheads, dropout)
selfattn = SubLayer(selfattn, dmodel, dropout)
feedforw = FeedForward(dmodel, dff, dropout)
feedforw = SubLayer(feedforw, dmodel, dropout)
EncoderLayer(selfattn, feedforw)
end
function (l::EncoderLayer)(x)
l.feedforw(l.selfattn(x))
end
## DecoderLayer
struct DecoderLayer; selfattn; srcattn; feedforw; end
function DecoderLayer(dmodel::Int, dff::Int, nheads::Int, dropout)
selfattn = MultiHeadAttention(dmodel, nheads, dropout, selfmask=true)
selfattn = SubLayer(selfattn, dmodel, dropout)
srcattn = MultiHeadAttention(dmodel, nheads, dropout)
srcattn = SubLayer(srcattn, dmodel, dropout)
feedforw = FeedForward(dmodel, dff, dropout)
feedforw = SubLayer(feedforw, dmodel, dropout)
DecoderLayer(selfattn, srcattn, feedforw)
end
function (l::DecoderLayer)(y,x)
l.feedforw(l.srcattn(l.selfattn(y), x))
end
## MultiHeadAttention
struct MultiHeadAttention; q; k; v; o; dropout; scale; selfmask; end
function MultiHeadAttention(dmodel::Int, nheads::Int, dropout; selfmask=false, scale=1/sqrt(dmodel÷nheads))
@assert dmodel % nheads == 0
dk = dmodel ÷ nheads
q = Linear(dmodel,dk,nheads)
k = Linear(dmodel,dk,nheads)
v = Linear(dmodel,dk,nheads)
o = Linear(dmodel,dmodel)
MultiHeadAttention(q,k,v,o,dropout,scale,selfmask)
end
function (l::MultiHeadAttention)(q,k,v; keymask=nothing) # inputs all batch-major
# (Q1,K1,V1) = size.((q,k,v),(1,)); (Q2,K2,V2) = size.((l.q.w,l.k.w,l.v.w),(1,)); (H,T1,T2) = size.((l.q.w,k,q),(2,)); B = size(q,3); (O2,O1) = size(l.o.w)
# @size q (Q1,T2,B); @size k (K1,T1,B); @size v (V1,T1,B)
# @size l.q.w (Q2,H,Q1); @size l.k.w (K2,H,K1); @size l.v.w (V2,H,V1); @assert K2 == Q2; @assert O1 == V2*H
# query, keys and values:
q,k,v = l.q(q),l.k(k),l.v(v) # @size q (Q2,H,T2,B); @size k (K2,H,T1,B); @size v (V2,H,T1,B)
q,v = permutedims.((q,v), ((1,3,2,4),)) # @size q (Q2,T2,H,B); @size v (V2,T1,H,B)
k = permutedims(k, (3,1,2,4)) # @size k (T1,K2,H,B)
# scores:
s = bmm(k,q) # @size s (T1,T2,H,B)
s = s * eltype(s)(l.scale) # @size s (T1,T2,H,B)
s = attnmask(s, keymask, l.selfmask) # @size s (T1,T2,H,B)
s = softmax(s, dims=1) # @size s (T1,T2,H,B)
s = dropout(s, l.dropout) # This is where all implementations put attention_dropout.
# context:
c = bmm(v,s) # @size c (V2,T2,H,B)
c = permutedims(c, (1,3,2,4)) # @size c (V2,H,T2,B)
c = reshape(c, :, size(c,3), size(c,4)) # @size c (O1,T2,B)
o = l.o(c) # @size o (O2,T2,B)
return o
end
function (l::MultiHeadAttention)(q::MaskedArray,k::MaskedArray,v::MaskedArray)
# We turn (Q,Tq,B),(K,Tk,B),(V,Tk,B) -> (V,Tq,B)
# The query mask will be applied to the output mask, does not effect the attention calculation
# Only the key/value mask will effect the inner score calculation
# Target time masking requires a (T,T) mask, not compatible with input sizes, has to be generated inside
@assert k.mask == v.mask
@assert size(q.mask,1) == 1
a = l(q.array, k.array, v.array, keymask = k.mask)
MaskedArray(a, q.mask)
end
(l::MultiHeadAttention)(x)=l(x,x,x)
(l::MultiHeadAttention)(y,x)=l(y,x,x)
function attnmask(s, keymask, do_selfmask) # s=(Tk,Tq,H,B) keymask=(1,Tk,B) selfmask=(T,T,1,1)
mask = nothing
if keymask !== nothing
@assert size(keymask) == (1, size(s,1), size(s,4))
mask = reshape(keymask, size(s,1), 1, 1, size(s,4))
end
if do_selfmask
@assert size(s,1) == size(s,2)
T = size(s,1)
sm = [ key <= qry for key in 1:T, qry in 1:T ] # qry should see up to its own position but no further
if mask === nothing
mask = reshape(sm, T, T, 1, 1)
else
mask = mask .& sm
end
end
if mask === nothing
return s
else
return s .+ oftype(s, -1e9 * .!mask)
end
end
## FeedForward
function FeedForward(dmodel, dff, dropout)
Chain(Linear(dmodel,dff), Relu(), Dropout(dropout), Linear(dff,dmodel))
end
## SubLayer
struct SubLayer; layer; norm; dropout; end
function SubLayer(layer, dmodel::Int, dropout::Number)
SubLayer(layer, LayerNorm(dmodel), Dropout(dropout))
end
# The paper suggests l.norm(x+l.dropout(l.layer(x))), however x + l.dropout(l.layer(l.norm(x)))
# is the default implementation in the code, see discussion on "LAYER NORMALIZATION" below.
function (l::SubLayer)(x, xs...)
x .+ l.dropout(l.layer(l.norm(x), xs...))
end
## LayerNorm: https://arxiv.org/abs/1607.06450: Layer Normalization
# TODO: this is slow, need a kernel, maybe https://github.com/tensorflow/tensorflow/pull/6205/files
struct LayerNorm; a; b; ϵ; end
function LayerNorm(dmodel; eps=1e-6)
a = param(dmodel; init=ones)
b = param(dmodel; init=zeros)
LayerNorm(a, b, eps)
end
function (l::LayerNorm)(x, o...)
μ = mean(x,dims=1)
σ = std(x,mean=μ,dims=1)
ϵ = eltype(x)(l.ϵ)
l.a .* (x .- μ) ./ (σ .+ ϵ) .+ l.b # TODO: doing x .- μ twice?
end
function (l::LayerNorm)(x::MaskedArray, o...)
MaskedArray(l(x.array), x.mask) # TODO: shouldn't normalization ignore masked values?
end
## Dropout
struct Dropout; p; end
function (l::Dropout)(x)
dropout(x, l.p) # TODO: dropout normalization does not depend on masks?
end
## Relu
struct Relu end
function (l::Relu)(x)
relu.(x)
end
## Embed
struct Embed; w; end
function Embed(vocabsize,embedsize)
Embed(param(embedsize,vocabsize))
end
function (l::Embed)(x)
l.w[:,x]
end
function (l::Embed)(x::MaskedArray)
a = l(x.array)
m = (x.mask === nothing ? nothing : reshape(x.mask, 1, size(x.mask)...))
MaskedArray(a, m)
end
## PositionalEncoding
struct PositionalEncoding; w; end
function PositionalEncoding(dmodel, maxlen; λ=10000, atype=atype())
x = exp.((0:2:dmodel-1) .* -(log(λ)/dmodel)) * (0:maxlen-1)'
pe = zeros(dmodel, maxlen)
pe[1:2:end,:] = sin.(x)
pe[2:2:end,:] = cos.(x)
PositionalEncoding(atype(pe))
end
function (l::PositionalEncoding)(x)
x .+ l.w[:,1:size(x,2)]
end
## Linear: generalizes mmul to more than 2 dims: (A...,B) x (B,C...) => (A...,C...)
struct Linear; w; b; end
function Linear(input::Int,outputs...; bias=true)
Linear(param(outputs...,input),
bias ? param0(outputs...) : nothing)
end
function (l::Linear)(x)
W1,W2,X1,X2 = size(l.w)[1:end-1], size(l.w)[end], size(x,1), size(x)[2:end]; @assert W2===X1
y = reshape(l.w,:,W2) * reshape(x,X1,:)
y = reshape(y, W1..., X2...)
if l.b !== nothing; y = y .+ l.b; end
return y
end
function (l::Linear)(x::MaskedArray)
(a,m) = (x.array, x.mask)
@assert m===nothing || all(size(m,i) == 1 || size(m,i) == size(a,i) for i in 1:ndims(a))
if m === nothing
return MaskedArray(l(a), nothing)
elseif size(m,1) == 1 # avoid mask multiplication if possible
b = l(a)
if ndims(b) > ndims(m)
m = reshape(m, ntuple(i->1, ndims(b)-ndims(m))..., size(m)...)
end
return MaskedArray(b, m)
else
return MaskedArray(l(a .* oftype(a,m)), nothing)
end
end
## TODO:
# AVERAGING: For the base models, we used a single model obtained by averaging the last 5 checkpoints, which were written at 10-minute intervals. For the big models, we averaged the last 20 checkpoints.
# GENERATION: We used beam search with a beam size of 4 and length penalty α = 0.6 [38]. These hyperparameters were chosen after experimentation on the development set. We set the maximum output length during inference to input length + 50, but terminate early when possible [38].
## CACHING: OpenNMT uses layer_cache to avoid extra computation:
# https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/modules/multi_headed_attn.py
# Different optimizations for when attn_type == 'self' vs 'context'
## SHARING_EMBEDDINGS: In our model, we share the same weight matrix between the two embedding layers and the pre-softmax linear transformation, similar to (cite). In the embedding layers, we multiply those weights by √dmodel.
# https://arxiv.org/abs/1608.05859: Using the Output Embedding to Improve Language Models
# For now I will use 3 independent embedding layers, so does harvard. So no need for scaling with √dmodel.
## DROPOUT:
# [1] Residual Dropout: We apply dropout [33] to the output of each sub-layer, before it is added to the sub-layer input and normalized. In addition, we apply dropout to the sums of the embeddings and the positional encodings in both the encoder and decoder stacks. For the base model, we use a rate of Pdrop = 0.1.
# [2] SublayerConnection: x + self.dropout(sublayer(self.norm(x)))
# PositionalEncoding: after the sum, at the end
# MultiHeadedAttention: after softmax (frozen by bug)
# FeedForward: self.w_2(self.dropout(F.relu(self.w_1(x))))
# [3] https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/encoders/transformer.py (separate attn_dropout from dropout)
# after attention before residual
# [4] https://github.com/tensorflow/tensor2tensor/blob/43be271c8a3fa06cb06b5147f044cbdc8bb77535/tensor2tensor/layers/common_attention.py#L1588 #L1747 #1934 #2106 #2266 (in attention after softmax)
# https://github.com/tensorflow/tensor2tensor/blob/43be271c8a3fa06cb06b5147f044cbdc8bb77535/tensor2tensor/layers/common_layers.py#L1358 (FeedForward between two layers, following relu)
# https://github.com/tensorflow/tensor2tensor/blob/dfcf88cb9b2ac695b1ca6be46b4ec29190d093b7/tensor2tensor/trax/models/transformer.py#L24 (FeedForward after both layers)
# https://github.com/tensorflow/tensor2tensor/blob/dfcf88cb9b2ac695b1ca6be46b4ec29190d093b7/tensor2tensor/trax/models/transformer.py#L57 (attention also has dropout at the end)
# https://github.com/tensorflow/tensor2tensor/blob/dfcf88cb9b2ac695b1ca6be46b4ec29190d093b7/tensor2tensor/trax/models/transformer.py#L98 (between embedding and positional encoding)
# [2],[3] have dropout at after posemb, no normalization
# [2],[3] have no dropout before posemb, no dropout after final_norm in encoder or decoder, no dropout in generator
## LAYER NORMALIZATION:
# [1] https://arxiv.org/abs/1706.03762
# The output of each sub-layer is LayerNorm(x + Sublayer(x)). We apply dropout [33] to the output of each sub-layer, before it is added to the sub-layer input and normalized.
# [2] http://disq.us/p/1s2bpmf (discussion)
# https://github.com/harvardnlp/annotated-transformer/blob/master/The%20Annotated%20Transformer.ipynb
# x + dropout(sublayer(norm(x)))
# also additional LayerNorm at the end of encoder and decoder
# [3] https://github.com/OpenNMT/OpenNMT-py/issues/770 (discussion suggests this is better)
# https://github.com/OpenNMT/OpenNMT-py/blob/cd29c1dbfb35f4a2701ff52a1bf4e5bdcf02802e/onmt/decoders/transformer.py#L73
# https://github.com/OpenNMT/OpenNMT-py/blob/cd29c1dbfb35f4a2701ff52a1bf4e5bdcf02802e/onmt/encoders/transformer.py#L48
# x + dropout(sublayer(norm(x)))
# also additional LayerNorm at the end of encoder and decoder
# [4] https://github.com/tensorflow/tensor2tensor/blob/43be271c8a3fa06cb06b5147f044cbdc8bb77535/tensor2tensor/layers/common_layers.py#L862) (configurable implementation)
# https://github.com/tensorflow/tensor2tensor/blob/49e279eb6c871fbebc137d6f598758a275f521c3/tensor2tensor/layers/common_hparams.py#L110-L112 (note suggests layer+"dan" is published and default but "n"+layer+"da" is better.)
# https://github.com/tensorflow/tensor2tensor/blob/49e279eb6c871fbebc137d6f598758a275f521c3/tensor2tensor/models/transformer.py#L1133-L1134 (transformer_base_v2 config makes this the default)
## LABEL SMOOTHING:
# [1] Label Smoothing During training, we employed label smoothing of value ls = 0.1 [36]. This hurts perplexity, as the model learns to be more unsure, but improves accuracy and BLEU score.
# [2]
# [3]
# [4]
## MASKING:
# EncoderDecoder gets src_mask and tgt_mask
# Encoder gets src_mask
# Decoder gets both
# Encoder passes src_mask to EncoderLayer (not final norm)
# EncoderLayer passes it to self_attn (but not feedforw)
# Decoder passes both to DecoderLayer (not final norm)
# DecoderLayer passes src_mask to src_attn, tgt_mask to self_attn
# attention replaces masked positions with 1e-9 before softmax.
# src_mask is constructed using only pads.
# tgt_mask is a combination of pads and subsequent_mask.
# src_mask = (src != SRC.stoi["<blank>"]).unsqueeze(-2)
# decoder takes tgt[1:end-1] (trg) as input and uses tgt[2:end] (trg_y) as output.
# loss_compute takes into account tgt_mask how?
# greedy_decode takes into account masks how?
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment