Skip to content

Instantly share code, notes, and snippets.

@l-bat
Last active March 19, 2020 08:33
Show Gist options
  • Save l-bat/bb2f74a7452bdb0687f76ed72a58bc6b to your computer and use it in GitHub Desktop.
Save l-bat/bb2f74a7452bdb0687f76ed72a58bc6b to your computer and use it in GitHub Desktop.
if (inputs.size() == 2)
{
int dims = outputs[0].dims;
int m = inputs[0].size[dims - 2];
int n = inputs[0].size[dims - 1];
int k = inputs[1].size[dims - 1];
int rows = inputs[0].total() / (m * n);
MatShape sh_A = shape(rows, m * n);
MatShape sh_B = shape(rows, n * k);
MatShape sh_C = shape(rows, m * k);
UMat inp = inputs[0].reshape(1, sh_A.size(), &sh_A[0]);
UMat weight = inputs[1].reshape(1, sh_B.size(), &sh_B[0]);
UMat out = outputs[0].reshape(1, sh_C.size(), &sh_C[0]);
UMat A, B, C, A_fp32, B_fp32, C_fp32;
for (int i = 0; i < rows; ++i)
{
A = inp.row(i).reshape(1, m);
B = weight.row(i).reshape(1, n);
C = out.row(i).reshape(1, m);
if (use_half)
{
convertFp16(A, A_fp32);
convertFp16(B, B_fp32);
convertFp16(C, C_fp32);
}
else
{
A_fp32 = A;
B_fp32 = B;
C_fp32 = C;
}
cv::gemm(A_fp32, B_fp32, 1, noArray(), 0, C_fp32);
if (use_half)
{
convertFp16(A_fp32, A);
convertFp16(B_fp32, B);
convertFp16(C_fp32, C);
}
}
return true;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment