Skip to content

Instantly share code, notes, and snippets.

@euske
Last active June 4, 2023 05:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save euske/02e4222a3b91f3514b2b6e11627cce3e to your computer and use it in GitHub Desktop.
Save euske/02e4222a3b91f3514b2b6e11627cce3e to your computer and use it in GitHub Desktop.
Intuitive Explanation of Transformer (w/o math)

Intuitive Explanation of Transformer (without math)

Summary: Transformer (a la "Attention is All You Need") is a complex model that is built upon several important ideas. In this article, we explain these ideas in terms of traditional programming concepts, without using math.

Prerequisites: basic understanding of NN and RNN, and Python.

What's wrong with RNN?

Here's the general seq2seq translation framework:

"King Midas has donkey ears" --[encode]-->  [17 29 54] --[decode]--> "王様 の 耳 は ロバ の 耳"
                                          |<--fixed-->|

In a typical RNN, an input sequence is processed for one element at a time. It is "compressed (encoded)" to an intermediate representation (a fixed-length vector), then decoded in the target language.

This model has one obvious bottleneck: the intermediate vector has only a fixed length, and it can only hold so much information. Naturally, this model is known to suffer with a long input sequence.

How is Transformer different?

Transformer fundamentally resolves the above problem by allowing an intermediate vector that have the same length as the input vector:

"King Midas has donkey ears"
   |    |    |    |     |  (encode)
   v    v    v    v     v
 [17   39   27    53   76] (memory)
|<- same length as input ->|
   |    |    |    |     |  (decode)
   v    v    v    v     v
  "王様 の 耳 は ロバ の 耳"

Note that the meaning of the intermediate vector is very different in RNN and Transformer - In RNN, the information of the vector is accumuated as it goes through each element, so the longer the input sequence gets, the more compression will occur. This means that, in RNN, a vector from a short sequence can't be directly compared with a vector from a longer sequence. On the other hand, each element in Transformer can carry the equal amount of information, and a vector from a short sequence can naturally expaned to a longer sequence. This is good for training.

Furthermore, unlike RNN which needs to process one element at a time, each element in Transformer can be trained/inferred in parallel. So intuitively it has a good chance (at least in theory) to beat traditional RNN models.

How does Transformer work?

Now let's take a look at how the algorithm works overall:

def transformer(input):
    memory = encoder(input)
    output = [BOS]
    while True:
        output.append( decoder(memory, output) )
        if output[-1] == EOS: break
    return output

Seems pretty straightforward, right? First, the input sequence is converted to "memory" (whatever it is). Then, the output sequence is generated one element at a time.

Now, what is this "memory" thing?

Intuitively speaking, this is an associative array, hash table, or Python dictionary. The encoder constructs the dictionary, and the decoder looks it up.

Here's the intuitive Python version:

def encoder(src):
    h1 = { key1(x):value1(x) for x in src }
    h2 = { key2(x):value2(x) for x in src }
    memory = (h1,h2)
    return memory

def decoder(memory, target):
    (h1,h2) = memory
    v1 = [ h1.get(query1(x)) for x in target ]
    v2 = [ h2.get(query2(x)) for x in target ]
    return ff(v1,v2)

Here, key1(), value1(), key2(), value2(), query1(), query2() and ff() are all learnable functions. Note that the real Transformer does not use a dictionary because they are not differentiable (which is the vital requirement for any neural network model). In the real Transformer, vectors and matrices are used to do a comparable operation.

Also note that the above encoder/decoder uses only two dictionaries. In the original paper, they used eight dictionary-like things (called "heads" in the paper) but for simplicity we just use two of them here.

Here is a simple model that uses the above framework. This just returns an identical sequence to the input:

BOS = 0
EOS = 999

# Let's say these are learned functions.
def key1(x): return x
def value1(x): return x
def key2(x): return x
def value2(x): return 1
def query1(x): return x
def query2(x): return x+1
def ff(v1,v2):
    x1 = v1[-1]
    x2 = v2[-1]
    if x2 is None:
        return EOS
    else:
        return x1+1

print(transformer([BOS,1,2,3,4,5,EOS])) # [BOS,1,2,3,4,5,EOS]

This model is still rather limited because it can only perform a simple element-wise operation. Specifically, it lacks the ability to take into account following information (which is crucial for most sequence problems.)

  • a. Relationship of multiple elements.
  • b. Order of each element.

Considering the relationship of multiple elements (Self-Attention)

Now let's take the first problem. How Transformer can gather information from multiple elements? It uses a mechanism called "Self-Attention". This is the key idea of the paper (hence the title) and one of the biggest breakthroughs in recent NN research.

However, I found the term "Self-Attention" is rather confusing. To me, this could be more aptly called "inner relationship", because what it does is to figure out the relationships between the elements within an input sequence. Something like this:

             +---object--+
             |           |
  +-subject--+           |
  |          |           |
  +-adj-+    |     +-adj-+
  |     |    |     |     |
"King Midas has donkey ears"

Each relationship consists of its source, target and the type of relationship. It's a kind of a graph. What's great about Transformer is that it can discover this kind of relationships automatically! Note that the actual relationships ("attentions") in Transformer isn't this clear-cut, and we often don't know what kind of relationships is actually expressed, but nonetheless it's "some" relationship anyway. This is analogous to the visual features learned in CNN layers. It's hard to know what kind of features are actually captured at each Conv layer, but they're nonetheless some features!

Here's a Python version of this relationship-building process:

def self_attn(seq):
    h1 = { sa_key1(x):sa_value1(x) for x in seq }
    h2 = { sa_key2(x):sa_value2(x) for x in seq }
    a1 = [ h1.get(sa_query1(x)) for x in seq ]
    a2 = [ h2.get(sa_query2(x)) for x in seq ]
    return [ aa(y1,y2) for (y1,y2) in zip(a1,a2) ]

Here, "seq" is the input sequence in question. sa_key1(), sa_value1(), sa_key2(), sa_value2(), sa_query1(), sa_query2() and aa() are all learnable functions. We first create two dictionaries, h1 and h2, which correspond to two types of "relationships" (again, they are called "heads"). Each key/value in the dictionary is computed from each element. Then, we create two sequences by looking up the dictionary with its own input. This effectively compares all the input elements with each other. Finally, we combine these two sequences to produce a final output sequence. This sequence takes care of two types of such relationships. In the example below, each element of the output sequence is 1 if there's a half or double of the original value within the input sequence.

def sa_key1(w): return w
def sa_value1(w): return w
def sa_key2(w): return w
def sa_value2(w): return w
def sa_query1(w): return w*2
def sa_query2(w): return w/2
def aa(v1,v2):
    if v1 is None and v2 is None: return 0
    return 1

print(self_attn([BOS,1,2,3,4,5,8,EOS])) # [1,1,1,0,1,0,1,0]

In Transformer, eight types of information is carried separately in the output vector. So they hold multiple relationships (called "Multi-Head Attention"). Again, the real implementation does not use a dictionary but does a comparable thing by computing dot products and softmax. Furthermore, they stack this layer six times. This means that it can use shallow relationships to discover deeper, more complex relationships between elements. Self-Attention is really the key to handle the complexity of input sequences.

Considering the element order (Positional Encoding)

The second problem we need to address is considering the element order. There are a number of ways to do this. The most straightforward way is to add an extra field to each element that contains a number. However this increases the information for each input and adds extra burden to the network. A better way is to superimpose some kind of "watermark" on each element in a way that you can tell if an element comes before or after some other element, and that's what the Transformer paper is doing. This method is called "Positional Encoding".

Here's a very simplistic example of positional encoding:

def add_positional(seq):
    return [ i*1000+x for (i, x) in enumerate(seq) ]

print(add_positional([BOS,2,5,7,9,20,EOS])) # [0, 1002, 2005, 3007, 4009, 5020, 6999]

This simple implementation has a shortcoming in that the system can't distinguish some elements if they are being added (e.g. 1002 + 2005 = 3007). The real Positional Encoding is a little more sophisticated than this so that this kind of things doesn't happen, but the basic idea is the same.

Combining this with Self-Attention, we can now detect if the same element appears twice consecutively:

def sa_key1(w): return w // 1000
def sa_value1(w): return w % 1000
def sa_key2(w): return w // 1000
def sa_value2(w): return w % 1000
def sa_query1(w): return w // 1000
def sa_query2(w): return (w // 1000)-1
def aa(v1,v2):
    if v1 != v2: return 0
    return 1

print(self_attn(add_positional([BOS,1,1,5,5,2,EOS]))) # [0, 0, 1, 0, 1, 0, 0]

And that's it! Now you've understood the basic mechanism of Transformer without using math.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment