Skip to content

Instantly share code, notes, and snippets.

@ra1u
Created July 14, 2018 14:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ra1u/e3a5f7b6708cf3534b0918866eac55f3 to your computer and use it in GitHub Desktop.
Save ra1u/e3a5f7b6708cf3534b0918866eac55f3 to your computer and use it in GitHub Desktop.
void TransformIn(const size_t batch_size, const float *input,
const size_t channels, float *output)
{
static const size_t Par = 16;
float x[Par][kWinogradAlpha][kWinogradAlpha];
float T1[Par][kWinogradAlpha][kWinogradAlpha];
for (size_t batch_index = 0; batch_index < batch_size; batch_index++) {
const float *input_batch =
input + batch_index * kWidth * kHeight * channels;
float *V_batch = &output[channels * kTiles * batch_index];
size_t channel_step = Par;
for (size_t channel_long = 0; channel_long < channels;
channel_long += channel_step) {
channel_step = Par;
for (int block_y = 0; block_y < kWtiles; block_y++) {
for (int block_x = 0; block_x < kWtiles; block_x++) {
// Tiles overlap by 2
const int yin = 2 * block_y - 1;
const int xin = 2 * block_x - 1;
for (size_t ch = 0; ch < channel_step; ++ch) {
const size_t channel = channel_long + ch;
float *V_channel = V_batch + channel;
const float *input_channel =
input_batch + channel * (kWidth * kHeight);
for (int i = 0; i < kWinogradAlpha; i++) {
for (int j = 0; j < kWinogradAlpha; j++) {
if ((yin + i) >= 0 && (xin + j) >= 0 &&
(yin + i) < kHeight && (xin + j) < kWidth) {
x[ch][i][j] = input_channel[(yin + i) * kWidth +
(xin + j)];
}
else {
x[ch][i][j] = 0.0f;
}
}
}
// Calculates transpose(B).x.B
// B = [[ 1.0, 0.0, 0.0, 0.0],
// [ 0.0, 1.0, -1.0, 1.0],
// [-1.0, 1.0, 1.0, 0.0],
// [ 0.0, 0.0, 0.0, -1.0]]
// WinogradTile T1, T2;
T1[ch][0][0] = x[ch][0][0] - x[ch][2][0];
T1[ch][0][1] = x[ch][0][1] - x[ch][2][1];
T1[ch][0][2] = x[ch][0][2] - x[ch][2][2];
T1[ch][0][3] = x[ch][0][3] - x[ch][2][3];
T1[ch][1][0] = x[ch][1][0] + x[ch][2][0];
T1[ch][1][1] = x[ch][1][1] + x[ch][2][1];
T1[ch][1][2] = x[ch][1][2] + x[ch][2][2];
T1[ch][1][3] = x[ch][1][3] + x[ch][2][3];
T1[ch][2][0] = x[ch][2][0] - x[ch][1][0];
T1[ch][2][1] = x[ch][2][1] - x[ch][1][1];
T1[ch][2][2] = x[ch][2][2] - x[ch][1][2];
T1[ch][2][3] = x[ch][2][3] - x[ch][1][3];
T1[ch][3][0] = x[ch][1][0] - x[ch][3][0];
T1[ch][3][1] = x[ch][1][1] - x[ch][3][1];
T1[ch][3][2] = x[ch][1][2] - x[ch][3][2];
T1[ch][3][3] = x[ch][1][3] - x[ch][3][3];
}
const auto V_incr = channels * kTiles * batch_size;
const size_t channel = channel_long;
float *V_channel = V_batch + channel;
float *wTile_V =
V_channel + channels * (block_y * kWtiles + block_x);
#define M(a0, a1, op, b1, b2) \
do { \
const size_t idx = channel_step; \
for (size_t i = 0; i < idx; ++i) { \
wTile_V[i] = T1[i][a0][a1] op T1[i][b1][b2]; \
}; \
wTile_V += V_incr; \
} while (0)
M(0, 0, -, 0, 2);
M(0, 1, +, 0, 2);
M(0, 2, -, 0, 1);
M(0, 1, -, 0, 3);
M(1, 0, -, 1, 2);
M(1, 1, +, 1, 2);
M(1, 2, -, 1, 1);
M(1, 1, -, 1, 3);
M(2, 0, -, 2, 2);
M(2, 1, +, 2, 2);
M(2, 2, -, 2, 1);
M(2, 1, -, 2, 3);
M(3, 0, -, 3, 2);
M(3, 1, +, 3, 2);
M(3, 2, -, 3, 1);
M(3, 1, -, 3, 3);
}
}
}
}
}
@frpays
Copy link

frpays commented Jul 14, 2018

That's very smart indeed!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment