Created
November 21, 2023 12:46
-
-
Save sevagh/3b699fe58adca0aba578ddd50453fec8 to your computer and use it in GitHub Desktop.
my naive (and incorrect) code for computing torch.nn.LayerNorm with Eigen3
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Eigen::Tensor3dXf demucscpp::layer_norm_last_dim(const Eigen::Tensor3dXf &x, | |
const Eigen::Tensor1dXf &weight, | |
const Eigen::Tensor1dXf &bias, | |
const float eps=1e-5) { | |
// compute layer norm across last dimension | |
// eps is typically 1e-5 | |
// | |
// e.g. x = (A, B, C) | |
// w = (C) | |
// b = (C) | |
// -> | |
// y = (A, B, C) | |
int dim0 = x.dimension(0); | |
int dim1 = x.dimension(1); | |
int dim2 = x.dimension(2); | |
Eigen::Tensor3dXf y_out(x.dimensions()); | |
for (int d0 = 0; d0 < dim0; ++d0) { | |
for (int d1 = 0; d1 < dim1; ++d1) { | |
auto slice = x.chip(d0, 0).chip(d1, 0); | |
Eigen::Tensor<float, 0> mean_tensor = slice.mean(); | |
float mean = mean_tensor(0); | |
Eigen::Tensor<float, 0> var_tensor = (slice - mean).square().mean(); | |
float var = var_tensor(0); | |
for (int d2 = 0; d2 < dim2; ++d2) { | |
float norm_val = (x(d0, d1, d2) - mean) / std::sqrt(var + eps); | |
y_out(d0, d1, d2) = norm_val * (float)weight(d2) + (float)bias(d2); | |
} | |
} | |
} | |
return y_out; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment