Created
July 14, 2018 14:14
-
-
Save ra1u/e3a5f7b6708cf3534b0918866eac55f3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
void TransformIn(const size_t batch_size, const float *input, | |
const size_t channels, float *output) | |
{ | |
static const size_t Par = 16; | |
float x[Par][kWinogradAlpha][kWinogradAlpha]; | |
float T1[Par][kWinogradAlpha][kWinogradAlpha]; | |
for (size_t batch_index = 0; batch_index < batch_size; batch_index++) { | |
const float *input_batch = | |
input + batch_index * kWidth * kHeight * channels; | |
float *V_batch = &output[channels * kTiles * batch_index]; | |
size_t channel_step = Par; | |
for (size_t channel_long = 0; channel_long < channels; | |
channel_long += channel_step) { | |
channel_step = Par; | |
for (int block_y = 0; block_y < kWtiles; block_y++) { | |
for (int block_x = 0; block_x < kWtiles; block_x++) { | |
// Tiles overlap by 2 | |
const int yin = 2 * block_y - 1; | |
const int xin = 2 * block_x - 1; | |
for (size_t ch = 0; ch < channel_step; ++ch) { | |
const size_t channel = channel_long + ch; | |
float *V_channel = V_batch + channel; | |
const float *input_channel = | |
input_batch + channel * (kWidth * kHeight); | |
for (int i = 0; i < kWinogradAlpha; i++) { | |
for (int j = 0; j < kWinogradAlpha; j++) { | |
if ((yin + i) >= 0 && (xin + j) >= 0 && | |
(yin + i) < kHeight && (xin + j) < kWidth) { | |
x[ch][i][j] = input_channel[(yin + i) * kWidth + | |
(xin + j)]; | |
} | |
else { | |
x[ch][i][j] = 0.0f; | |
} | |
} | |
} | |
// Calculates transpose(B).x.B | |
// B = [[ 1.0, 0.0, 0.0, 0.0], | |
// [ 0.0, 1.0, -1.0, 1.0], | |
// [-1.0, 1.0, 1.0, 0.0], | |
// [ 0.0, 0.0, 0.0, -1.0]] | |
// WinogradTile T1, T2; | |
T1[ch][0][0] = x[ch][0][0] - x[ch][2][0]; | |
T1[ch][0][1] = x[ch][0][1] - x[ch][2][1]; | |
T1[ch][0][2] = x[ch][0][2] - x[ch][2][2]; | |
T1[ch][0][3] = x[ch][0][3] - x[ch][2][3]; | |
T1[ch][1][0] = x[ch][1][0] + x[ch][2][0]; | |
T1[ch][1][1] = x[ch][1][1] + x[ch][2][1]; | |
T1[ch][1][2] = x[ch][1][2] + x[ch][2][2]; | |
T1[ch][1][3] = x[ch][1][3] + x[ch][2][3]; | |
T1[ch][2][0] = x[ch][2][0] - x[ch][1][0]; | |
T1[ch][2][1] = x[ch][2][1] - x[ch][1][1]; | |
T1[ch][2][2] = x[ch][2][2] - x[ch][1][2]; | |
T1[ch][2][3] = x[ch][2][3] - x[ch][1][3]; | |
T1[ch][3][0] = x[ch][1][0] - x[ch][3][0]; | |
T1[ch][3][1] = x[ch][1][1] - x[ch][3][1]; | |
T1[ch][3][2] = x[ch][1][2] - x[ch][3][2]; | |
T1[ch][3][3] = x[ch][1][3] - x[ch][3][3]; | |
} | |
const auto V_incr = channels * kTiles * batch_size; | |
const size_t channel = channel_long; | |
float *V_channel = V_batch + channel; | |
float *wTile_V = | |
V_channel + channels * (block_y * kWtiles + block_x); | |
#define M(a0, a1, op, b1, b2) \ | |
do { \ | |
const size_t idx = channel_step; \ | |
for (size_t i = 0; i < idx; ++i) { \ | |
wTile_V[i] = T1[i][a0][a1] op T1[i][b1][b2]; \ | |
}; \ | |
wTile_V += V_incr; \ | |
} while (0) | |
M(0, 0, -, 0, 2); | |
M(0, 1, +, 0, 2); | |
M(0, 2, -, 0, 1); | |
M(0, 1, -, 0, 3); | |
M(1, 0, -, 1, 2); | |
M(1, 1, +, 1, 2); | |
M(1, 2, -, 1, 1); | |
M(1, 1, -, 1, 3); | |
M(2, 0, -, 2, 2); | |
M(2, 1, +, 2, 2); | |
M(2, 2, -, 2, 1); | |
M(2, 1, -, 2, 3); | |
M(3, 0, -, 3, 2); | |
M(3, 1, +, 3, 2); | |
M(3, 2, -, 3, 1); | |
M(3, 1, -, 3, 3); | |
} | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
That's very smart indeed!