Last active
August 20, 2020 10:25
-
-
Save lozhnikov/aabb9231c0bb72528ff64a4f9bc19923 to your computer and use it in GitHub Desktop.
Transformer model draft
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <mlpack/methods/ann/ffn.hpp> | |
#include <mlpack/methods/ann/layer/layer.hpp> | |
using namespace mlpack::ann; | |
/* This layer accepts the output of the positional encoding or the output | |
* of the previous encoder block as the input. | |
* The output is one key/value matrix. | |
*/ | |
Sequential<>* CreateEncoderBlock() { | |
Sequential<>* encoderBlock = new Sequential<>(); | |
{ | |
/* Broadcast (query) into (query, key, value). */ | |
Concat<>* selfAttentionInput = new Concat<>(); | |
selfAttentionInput->Add<IdentityLayer<>>(); | |
selfAttentionInput->Add<IdentityLayer<>>(); | |
selfAttentionInput->Add<IdentityLayer<>>(); | |
/* Self attention layer. */ | |
Sequential<>* selfAttention = new Sequential<>(); | |
selfAttention->Add(selfAttentionInput); | |
selfAttention->Add<MultiheadAttention<>>(); | |
/* This layer adds a residual connection. */ | |
AddMerge<>* residualAddMerge = new AddMerge<>(); | |
residualAddMerge->Add(selfAttention); | |
residualAddMerge->Add<IdentityLayer<>>(); | |
encoderBlock->Add(residualAddMerge); | |
} | |
encoderBlock->Add<LayerNorm<>>(); | |
{ | |
/* Some FFN. It could contain a number of dense layers with activations. */ | |
Sequential<>* pointWiseFeedForwardNetwork = new Sequential<>(); | |
// pointWiseFeedForwardNetwork->Add(......); | |
// pointWiseFeedForwardNetwork->Add(......); | |
/* This layer adds a residual connection. */ | |
AddMerge<>* residualAddMerge = new AddMerge<>(); | |
residualAddMerge->Add(pointWiseFeedForwardNetwork); | |
residualAddMerge->Add<IdentityLayer<>>(); | |
encoderBlock->Add(residualAddMerge); | |
} | |
encoderBlock->Add<LayerNorm<>>(); | |
return encoderBlock; | |
} | |
/* | |
* A decoder block. The input of the block consists of 2 matrices: | |
* the first one is the output of the encoder, and the second one is the output | |
* of the positional encoder or the previous decoder block. | |
*/ | |
Sequential<>* CreateDecoderBlock() { | |
Sequential<>* decoderBlockBottom = new Sequential<>(); | |
/* Extract the output of the previous decoder block or the output of the | |
* positional encoder. */ | |
decoderBlockBottom->Add<Subview<>>(); | |
{ | |
/* Broadcast (query) into (query, key, value). */ | |
Concat<>* selfAttentionInput = new Concat<>(); | |
selfAttentionInput->Add<IdentityLayer<>>(); | |
selfAttentionInput->Add<IdentityLayer<>>(); | |
selfAttentionInput->Add<IdentityLayer<>>(); | |
/* Self attention layer. */ | |
Sequential<>* selfAttention = new Sequential<>(); | |
selfAttention->Add(selfAttentionInput); | |
selfAttention->Add<MultiheadAttention<>>(); | |
/* Residual connection. */ | |
AddMerge<>* residualAddMerge = new AddMerge<>(); | |
residualAddMerge->Add(selfAttention); | |
residualAddMerge->Add<IdentityLayer<>>(); | |
decoderBlockBottom->Add(residualAddMerge); | |
} | |
decoderBlockBottom->Add<LayerNorm<>>(); | |
Sequential<>* decoderBlock = new Sequential<>(); | |
{ | |
/* This layer extracts the output of the decoder and broadcasts it | |
* two times i.e. it extracts (key) and broadcasts it into | |
* (key, value). */ | |
Concat<>* broadcastEncoderOutput = new Concat<>(); | |
broadcastEncoderOutput->Add<Subview<>>(); | |
broadcastEncoderOutput->Add<Subview<>>(); | |
/* This layer concatenates the output of the block bottom (query) | |
* and the output of the encoder (key, value). */ | |
Concat<>* encoderDecoderAttentionInput = new Concat<>(); | |
encoderDecoderAttentionInput->Add(decoderBlockBottom); | |
encoderDecoderAttentionInput->Add(broadcastEncoderOutput); | |
/* Encoder-decoder attention. */ | |
Sequential<>* encoderDecoderAttention = new Sequential<>(); | |
encoderDecoderAttention->Add(encoderDecoderAttentionInput); | |
encoderDecoderAttention->Add<MultiheadAttention<>>(); | |
/* Residual connection. */ | |
AddMerge<>* residualAddMerge = new AddMerge<>(); | |
residualAddMerge->Add(encoderDecoderAttention); | |
residualAddMerge->Add<IdentityLayer<>>(); | |
decoderBlock->Add(residualAddMerge); | |
} | |
{ | |
/* Some FFN network. */ | |
Sequential<>* pointWiseFeedForwardNetwork = new Sequential<>(); | |
// pointWiseFeedForwardNetwork->Add(......); | |
// pointWiseFeedForwardNetwork->Add(......); | |
/* Residual connection. */ | |
AddMerge<>* residualAddMerge = new AddMerge<>(); | |
residualAddMerge->Add(pointWiseFeedForwardNetwork); | |
residualAddMerge->Add<IdentityLayer<>>(); | |
decoderBlock->Add(residualAddMerge); | |
} | |
decoderBlock->Add<LayerNorm<>>(); | |
return decoderBlock; | |
} | |
FFN<> CreateModel() { | |
const int numEncoders = 5; | |
const int numDecoders = 5; | |
FFN<> model; | |
Sequential<>* encoder = new Sequential<>(); | |
/* Extract the encoder input from dataset records. */ | |
encoder->Add<Subview<>>(); | |
encoder->Add<Lookup<>>(); | |
encoder->Add<PositionalEncoding<>>(); | |
for (int i = 0; i < numEncoders; i++) | |
encoder->Add(CreateEncoderBlock()); | |
{ | |
/* Broadcast the encoder output numDecoders times. Each decoder | |
* will use its own encoder output (key). */ | |
Concat<>* broadcastEncoderOutput = new Concat<>(); | |
for (int i = 0; i < numDecoders; i++) | |
broadcastEncoderOutput->Add<IdentityLayer<>>(); | |
encoder->Add(broadcastEncoderOutput); | |
} | |
Sequential<>* decoder = new Sequential<>(); | |
/* Extract the decoder input from dataset records. */ | |
decoder->Add<Subview<>>(); | |
decoder->Add<Lookup<>>(); | |
decoder->Add<PositionalEncoding<>>(); | |
{ | |
/* Concatenate the encoder output (numDecoders key matrices) and | |
* the decoder output. */ | |
Concat<>* encoderDecoderConcat = new Concat<>(); | |
encoderDecoderConcat->Add(encoder); | |
encoderDecoderConcat->Add(decoder); | |
/* Finally add it to the model. */ | |
model.Add(encoderDecoderConcat); | |
} | |
{ | |
for (int i = 0; i < numDecoders; i++) { | |
/* This layer extracts one key matrix from the input and the output | |
* of the previous decoder block. */ | |
Concat<>* decoderBlockInput = new Concat<>(); | |
decoderBlockInput->Add<Subview<>>(); | |
decoderBlockInput->Add<Subview<>>(); | |
/* The decoder block itself. */ | |
Sequential<>* decoderBlock = new Sequential<>(); | |
decoderBlock->Add(decoderBlockInput); | |
decoderBlock->Add(CreateDecoderBlock()); | |
/* This layer extracts the remaining encoder output (the remaining | |
* key matrices. Each iteration the number of key matrices decreases | |
* by one since each decoder block takes it as the input. */ | |
Subview<>* remainingEncoderOutputs = new Subview<>(); | |
/* This layer concatenates unused encoder output (unused key matrices) | |
* and the decoder block output. */ | |
Concat<>* decoderBlockOutput = new Concat<>(); | |
decoderBlockOutput->Add(remainingEncoderOutputs); | |
decoderBlockOutput->Add(decoderBlock); | |
/* Finally add the block to the model. */ | |
model.Add(decoderBlockOutput); | |
} | |
} | |
/* All the encoders and decoders are ready! Let's add the top layers. */ | |
model.Add<Linear3D<>>(); | |
model.Add<Softmax<>>(); | |
return model; | |
} | |
int main() { | |
FFN<> model = CreateModel(); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment