Skip to content

Instantly share code, notes, and snippets.

@sevagh
Created November 21, 2023 12:46
Show Gist options
  • Save sevagh/3b699fe58adca0aba578ddd50453fec8 to your computer and use it in GitHub Desktop.
Save sevagh/3b699fe58adca0aba578ddd50453fec8 to your computer and use it in GitHub Desktop.
my naive (and incorrect) code for computing torch.nn.LayerNorm with Eigen3
Eigen::Tensor3dXf demucscpp::layer_norm_last_dim(const Eigen::Tensor3dXf &x,
const Eigen::Tensor1dXf &weight,
const Eigen::Tensor1dXf &bias,
const float eps=1e-5) {
// compute layer norm across last dimension
// eps is typically 1e-5
//
// e.g. x = (A, B, C)
// w = (C)
// b = (C)
// ->
// y = (A, B, C)
int dim0 = x.dimension(0);
int dim1 = x.dimension(1);
int dim2 = x.dimension(2);
Eigen::Tensor3dXf y_out(x.dimensions());
for (int d0 = 0; d0 < dim0; ++d0) {
for (int d1 = 0; d1 < dim1; ++d1) {
auto slice = x.chip(d0, 0).chip(d1, 0);
Eigen::Tensor<float, 0> mean_tensor = slice.mean();
float mean = mean_tensor(0);
Eigen::Tensor<float, 0> var_tensor = (slice - mean).square().mean();
float var = var_tensor(0);
for (int d2 = 0; d2 < dim2; ++d2) {
float norm_val = (x(d0, d1, d2) - mean) / std::sqrt(var + eps);
y_out(d0, d1, d2) = norm_val * (float)weight(d2) + (float)bias(d2);
}
}
}
return y_out;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment