Skip to content

Instantly share code, notes, and snippets.

@domluna
Created September 25, 2018 00:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save domluna/f87eb9b9bba139df072b93277d818130 to your computer and use it in GitHub Desktop.
Save domluna/f87eb9b9bba139df072b93277d818130 to your computer and use it in GitHub Desktop.
Notes about attention and transformer

Transformer notes

  • current models have trouble learning dependencies over distance (i.e. between characters/words), # ops scale O(n) or O(log n).

  • transformer is O(1) in number of ops

  • encoder-decoder with residual conns. Encoder/decodes feed into themselves N times.

  • We also modify the self-attention sub-layer in the decoder stack to prevent positions from attending to subsequent positions. This masking, combined with fact that the output embeddings are offset by one position, **ensures that the predictions for position i can depend only on the known outputs at positions less than i **.

def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0

with size 10 returns:

tensor([[[ 1,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 1,  1,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 1,  1,  1,  0,  0,  0,  0,  0,  0,  0],
         [ 1,  1,  1,  1,  0,  0,  0,  0,  0,  0],
         [ 1,  1,  1,  1,  1,  0,  0,  0,  0,  0],
         [ 1,  1,  1,  1,  1,  1,  0,  0,  0,  0],
         [ 1,  1,  1,  1,  1,  1,  1,  0,  0,  0],
         [ 1,  1,  1,  1,  1,  1,  1,  1,  0,  0],
         [ 1,  1,  1,  1,  1,  1,  1,  1,  1,  0],
         [ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1]]], dtype=torch.uint8)

Attention

represented as vectors, but packed into matrices:

query (Q)
keys (K)
values (V)

scaled dot-product attention, sqrt(d_k) scales the dot product otherwise the result would have mean 0 with variance d_k, potentially leading to very small gradients.

def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"

    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)

    p_attn = F.softmax(scores, dim = -1)

    if dropout is not None:
        p_attn = dropout(p_attn)

    return torch.matmul(p_attn, value), p_attn

multi-head attention:

head_i = attention(Q * W_i^{Q}, K * W_i^{K}, V * W_i^{V})
multihead_attention = concat(head_0, ..., head_{n-1}) * W_o

matrix size for the multi-heads should be proportional to the # of heads and model size s.t. computational cost remains within a small constant of single-head attention

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment