Skip to content

Instantly share code, notes, and snippets.

@lozhnikov
Last active August 20, 2020 10:25
Show Gist options
  • Save lozhnikov/aabb9231c0bb72528ff64a4f9bc19923 to your computer and use it in GitHub Desktop.
Save lozhnikov/aabb9231c0bb72528ff64a4f9bc19923 to your computer and use it in GitHub Desktop.
Transformer model draft
#include <mlpack/methods/ann/ffn.hpp>
#include <mlpack/methods/ann/layer/layer.hpp>
using namespace mlpack::ann;
/* This layer accepts the output of the positional encoding or the output
* of the previous encoder block as the input.
* The output is one key/value matrix.
*/
Sequential<>* CreateEncoderBlock() {
Sequential<>* encoderBlock = new Sequential<>();
{
/* Broadcast (query) into (query, key, value). */
Concat<>* selfAttentionInput = new Concat<>();
selfAttentionInput->Add<IdentityLayer<>>();
selfAttentionInput->Add<IdentityLayer<>>();
selfAttentionInput->Add<IdentityLayer<>>();
/* Self attention layer. */
Sequential<>* selfAttention = new Sequential<>();
selfAttention->Add(selfAttentionInput);
selfAttention->Add<MultiheadAttention<>>();
/* This layer adds a residual connection. */
AddMerge<>* residualAddMerge = new AddMerge<>();
residualAddMerge->Add(selfAttention);
residualAddMerge->Add<IdentityLayer<>>();
encoderBlock->Add(residualAddMerge);
}
encoderBlock->Add<LayerNorm<>>();
{
/* Some FFN. It could contain a number of dense layers with activations. */
Sequential<>* pointWiseFeedForwardNetwork = new Sequential<>();
// pointWiseFeedForwardNetwork->Add(......);
// pointWiseFeedForwardNetwork->Add(......);
/* This layer adds a residual connection. */
AddMerge<>* residualAddMerge = new AddMerge<>();
residualAddMerge->Add(pointWiseFeedForwardNetwork);
residualAddMerge->Add<IdentityLayer<>>();
encoderBlock->Add(residualAddMerge);
}
encoderBlock->Add<LayerNorm<>>();
return encoderBlock;
}
/*
* A decoder block. The input of the block consists of 2 matrices:
* the first one is the output of the encoder, and the second one is the output
* of the positional encoder or the previous decoder block.
*/
Sequential<>* CreateDecoderBlock() {
Sequential<>* decoderBlockBottom = new Sequential<>();
/* Extract the output of the previous decoder block or the output of the
* positional encoder. */
decoderBlockBottom->Add<Subview<>>();
{
/* Broadcast (query) into (query, key, value). */
Concat<>* selfAttentionInput = new Concat<>();
selfAttentionInput->Add<IdentityLayer<>>();
selfAttentionInput->Add<IdentityLayer<>>();
selfAttentionInput->Add<IdentityLayer<>>();
/* Self attention layer. */
Sequential<>* selfAttention = new Sequential<>();
selfAttention->Add(selfAttentionInput);
selfAttention->Add<MultiheadAttention<>>();
/* Residual connection. */
AddMerge<>* residualAddMerge = new AddMerge<>();
residualAddMerge->Add(selfAttention);
residualAddMerge->Add<IdentityLayer<>>();
decoderBlockBottom->Add(residualAddMerge);
}
decoderBlockBottom->Add<LayerNorm<>>();
Sequential<>* decoderBlock = new Sequential<>();
{
/* This layer extracts the output of the decoder and broadcasts it
* two times i.e. it extracts (key) and broadcasts it into
* (key, value). */
Concat<>* broadcastEncoderOutput = new Concat<>();
broadcastEncoderOutput->Add<Subview<>>();
broadcastEncoderOutput->Add<Subview<>>();
/* This layer concatenates the output of the block bottom (query)
* and the output of the encoder (key, value). */
Concat<>* encoderDecoderAttentionInput = new Concat<>();
encoderDecoderAttentionInput->Add(decoderBlockBottom);
encoderDecoderAttentionInput->Add(broadcastEncoderOutput);
/* Encoder-decoder attention. */
Sequential<>* encoderDecoderAttention = new Sequential<>();
encoderDecoderAttention->Add(encoderDecoderAttentionInput);
encoderDecoderAttention->Add<MultiheadAttention<>>();
/* Residual connection. */
AddMerge<>* residualAddMerge = new AddMerge<>();
residualAddMerge->Add(encoderDecoderAttention);
residualAddMerge->Add<IdentityLayer<>>();
decoderBlock->Add(residualAddMerge);
}
{
/* Some FFN network. */
Sequential<>* pointWiseFeedForwardNetwork = new Sequential<>();
// pointWiseFeedForwardNetwork->Add(......);
// pointWiseFeedForwardNetwork->Add(......);
/* Residual connection. */
AddMerge<>* residualAddMerge = new AddMerge<>();
residualAddMerge->Add(pointWiseFeedForwardNetwork);
residualAddMerge->Add<IdentityLayer<>>();
decoderBlock->Add(residualAddMerge);
}
decoderBlock->Add<LayerNorm<>>();
return decoderBlock;
}
FFN<> CreateModel() {
const int numEncoders = 5;
const int numDecoders = 5;
FFN<> model;
Sequential<>* encoder = new Sequential<>();
/* Extract the encoder input from dataset records. */
encoder->Add<Subview<>>();
encoder->Add<Lookup<>>();
encoder->Add<PositionalEncoding<>>();
for (int i = 0; i < numEncoders; i++)
encoder->Add(CreateEncoderBlock());
{
/* Broadcast the encoder output numDecoders times. Each decoder
* will use its own encoder output (key). */
Concat<>* broadcastEncoderOutput = new Concat<>();
for (int i = 0; i < numDecoders; i++)
broadcastEncoderOutput->Add<IdentityLayer<>>();
encoder->Add(broadcastEncoderOutput);
}
Sequential<>* decoder = new Sequential<>();
/* Extract the decoder input from dataset records. */
decoder->Add<Subview<>>();
decoder->Add<Lookup<>>();
decoder->Add<PositionalEncoding<>>();
{
/* Concatenate the encoder output (numDecoders key matrices) and
* the decoder output. */
Concat<>* encoderDecoderConcat = new Concat<>();
encoderDecoderConcat->Add(encoder);
encoderDecoderConcat->Add(decoder);
/* Finally add it to the model. */
model.Add(encoderDecoderConcat);
}
{
for (int i = 0; i < numDecoders; i++) {
/* This layer extracts one key matrix from the input and the output
* of the previous decoder block. */
Concat<>* decoderBlockInput = new Concat<>();
decoderBlockInput->Add<Subview<>>();
decoderBlockInput->Add<Subview<>>();
/* The decoder block itself. */
Sequential<>* decoderBlock = new Sequential<>();
decoderBlock->Add(decoderBlockInput);
decoderBlock->Add(CreateDecoderBlock());
/* This layer extracts the remaining encoder output (the remaining
* key matrices. Each iteration the number of key matrices decreases
* by one since each decoder block takes it as the input. */
Subview<>* remainingEncoderOutputs = new Subview<>();
/* This layer concatenates unused encoder output (unused key matrices)
* and the decoder block output. */
Concat<>* decoderBlockOutput = new Concat<>();
decoderBlockOutput->Add(remainingEncoderOutputs);
decoderBlockOutput->Add(decoderBlock);
/* Finally add the block to the model. */
model.Add(decoderBlockOutput);
}
}
/* All the encoders and decoders are ready! Let's add the top layers. */
model.Add<Linear3D<>>();
model.Add<Softmax<>>();
return model;
}
int main() {
FFN<> model = CreateModel();
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment