Created
November 28, 2023 18:06
-
-
Save sevagh/b71d253a347a9b59c026580625452fc5 to your computer and use it in GitHub Desktop.
multihead attention in Eigen/C++
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
void demucscpp::common_encoder_layer( | |
Eigen::Tensor3dXf &q, // q = x = frequency | |
const Eigen::Tensor3dXf &k, // k = xt = time | |
const Eigen::Tensor1dXf &norm1_weight, const Eigen::Tensor1dXf &norm1_bias, | |
const Eigen::Tensor1dXf &norm2_weight, const Eigen::Tensor1dXf &norm2_bias, | |
const Eigen::MatrixXf &in_proj_weight, const Eigen::VectorXf &in_proj_bias, | |
const Eigen::MatrixXf &out_proj_weight, | |
const Eigen::VectorXf &out_proj_bias, const Eigen::VectorXf &gamma1_scale, | |
const Eigen::Tensor1dXf &norm3_weight, const Eigen::Tensor1dXf &norm3_bias, | |
const Eigen::MatrixXf &linear1_weight, const Eigen::VectorXf &linear1_bias, | |
const Eigen::MatrixXf &linear2_weight, const Eigen::VectorXf &linear2_bias, | |
const Eigen::VectorXf &gamma2_scale, | |
const Eigen::Tensor1dXf &norm_out_weight, | |
const Eigen::Tensor1dXf &norm_out_bias, float eps, const int num_heads) | |
{ | |
// Normalize x using the norm1 weights and biases | |
Eigen::Tensor3dXf q_norm = | |
demucscpp::layer_norm(q, norm1_weight, norm1_bias, eps); | |
Eigen::Tensor3dXf k_norm = | |
demucscpp::layer_norm(k, norm2_weight, norm2_bias, eps); | |
// Cross-attention block | |
// Compute Q, K, V matrices | |
int B = q.dimension(0); | |
int T = q.dimension(1); | |
int C = q.dimension(2); | |
int B_k = k.dimension(0); | |
int S = k.dimension(1); | |
int C_k = k.dimension(2); | |
demucscppdebug::assert_(B == B_k); | |
demucscppdebug::assert_(B == 1); | |
demucscppdebug::assert_(C == C_k); | |
// Reshape q, k to 2D matrix of dimensions (T*B, C) | |
// Use Eigen::Map to avoid manual loops for reshaping | |
Eigen::MatrixXf q_norm_2d = | |
Eigen::Map<const Eigen::MatrixXf>(q_norm.data(), T, C); | |
Eigen::MatrixXf k_norm_2d = | |
Eigen::Map<const Eigen::MatrixXf>(k_norm.data(), S, C); | |
// Compute Q, K, V matrices | |
Eigen::MatrixXf Q = | |
q_norm_2d * in_proj_weight.block(0, 0, C, C).transpose(); | |
Eigen::MatrixXf K = | |
k_norm_2d * in_proj_weight.block(C, 0, C, C).transpose(); | |
Eigen::MatrixXf V = | |
k_norm_2d * in_proj_weight.block(2 * C, 0, C, C).transpose(); | |
Eigen::VectorXf q_bias = in_proj_bias.segment(0, C); | |
Eigen::VectorXf k_bias = in_proj_bias.segment(C, C); | |
Eigen::VectorXf v_bias = in_proj_bias.segment(2 * C, C); | |
int head_split = C / num_heads; | |
Eigen::Tensor3dXf Q_heads(T, num_heads, head_split); | |
Eigen::Tensor3dXf K_heads(S, num_heads, head_split); | |
Eigen::MatrixXf V_heads_2d(S * num_heads, head_split); | |
// reverse loop order to combine with K and V | |
for (int d = 0; d < head_split; ++d) | |
{ | |
for (int h = 0; h < num_heads; ++h) | |
{ | |
for (int t = 0; t < T; ++t) | |
{ | |
Q_heads(t, h, d) = | |
Q(t, h * head_split + d) + q_bias(h * head_split + d); | |
} | |
for (int s = 0; s < S; ++s) | |
{ | |
K_heads(s, h, d) = | |
K(s, h * head_split + d) + k_bias(h * head_split + d); | |
V_heads_2d(s * num_heads + h, d) = | |
V(s, h * head_split + d) + v_bias(h * head_split + d); | |
} | |
} | |
} | |
// Compute cross-attention scores | |
Eigen::MatrixXf scores(num_heads * T, S); // Initialize to zeros | |
for (int h = 0; h < num_heads; ++h) | |
{ | |
// Extract the h-th head from Q_heads and K_heads | |
Eigen::Tensor<float, 2> Q_head_tensor = Q_heads.chip(h, 1); | |
Eigen::Tensor<float, 2> K_head_tensor = K_heads.chip(h, 1); | |
// Reshape the tensors to matrices | |
Eigen::Map<Eigen::MatrixXf> Q_head(Q_head_tensor.data(), T, head_split); | |
Eigen::Map<Eigen::MatrixXf> K_head(K_head_tensor.data(), S, head_split); | |
// Compute the dot product of Q_head and K_head | |
Eigen::MatrixXf dot_product = Q_head * K_head.transpose(); | |
// Store the result in scores | |
scores.block(h * T, 0, T, S) = dot_product / std::sqrt((float)head_split); | |
} | |
// Apply softmax to scores | |
Eigen::ArrayXf max_vals = scores.rowwise().maxCoeff(); | |
Eigen::MatrixXf max_vals_expanded = max_vals.replicate(1, scores.cols()); | |
scores = (scores - max_vals_expanded).array().exp().matrix(); | |
Eigen::VectorXf row_sums = scores.rowwise().sum(); | |
Eigen::MatrixXf divisor = row_sums.replicate(1, scores.cols()); | |
scores = (scores.array() / divisor.array()).matrix(); | |
// Compute cross-attention output | |
std::vector<Eigen::MatrixXf> cross_attn_out_3d; | |
std::vector<Eigen::MatrixXf> V_heads_3d; | |
std::vector<Eigen::MatrixXf> scores_3d; | |
for (int h = 0; h < num_heads; ++h) | |
{ | |
V_heads_3d.push_back(Eigen::MatrixXf(S, head_split)); | |
scores_3d.push_back(Eigen::MatrixXf(T, S)); | |
cross_attn_out_3d.push_back(Eigen::MatrixXf(T, head_split)); | |
} | |
// first copy V_heads_2d, scores into 3d tensors | |
for (int h = 0; h < num_heads; ++h) | |
{ | |
for (int s = 0; s < S; ++s) | |
{ | |
for (int t = 0; t < T; ++t) | |
{ | |
scores_3d[h](t, s) = scores(h * T + t, s); | |
} | |
for (int d = 0; d < head_split; ++d) | |
{ | |
V_heads_3d[h](s, d) = V_heads_2d(s * num_heads + h, d); | |
} | |
} | |
} | |
// now loop over 8 and do inner matmuls, assigning | |
// results to cross_attn_out_3d | |
for (int h = 0; h < num_heads; ++h) | |
{ | |
cross_attn_out_3d[h] = scores_3d[h] * V_heads_3d[h]; | |
} | |
Eigen::MatrixXf cross_attn_out(T, C); | |
// now copy cross_attn_out_3d into cross_attn_out | |
// from shape (8, T, 64) to (T, C) | |
for (int t = 0; t < T; ++t) | |
{ | |
for (int c = 0; c < C; ++c) | |
{ | |
int h = c / head_split; | |
int k_ = c % head_split; | |
cross_attn_out(t, c) = cross_attn_out_3d[h](t, k_); | |
} | |
} | |
// Apply output projection | |
Eigen::MatrixXf out_proj = cross_attn_out * out_proj_weight.transpose(); | |
out_proj.array().rowwise() += out_proj_bias.transpose().array(); | |
// now we need x = q + out_proj, but let's store that in 3d q | |
for (int t = 0; t < T; ++t) | |
{ | |
for (int c = 0; c < C; ++c) | |
{ | |
q(0, t, c) += out_proj(t, c) * gamma1_scale(c); | |
} | |
} | |
// copy q into x_2d | |
Eigen::MatrixXf q_2d(T, C); | |
q_2d = Eigen::Map<const Eigen::MatrixXf>(q.data(), T, C); | |
// before feedforward, apply norm3 to x i.e. q | |
q_norm = demucscpp::layer_norm(q, norm3_weight, norm3_bias, eps); | |
q_norm_2d = Eigen::Map<const Eigen::MatrixXf>(q_norm.data(), T, C); | |
// Feedforward block | |
// Linear layer 1 | |
Eigen::MatrixXf ff1 = q_norm_2d * linear1_weight.transpose(); | |
ff1.rowwise() += linear1_bias.transpose(); | |
ff1 = demucscpp::gelu(ff1); | |
// Linear layer 2 | |
Eigen::MatrixXf ff2 = ff1 * linear2_weight.transpose(); | |
ff2.rowwise() += linear2_bias.transpose(); | |
// Apply gamma_2 scale directly on 2D matrix | |
ff2 = ff2.array().rowwise() * gamma2_scale.transpose().array(); | |
// now x = x + self.gamma_2(self._ff_block(self.norm3(q)))) | |
q_2d += ff2; | |
// Map the 2D data back into a 3D tensor with dimensions (T, B, C) | |
q = Eigen::TensorMap<Eigen::Tensor3dXf>(q_2d.data(), T, B, C); | |
// Swap the first and last dimensions to get a tensor with dimensions (B, C, | |
// T) | |
Eigen::array<int, 3> permute_dims = {1, 2, 0}; | |
Eigen::Tensor3dXf q_shuf = q.shuffle(permute_dims); | |
// Normalize the output with norm_out/MyGroupNorm | |
q = demucscpp::group_norm(q_shuf, norm_out_weight, norm_out_bias, 1, eps); | |
Eigen::array<int, 3> permute_dims_2 = {0, 2, 1}; | |
q_shuf = q.shuffle(permute_dims_2); | |
q = q_shuf; | |
} | |
#include "crosstransformer.hpp" | |
#include "layers.hpp" | |
#include "model.hpp" | |
#include "tensor.hpp" | |
#include <Eigen/Dense> | |
#include <filesystem> | |
#include <fstream> | |
#include <iostream> | |
static Eigen::Tensor3dXf create_2d_sin_embedding(int d_model, int height, | |
int width, | |
float max_period = 10000.0) | |
{ | |
if (d_model % 4 != 0) | |
{ | |
std::cerr << "Cannot use sin/cos positional encoding with odd dimension" << std::endl; | |
std::exit(1); | |
} | |
Eigen::Tensor3dXf pe(d_model, height, width); | |
d_model /= 2; | |
Eigen::ArrayXf div_term = | |
Eigen::exp(Eigen::ArrayXf::LinSpaced(d_model / 2, 0, d_model - 2) * | |
(-std::log(max_period) / d_model)); | |
for (int i = 0; i < width; ++i) | |
{ | |
for (int j = 0; j < d_model / 2; ++j) | |
{ | |
float val_w = i * div_term(j); | |
pe.slice(Eigen::array<int, 3>({j * 2, 0, i}), | |
Eigen::array<int, 3>({1, height, 1})) | |
.setConstant(std::sin(val_w)); | |
pe.slice(Eigen::array<int, 3>({j * 2 + 1, 0, i}), | |
Eigen::array<int, 3>({1, height, 1})) | |
.setConstant(std::cos(val_w)); | |
} | |
} | |
for (int i = 0; i < height; ++i) | |
{ | |
for (int j = 0; j < d_model / 2; ++j) | |
{ | |
float val_h = i * div_term(j); | |
pe.slice(Eigen::array<int, 3>({d_model + j * 2, i, 0}), | |
Eigen::array<int, 3>({1, 1, width})) | |
.setConstant(std::sin(val_h)); | |
pe.slice(Eigen::array<int, 3>({d_model + j * 2 + 1, i, 0}), | |
Eigen::array<int, 3>({1, 1, width})) | |
.setConstant(std::cos(val_h)); | |
} | |
} | |
return pe; | |
} | |
static Eigen::Tensor3dXf create_sin_embedding(int length, int dim, | |
int shift = 0, | |
float max_period = 10000.0f) | |
{ | |
Eigen::Tensor3dXf pos_emb(1, length, dim); | |
int half_dim = dim / 2; | |
Eigen::ArrayXf div_term = | |
Eigen::ArrayXf::LinSpaced(half_dim, 0, half_dim - 1) / (half_dim - 1); | |
for (int t = 0; t < length; ++t) | |
{ | |
float position = static_cast<float>(t) + shift; | |
for (int i = 0; i < half_dim; ++i) | |
{ | |
float phase = position / std::pow(max_period, div_term(i)); | |
pos_emb(0, t, i) = std::cos(phase); // assign to first half | |
pos_emb(0, t, i + half_dim) = | |
std::sin(phase); // assign to second half | |
} | |
} | |
return pos_emb; | |
} | |
static void | |
my_transformer_encoder_layer(struct demucscpp::demucs_model_4s &model, | |
Eigen::Tensor3dXf &x, int freq_or_time, | |
int weight_idx, float eps = 1e-5) | |
{ | |
demucscpp::common_encoder_layer( | |
x, // pass x as q | |
x, // pass x as k | |
model.crosstransformer_my_layers_norm1_weight[freq_or_time][weight_idx], | |
model.crosstransformer_my_layers_norm1_bias[freq_or_time][weight_idx], | |
model.crosstransformer_my_layers_norm1_weight[freq_or_time][weight_idx], | |
model.crosstransformer_my_layers_norm1_bias[freq_or_time][weight_idx], | |
model.crosstransformer_my_layers_self_attn_in_proj_weight[freq_or_time] | |
[weight_idx], | |
model.crosstransformer_my_layers_self_attn_in_proj_bias[freq_or_time] | |
[weight_idx], | |
model.crosstransformer_my_layers_self_attn_out_proj_weight[freq_or_time] | |
[weight_idx], | |
model.crosstransformer_my_layers_self_attn_out_proj_bias[freq_or_time] | |
[weight_idx], | |
model | |
.crosstransformer_my_layers_gamma_1_scale[freq_or_time][weight_idx], | |
model.crosstransformer_my_layers_norm2_weight[freq_or_time][weight_idx], | |
model.crosstransformer_my_layers_norm2_bias[freq_or_time][weight_idx], | |
model.crosstransformer_my_layers_linear1_weight[freq_or_time] | |
[weight_idx], | |
model.crosstransformer_my_layers_linear1_bias[freq_or_time][weight_idx], | |
model.crosstransformer_my_layers_linear2_weight[freq_or_time] | |
[weight_idx], | |
model.crosstransformer_my_layers_linear2_bias[freq_or_time][weight_idx], | |
model | |
.crosstransformer_my_layers_gamma_2_scale[freq_or_time][weight_idx], | |
model.crosstransformer_my_layers_norm_out_weight[freq_or_time] | |
[weight_idx], | |
model | |
.crosstransformer_my_layers_norm_out_bias[freq_or_time][weight_idx], | |
eps); | |
} | |
static void | |
cross_transformer_encoder_layer(struct demucscpp::demucs_model_4s &model, | |
Eigen::Tensor3dXf &q, // q = x = frequency | |
const Eigen::Tensor3dXf &k, // k = xt = time | |
int freq_or_time, int weight_idx, | |
float eps = 1e-5) | |
{ | |
demucscpp::common_encoder_layer( | |
q, k, | |
model.crosstransformer_cross_layers_norm1_weight[freq_or_time] | |
[weight_idx], | |
model | |
.crosstransformer_cross_layers_norm1_bias[freq_or_time][weight_idx], | |
model.crosstransformer_cross_layers_norm2_weight[freq_or_time] | |
[weight_idx], | |
model | |
.crosstransformer_cross_layers_norm2_bias[freq_or_time][weight_idx], | |
model.crosstransformer_cross_layers_cross_attn_in_proj_weight | |
[freq_or_time][weight_idx], | |
model | |
.crosstransformer_cross_layers_cross_attn_in_proj_bias[freq_or_time] | |
[weight_idx], | |
model.crosstransformer_cross_layers_cross_attn_out_proj_weight | |
[freq_or_time][weight_idx], | |
model.crosstransformer_cross_layers_cross_attn_out_proj_bias | |
[freq_or_time][weight_idx], | |
model.crosstransformer_cross_layers_gamma_1_scale[freq_or_time] | |
[weight_idx], | |
model.crosstransformer_cross_layers_norm3_weight[freq_or_time] | |
[weight_idx], | |
model | |
.crosstransformer_cross_layers_norm3_bias[freq_or_time][weight_idx], | |
model.crosstransformer_cross_layers_linear1_weight[freq_or_time] | |
[weight_idx], | |
model.crosstransformer_cross_layers_linear1_bias[freq_or_time] | |
[weight_idx], | |
model.crosstransformer_cross_layers_linear2_weight[freq_or_time] | |
[weight_idx], | |
model.crosstransformer_cross_layers_linear2_bias[freq_or_time] | |
[weight_idx], | |
model.crosstransformer_cross_layers_gamma_2_scale[freq_or_time] | |
[weight_idx], | |
model.crosstransformer_cross_layers_norm_out_weight[freq_or_time] | |
[weight_idx], | |
model.crosstransformer_cross_layers_norm_out_bias[freq_or_time] | |
[weight_idx], | |
eps); | |
} | |
void demucscpp::apply_crosstransformer(struct demucscpp::demucs_model_4s &model, | |
Eigen::Tensor3dXf &x, // frequency branch | |
Eigen::Tensor3dXf &xt // time branch | |
) | |
{ | |
std::cout << "apply_crosstransformer" << std::endl; | |
Eigen::Tensor3dXf pos_embed_2d_pre_reshape = | |
create_2d_sin_embedding(x.dimension(0), x.dimension(1), x.dimension(2)); | |
Eigen::Tensor3dXf pos_embed_2d(1, x.dimension(1) * x.dimension(2), | |
x.dimension(0)); | |
Eigen::Tensor3dXf x_reshape(1, x.dimension(1) * x.dimension(2), | |
x.dimension(0)); | |
// x = rearrange(x, "b c fr t1 -> b (t1 fr) c") | |
// implement above with eigen for loops | |
// rearrange x too | |
for (int i = 0; i < x.dimension(1); ++i) | |
{ | |
for (int j = 0; j < x.dimension(2); ++j) | |
{ | |
for (int k = 0; k < x.dimension(0); ++k) | |
{ | |
pos_embed_2d(0, j * x.dimension(1) + i, k) = | |
pos_embed_2d_pre_reshape(k, i, j); | |
x_reshape(0, j * x.dimension(1) + i, k) = x(k, i, j); | |
} | |
} | |
} | |
x = x_reshape; | |
float eps = 1e-5; | |
x = demucscpp::layer_norm(x, model.crosstransformer_norm_in_weight, | |
model.crosstransformer_norm_in_bias, eps) + | |
pos_embed_2d; | |
// (B, C, T2) = xt.shape | |
int C = xt.dimension(1); | |
int T2 = xt.dimension(2); | |
Eigen::Tensor3dXf pos_embed_1d = create_sin_embedding(T2, C); | |
// shuffle axes of xt from 0,1,2 to 0,2,1 | |
Eigen::Tensor3dXf xt_shuf = xt.shuffle(Eigen::array<int, 3>{0, 2, 1}); | |
xt = demucscpp::layer_norm(xt_shuf, model.crosstransformer_norm_in_t_weight, | |
model.crosstransformer_norm_in_t_bias, eps) + | |
pos_embed_1d; | |
demucscppdebug::debug_tensor_3dxf(x, "x crosstransformer pre-layers"); | |
demucscppdebug::debug_tensor_3dxf(xt, "xt crosstransformer pre-tlayers"); | |
// actual crosstransformer layers here | |
// layer 0 for freq and time is the first MyTransformerEncoderLayer | |
// the argument 0 passed in the function call is the weight index | |
// 0,2,4 -> 0,1,3 in my C++ code because i store the 3 mytransformer layers | |
// in a single array and the 2 crosstransformer layers in another array | |
// x = self.layers[0](x) | |
// xt = self.layers_t[0](xt) | |
my_transformer_encoder_layer(model, x, 0, 0); | |
my_transformer_encoder_layer(model, xt, 1, 0); | |
demucscppdebug::debug_tensor_3dxf(x, "x crosstransformer post-layer-0"); | |
demucscppdebug::debug_tensor_3dxf(xt, "xt crosstransformer post-tlayer-0"); | |
// make a copy of x | |
Eigen::Tensor3dXf old_x = x; | |
// x is modified in-place and is the final value of x | |
// xt is not modified (const) | |
cross_transformer_encoder_layer(model, x, xt, 0, 0); | |
// xt is modified in-place and is the final value of xt | |
cross_transformer_encoder_layer(model, xt, old_x, 1, 0); | |
demucscppdebug::debug_tensor_3dxf(x, "x crosstransformer post-layer-1"); | |
demucscppdebug::debug_tensor_3dxf(xt, "xt crosstransformer post-tlayer-1"); | |
my_transformer_encoder_layer(model, x, 0, 1); | |
my_transformer_encoder_layer(model, xt, 1, 1); | |
demucscppdebug::debug_tensor_3dxf(x, "x crosstransformer post-layer-2"); | |
demucscppdebug::debug_tensor_3dxf(xt, "xt crosstransformer post-tlayer-2"); | |
// make a copy of x | |
old_x = x; | |
// x is modified in-place and is the final value of x | |
cross_transformer_encoder_layer(model, x, xt, 0, 1); | |
// old_xt is modified in-place and is the final value of xt | |
cross_transformer_encoder_layer(model, xt, old_x, 1, 1); | |
demucscppdebug::debug_tensor_3dxf(x, "x crosstransformer post-layer-3"); | |
demucscppdebug::debug_tensor_3dxf(xt, "xt crosstransformer post-tlayer-3"); | |
my_transformer_encoder_layer(model, x, 0, 2); | |
my_transformer_encoder_layer(model, xt, 1, 2); | |
demucscppdebug::debug_tensor_3dxf(x, "x crosstransformer post-layer-4"); | |
demucscppdebug::debug_tensor_3dxf(xt, "xt crosstransformer post-tlayer-4"); | |
// permute last two dims of xt | |
Eigen::array<int, 3> permute_dims = {0, 2, 1}; | |
Eigen::Tensor3dXf xt_ret = xt.shuffle(permute_dims); | |
xt = xt_ret; | |
// for x, transform from shape (1, 2688, 512) to | |
// (512, 8, 336) | |
// first also permute x | |
Eigen::Tensor3dXf x_shuf = x.shuffle(permute_dims); | |
x = x_shuf; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment