Skip to content

Instantly share code, notes, and snippets.

@sevagh
Created November 28, 2023 18:06
Show Gist options
  • Save sevagh/b71d253a347a9b59c026580625452fc5 to your computer and use it in GitHub Desktop.
Save sevagh/b71d253a347a9b59c026580625452fc5 to your computer and use it in GitHub Desktop.
multihead attention in Eigen/C++
void demucscpp::common_encoder_layer(
Eigen::Tensor3dXf &q, // q = x = frequency
const Eigen::Tensor3dXf &k, // k = xt = time
const Eigen::Tensor1dXf &norm1_weight, const Eigen::Tensor1dXf &norm1_bias,
const Eigen::Tensor1dXf &norm2_weight, const Eigen::Tensor1dXf &norm2_bias,
const Eigen::MatrixXf &in_proj_weight, const Eigen::VectorXf &in_proj_bias,
const Eigen::MatrixXf &out_proj_weight,
const Eigen::VectorXf &out_proj_bias, const Eigen::VectorXf &gamma1_scale,
const Eigen::Tensor1dXf &norm3_weight, const Eigen::Tensor1dXf &norm3_bias,
const Eigen::MatrixXf &linear1_weight, const Eigen::VectorXf &linear1_bias,
const Eigen::MatrixXf &linear2_weight, const Eigen::VectorXf &linear2_bias,
const Eigen::VectorXf &gamma2_scale,
const Eigen::Tensor1dXf &norm_out_weight,
const Eigen::Tensor1dXf &norm_out_bias, float eps, const int num_heads)
{
// Normalize x using the norm1 weights and biases
Eigen::Tensor3dXf q_norm =
demucscpp::layer_norm(q, norm1_weight, norm1_bias, eps);
Eigen::Tensor3dXf k_norm =
demucscpp::layer_norm(k, norm2_weight, norm2_bias, eps);
// Cross-attention block
// Compute Q, K, V matrices
int B = q.dimension(0);
int T = q.dimension(1);
int C = q.dimension(2);
int B_k = k.dimension(0);
int S = k.dimension(1);
int C_k = k.dimension(2);
demucscppdebug::assert_(B == B_k);
demucscppdebug::assert_(B == 1);
demucscppdebug::assert_(C == C_k);
// Reshape q, k to 2D matrix of dimensions (T*B, C)
// Use Eigen::Map to avoid manual loops for reshaping
Eigen::MatrixXf q_norm_2d =
Eigen::Map<const Eigen::MatrixXf>(q_norm.data(), T, C);
Eigen::MatrixXf k_norm_2d =
Eigen::Map<const Eigen::MatrixXf>(k_norm.data(), S, C);
// Compute Q, K, V matrices
Eigen::MatrixXf Q =
q_norm_2d * in_proj_weight.block(0, 0, C, C).transpose();
Eigen::MatrixXf K =
k_norm_2d * in_proj_weight.block(C, 0, C, C).transpose();
Eigen::MatrixXf V =
k_norm_2d * in_proj_weight.block(2 * C, 0, C, C).transpose();
Eigen::VectorXf q_bias = in_proj_bias.segment(0, C);
Eigen::VectorXf k_bias = in_proj_bias.segment(C, C);
Eigen::VectorXf v_bias = in_proj_bias.segment(2 * C, C);
int head_split = C / num_heads;
Eigen::Tensor3dXf Q_heads(T, num_heads, head_split);
Eigen::Tensor3dXf K_heads(S, num_heads, head_split);
Eigen::MatrixXf V_heads_2d(S * num_heads, head_split);
// reverse loop order to combine with K and V
for (int d = 0; d < head_split; ++d)
{
for (int h = 0; h < num_heads; ++h)
{
for (int t = 0; t < T; ++t)
{
Q_heads(t, h, d) =
Q(t, h * head_split + d) + q_bias(h * head_split + d);
}
for (int s = 0; s < S; ++s)
{
K_heads(s, h, d) =
K(s, h * head_split + d) + k_bias(h * head_split + d);
V_heads_2d(s * num_heads + h, d) =
V(s, h * head_split + d) + v_bias(h * head_split + d);
}
}
}
// Compute cross-attention scores
Eigen::MatrixXf scores(num_heads * T, S); // Initialize to zeros
for (int h = 0; h < num_heads; ++h)
{
// Extract the h-th head from Q_heads and K_heads
Eigen::Tensor<float, 2> Q_head_tensor = Q_heads.chip(h, 1);
Eigen::Tensor<float, 2> K_head_tensor = K_heads.chip(h, 1);
// Reshape the tensors to matrices
Eigen::Map<Eigen::MatrixXf> Q_head(Q_head_tensor.data(), T, head_split);
Eigen::Map<Eigen::MatrixXf> K_head(K_head_tensor.data(), S, head_split);
// Compute the dot product of Q_head and K_head
Eigen::MatrixXf dot_product = Q_head * K_head.transpose();
// Store the result in scores
scores.block(h * T, 0, T, S) = dot_product / std::sqrt((float)head_split);
}
// Apply softmax to scores
Eigen::ArrayXf max_vals = scores.rowwise().maxCoeff();
Eigen::MatrixXf max_vals_expanded = max_vals.replicate(1, scores.cols());
scores = (scores - max_vals_expanded).array().exp().matrix();
Eigen::VectorXf row_sums = scores.rowwise().sum();
Eigen::MatrixXf divisor = row_sums.replicate(1, scores.cols());
scores = (scores.array() / divisor.array()).matrix();
// Compute cross-attention output
std::vector<Eigen::MatrixXf> cross_attn_out_3d;
std::vector<Eigen::MatrixXf> V_heads_3d;
std::vector<Eigen::MatrixXf> scores_3d;
for (int h = 0; h < num_heads; ++h)
{
V_heads_3d.push_back(Eigen::MatrixXf(S, head_split));
scores_3d.push_back(Eigen::MatrixXf(T, S));
cross_attn_out_3d.push_back(Eigen::MatrixXf(T, head_split));
}
// first copy V_heads_2d, scores into 3d tensors
for (int h = 0; h < num_heads; ++h)
{
for (int s = 0; s < S; ++s)
{
for (int t = 0; t < T; ++t)
{
scores_3d[h](t, s) = scores(h * T + t, s);
}
for (int d = 0; d < head_split; ++d)
{
V_heads_3d[h](s, d) = V_heads_2d(s * num_heads + h, d);
}
}
}
// now loop over 8 and do inner matmuls, assigning
// results to cross_attn_out_3d
for (int h = 0; h < num_heads; ++h)
{
cross_attn_out_3d[h] = scores_3d[h] * V_heads_3d[h];
}
Eigen::MatrixXf cross_attn_out(T, C);
// now copy cross_attn_out_3d into cross_attn_out
// from shape (8, T, 64) to (T, C)
for (int t = 0; t < T; ++t)
{
for (int c = 0; c < C; ++c)
{
int h = c / head_split;
int k_ = c % head_split;
cross_attn_out(t, c) = cross_attn_out_3d[h](t, k_);
}
}
// Apply output projection
Eigen::MatrixXf out_proj = cross_attn_out * out_proj_weight.transpose();
out_proj.array().rowwise() += out_proj_bias.transpose().array();
// now we need x = q + out_proj, but let's store that in 3d q
for (int t = 0; t < T; ++t)
{
for (int c = 0; c < C; ++c)
{
q(0, t, c) += out_proj(t, c) * gamma1_scale(c);
}
}
// copy q into x_2d
Eigen::MatrixXf q_2d(T, C);
q_2d = Eigen::Map<const Eigen::MatrixXf>(q.data(), T, C);
// before feedforward, apply norm3 to x i.e. q
q_norm = demucscpp::layer_norm(q, norm3_weight, norm3_bias, eps);
q_norm_2d = Eigen::Map<const Eigen::MatrixXf>(q_norm.data(), T, C);
// Feedforward block
// Linear layer 1
Eigen::MatrixXf ff1 = q_norm_2d * linear1_weight.transpose();
ff1.rowwise() += linear1_bias.transpose();
ff1 = demucscpp::gelu(ff1);
// Linear layer 2
Eigen::MatrixXf ff2 = ff1 * linear2_weight.transpose();
ff2.rowwise() += linear2_bias.transpose();
// Apply gamma_2 scale directly on 2D matrix
ff2 = ff2.array().rowwise() * gamma2_scale.transpose().array();
// now x = x + self.gamma_2(self._ff_block(self.norm3(q))))
q_2d += ff2;
// Map the 2D data back into a 3D tensor with dimensions (T, B, C)
q = Eigen::TensorMap<Eigen::Tensor3dXf>(q_2d.data(), T, B, C);
// Swap the first and last dimensions to get a tensor with dimensions (B, C,
// T)
Eigen::array<int, 3> permute_dims = {1, 2, 0};
Eigen::Tensor3dXf q_shuf = q.shuffle(permute_dims);
// Normalize the output with norm_out/MyGroupNorm
q = demucscpp::group_norm(q_shuf, norm_out_weight, norm_out_bias, 1, eps);
Eigen::array<int, 3> permute_dims_2 = {0, 2, 1};
q_shuf = q.shuffle(permute_dims_2);
q = q_shuf;
}
#include "crosstransformer.hpp"
#include "layers.hpp"
#include "model.hpp"
#include "tensor.hpp"
#include <Eigen/Dense>
#include <filesystem>
#include <fstream>
#include <iostream>
static Eigen::Tensor3dXf create_2d_sin_embedding(int d_model, int height,
int width,
float max_period = 10000.0)
{
if (d_model % 4 != 0)
{
std::cerr << "Cannot use sin/cos positional encoding with odd dimension" << std::endl;
std::exit(1);
}
Eigen::Tensor3dXf pe(d_model, height, width);
d_model /= 2;
Eigen::ArrayXf div_term =
Eigen::exp(Eigen::ArrayXf::LinSpaced(d_model / 2, 0, d_model - 2) *
(-std::log(max_period) / d_model));
for (int i = 0; i < width; ++i)
{
for (int j = 0; j < d_model / 2; ++j)
{
float val_w = i * div_term(j);
pe.slice(Eigen::array<int, 3>({j * 2, 0, i}),
Eigen::array<int, 3>({1, height, 1}))
.setConstant(std::sin(val_w));
pe.slice(Eigen::array<int, 3>({j * 2 + 1, 0, i}),
Eigen::array<int, 3>({1, height, 1}))
.setConstant(std::cos(val_w));
}
}
for (int i = 0; i < height; ++i)
{
for (int j = 0; j < d_model / 2; ++j)
{
float val_h = i * div_term(j);
pe.slice(Eigen::array<int, 3>({d_model + j * 2, i, 0}),
Eigen::array<int, 3>({1, 1, width}))
.setConstant(std::sin(val_h));
pe.slice(Eigen::array<int, 3>({d_model + j * 2 + 1, i, 0}),
Eigen::array<int, 3>({1, 1, width}))
.setConstant(std::cos(val_h));
}
}
return pe;
}
static Eigen::Tensor3dXf create_sin_embedding(int length, int dim,
int shift = 0,
float max_period = 10000.0f)
{
Eigen::Tensor3dXf pos_emb(1, length, dim);
int half_dim = dim / 2;
Eigen::ArrayXf div_term =
Eigen::ArrayXf::LinSpaced(half_dim, 0, half_dim - 1) / (half_dim - 1);
for (int t = 0; t < length; ++t)
{
float position = static_cast<float>(t) + shift;
for (int i = 0; i < half_dim; ++i)
{
float phase = position / std::pow(max_period, div_term(i));
pos_emb(0, t, i) = std::cos(phase); // assign to first half
pos_emb(0, t, i + half_dim) =
std::sin(phase); // assign to second half
}
}
return pos_emb;
}
static void
my_transformer_encoder_layer(struct demucscpp::demucs_model_4s &model,
Eigen::Tensor3dXf &x, int freq_or_time,
int weight_idx, float eps = 1e-5)
{
demucscpp::common_encoder_layer(
x, // pass x as q
x, // pass x as k
model.crosstransformer_my_layers_norm1_weight[freq_or_time][weight_idx],
model.crosstransformer_my_layers_norm1_bias[freq_or_time][weight_idx],
model.crosstransformer_my_layers_norm1_weight[freq_or_time][weight_idx],
model.crosstransformer_my_layers_norm1_bias[freq_or_time][weight_idx],
model.crosstransformer_my_layers_self_attn_in_proj_weight[freq_or_time]
[weight_idx],
model.crosstransformer_my_layers_self_attn_in_proj_bias[freq_or_time]
[weight_idx],
model.crosstransformer_my_layers_self_attn_out_proj_weight[freq_or_time]
[weight_idx],
model.crosstransformer_my_layers_self_attn_out_proj_bias[freq_or_time]
[weight_idx],
model
.crosstransformer_my_layers_gamma_1_scale[freq_or_time][weight_idx],
model.crosstransformer_my_layers_norm2_weight[freq_or_time][weight_idx],
model.crosstransformer_my_layers_norm2_bias[freq_or_time][weight_idx],
model.crosstransformer_my_layers_linear1_weight[freq_or_time]
[weight_idx],
model.crosstransformer_my_layers_linear1_bias[freq_or_time][weight_idx],
model.crosstransformer_my_layers_linear2_weight[freq_or_time]
[weight_idx],
model.crosstransformer_my_layers_linear2_bias[freq_or_time][weight_idx],
model
.crosstransformer_my_layers_gamma_2_scale[freq_or_time][weight_idx],
model.crosstransformer_my_layers_norm_out_weight[freq_or_time]
[weight_idx],
model
.crosstransformer_my_layers_norm_out_bias[freq_or_time][weight_idx],
eps);
}
static void
cross_transformer_encoder_layer(struct demucscpp::demucs_model_4s &model,
Eigen::Tensor3dXf &q, // q = x = frequency
const Eigen::Tensor3dXf &k, // k = xt = time
int freq_or_time, int weight_idx,
float eps = 1e-5)
{
demucscpp::common_encoder_layer(
q, k,
model.crosstransformer_cross_layers_norm1_weight[freq_or_time]
[weight_idx],
model
.crosstransformer_cross_layers_norm1_bias[freq_or_time][weight_idx],
model.crosstransformer_cross_layers_norm2_weight[freq_or_time]
[weight_idx],
model
.crosstransformer_cross_layers_norm2_bias[freq_or_time][weight_idx],
model.crosstransformer_cross_layers_cross_attn_in_proj_weight
[freq_or_time][weight_idx],
model
.crosstransformer_cross_layers_cross_attn_in_proj_bias[freq_or_time]
[weight_idx],
model.crosstransformer_cross_layers_cross_attn_out_proj_weight
[freq_or_time][weight_idx],
model.crosstransformer_cross_layers_cross_attn_out_proj_bias
[freq_or_time][weight_idx],
model.crosstransformer_cross_layers_gamma_1_scale[freq_or_time]
[weight_idx],
model.crosstransformer_cross_layers_norm3_weight[freq_or_time]
[weight_idx],
model
.crosstransformer_cross_layers_norm3_bias[freq_or_time][weight_idx],
model.crosstransformer_cross_layers_linear1_weight[freq_or_time]
[weight_idx],
model.crosstransformer_cross_layers_linear1_bias[freq_or_time]
[weight_idx],
model.crosstransformer_cross_layers_linear2_weight[freq_or_time]
[weight_idx],
model.crosstransformer_cross_layers_linear2_bias[freq_or_time]
[weight_idx],
model.crosstransformer_cross_layers_gamma_2_scale[freq_or_time]
[weight_idx],
model.crosstransformer_cross_layers_norm_out_weight[freq_or_time]
[weight_idx],
model.crosstransformer_cross_layers_norm_out_bias[freq_or_time]
[weight_idx],
eps);
}
void demucscpp::apply_crosstransformer(struct demucscpp::demucs_model_4s &model,
Eigen::Tensor3dXf &x, // frequency branch
Eigen::Tensor3dXf &xt // time branch
)
{
std::cout << "apply_crosstransformer" << std::endl;
Eigen::Tensor3dXf pos_embed_2d_pre_reshape =
create_2d_sin_embedding(x.dimension(0), x.dimension(1), x.dimension(2));
Eigen::Tensor3dXf pos_embed_2d(1, x.dimension(1) * x.dimension(2),
x.dimension(0));
Eigen::Tensor3dXf x_reshape(1, x.dimension(1) * x.dimension(2),
x.dimension(0));
// x = rearrange(x, "b c fr t1 -> b (t1 fr) c")
// implement above with eigen for loops
// rearrange x too
for (int i = 0; i < x.dimension(1); ++i)
{
for (int j = 0; j < x.dimension(2); ++j)
{
for (int k = 0; k < x.dimension(0); ++k)
{
pos_embed_2d(0, j * x.dimension(1) + i, k) =
pos_embed_2d_pre_reshape(k, i, j);
x_reshape(0, j * x.dimension(1) + i, k) = x(k, i, j);
}
}
}
x = x_reshape;
float eps = 1e-5;
x = demucscpp::layer_norm(x, model.crosstransformer_norm_in_weight,
model.crosstransformer_norm_in_bias, eps) +
pos_embed_2d;
// (B, C, T2) = xt.shape
int C = xt.dimension(1);
int T2 = xt.dimension(2);
Eigen::Tensor3dXf pos_embed_1d = create_sin_embedding(T2, C);
// shuffle axes of xt from 0,1,2 to 0,2,1
Eigen::Tensor3dXf xt_shuf = xt.shuffle(Eigen::array<int, 3>{0, 2, 1});
xt = demucscpp::layer_norm(xt_shuf, model.crosstransformer_norm_in_t_weight,
model.crosstransformer_norm_in_t_bias, eps) +
pos_embed_1d;
demucscppdebug::debug_tensor_3dxf(x, "x crosstransformer pre-layers");
demucscppdebug::debug_tensor_3dxf(xt, "xt crosstransformer pre-tlayers");
// actual crosstransformer layers here
// layer 0 for freq and time is the first MyTransformerEncoderLayer
// the argument 0 passed in the function call is the weight index
// 0,2,4 -> 0,1,3 in my C++ code because i store the 3 mytransformer layers
// in a single array and the 2 crosstransformer layers in another array
// x = self.layers[0](x)
// xt = self.layers_t[0](xt)
my_transformer_encoder_layer(model, x, 0, 0);
my_transformer_encoder_layer(model, xt, 1, 0);
demucscppdebug::debug_tensor_3dxf(x, "x crosstransformer post-layer-0");
demucscppdebug::debug_tensor_3dxf(xt, "xt crosstransformer post-tlayer-0");
// make a copy of x
Eigen::Tensor3dXf old_x = x;
// x is modified in-place and is the final value of x
// xt is not modified (const)
cross_transformer_encoder_layer(model, x, xt, 0, 0);
// xt is modified in-place and is the final value of xt
cross_transformer_encoder_layer(model, xt, old_x, 1, 0);
demucscppdebug::debug_tensor_3dxf(x, "x crosstransformer post-layer-1");
demucscppdebug::debug_tensor_3dxf(xt, "xt crosstransformer post-tlayer-1");
my_transformer_encoder_layer(model, x, 0, 1);
my_transformer_encoder_layer(model, xt, 1, 1);
demucscppdebug::debug_tensor_3dxf(x, "x crosstransformer post-layer-2");
demucscppdebug::debug_tensor_3dxf(xt, "xt crosstransformer post-tlayer-2");
// make a copy of x
old_x = x;
// x is modified in-place and is the final value of x
cross_transformer_encoder_layer(model, x, xt, 0, 1);
// old_xt is modified in-place and is the final value of xt
cross_transformer_encoder_layer(model, xt, old_x, 1, 1);
demucscppdebug::debug_tensor_3dxf(x, "x crosstransformer post-layer-3");
demucscppdebug::debug_tensor_3dxf(xt, "xt crosstransformer post-tlayer-3");
my_transformer_encoder_layer(model, x, 0, 2);
my_transformer_encoder_layer(model, xt, 1, 2);
demucscppdebug::debug_tensor_3dxf(x, "x crosstransformer post-layer-4");
demucscppdebug::debug_tensor_3dxf(xt, "xt crosstransformer post-tlayer-4");
// permute last two dims of xt
Eigen::array<int, 3> permute_dims = {0, 2, 1};
Eigen::Tensor3dXf xt_ret = xt.shuffle(permute_dims);
xt = xt_ret;
// for x, transform from shape (1, 2688, 512) to
// (512, 8, 336)
// first also permute x
Eigen::Tensor3dXf x_shuf = x.shuffle(permute_dims);
x = x_shuf;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment