Skip to content

Instantly share code, notes, and snippets.

@YashasSamaga
Last active June 14, 2020 08:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save YashasSamaga/8ad0cd3b30dbd0eb588c1f4c035db28c to your computer and use it in GitHub Desktop.
Save YashasSamaga/8ad0cd3b30dbd0eb588c1f4c035db28c to your computer and use it in GitHub Desktop.
Performance comparision of different mish implementations
#include "mish.hpp"
#include <cuda_runtime.h>
#include <random>
#include <iostream>
template <class Activation>
__global__ void activate_vec1(float* __restrict__ output, const float* __restrict__ input, int n)
{
Activation activation;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x)
{
output[i] = activation(input[i]);
}
}
template <class Activation>
__global__ void activate_vec2(float2* __restrict__ output, const float2* __restrict__ input, int n)
{
Activation activation;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x)
{
float2 temp = input[i];
temp.x = activation(temp.x);
temp.y = activation(temp.y);
output[i] = temp;
}
}
template <class Activation>
__global__ void activate_vec4(float4* __restrict__ output, const float4* __restrict__ input, int n)
{
Activation activation;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x)
{
float4 temp = input[i];
temp.w = activation(temp.w);
temp.x = activation(temp.x);
temp.y = activation(temp.y);
temp.z = activation(temp.z);
output[i] = temp;
}
}
int main ()
{
constexpr int N = 1024 * 1024 * 16;
float *input;
float* output;
cudaMalloc(&input, N * sizeof(float));
cudaMalloc(&output, N * sizeof(float));
float *input_h = new float[N];
float *output_h = new float[N];
float *output_ref = new float[N];
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<float> dis(-50, 50);
for (int i = 0; i < N; i++)
{
double x = dis(gen);
input_h[i] = x;
output_ref[i] = x * std::tanh(std::log1p(std::exp(x)));
}
auto l2norm = [] (float* x, float* y, int n) {
std::vector<double> diff(n);
for (int i = 0; i < n; i++)
diff[i] = y[i] - x[i];
auto sqr_sum = std::accumulate(std::begin(diff), std::end(diff), 0.0, [](auto lhs, auto rhs) { return lhs + rhs * rhs; });
return std::sqrt(sqr_sum);
};
// vec1
cudaMemcpy(input, input_h, N * sizeof(float), cudaMemcpyHostToDevice);
activate_vec1<functors::relu><<<10, 1024>>>(output, input, N);
cudaMemcpy(input, input_h, N * sizeof(float), cudaMemcpyHostToDevice);
activate_vec1<functors::mish_tb><<<10, 1024>>>(output, input, N);
cudaDeviceSynchronize();
cudaMemcpy(output_h, output, N * sizeof(float), cudaMemcpyDeviceToHost);
std::cout << "[vec1] mish_tb: " << l2norm(output_ref, output_h, N) << '\n';
cudaMemcpy(input, input_h, N * sizeof(float), cudaMemcpyHostToDevice);
activate_vec1<functors::mish_rw><<<10, 1024>>>(output, input, N);
cudaDeviceSynchronize();
cudaMemcpy(output_h, output, N * sizeof(float), cudaMemcpyDeviceToHost);
std::cout << "[vec1] mish_rw: " << l2norm(output_ref, output_h, N) << '\n';
cudaMemcpy(input, input_h, N * sizeof(float), cudaMemcpyHostToDevice);
activate_vec1<functors::mish_njuffa1><<<10, 1024>>>(output, input, N);
cudaDeviceSynchronize();
cudaMemcpy(output_h, output, N * sizeof(float), cudaMemcpyDeviceToHost);
std::cout << "[vec1] mish_njuffa1: " << l2norm(output_ref, output_h, N) << '\n';
cudaMemcpy(input, input_h, N * sizeof(float), cudaMemcpyHostToDevice);
activate_vec1<functors::mish_njuffa2><<<10, 1024>>>(output, input, N);
cudaDeviceSynchronize();
cudaMemcpy(output_h, output, N * sizeof(float), cudaMemcpyDeviceToHost);
std::cout << "[vec1] mish_njuffa2: " << l2norm(output_ref, output_h, N) << '\n';
cudaMemcpy(input, input_h, N * sizeof(float), cudaMemcpyHostToDevice);
activate_vec1<functors::mish_njuffa3><<<10, 1024>>>(output, input, N);
cudaDeviceSynchronize();
cudaMemcpy(output_h, output, N * sizeof(float), cudaMemcpyDeviceToHost);
std::cout << "[vec1] mish_njuffa3: " << l2norm(output_ref, output_h, N) << '\n';
cudaMemcpy(input, input_h, N * sizeof(float), cudaMemcpyHostToDevice);
activate_vec1<functors::mish_aj1><<<10, 1024>>>(output, input, N);
cudaDeviceSynchronize();
cudaMemcpy(output_h, output, N * sizeof(float), cudaMemcpyDeviceToHost);
std::cout << "[vec1] mish_aj1: " << l2norm(output_ref, output_h, N) << '\n';
cudaMemcpy(input, input_h, N * sizeof(float), cudaMemcpyHostToDevice);
activate_vec1<functors::mish_aj2><<<10, 1024>>>(output, input, N);
cudaDeviceSynchronize();
cudaMemcpy(output_h, output, N * sizeof(float), cudaMemcpyDeviceToHost);
std::cout << "[vec1] mish_aj2: " << l2norm(output_ref, output_h, N) << '\n';
cudaMemcpy(input, input_h, N * sizeof(float), cudaMemcpyHostToDevice);
activate_vec1<functors::mish_aj2_fastdiv><<<10, 1024>>>(output, input, N);
cudaDeviceSynchronize();
cudaMemcpy(output_h, output, N * sizeof(float), cudaMemcpyDeviceToHost);
std::cout << "[vec1] mish_aj2_fastdiv: " << l2norm(output_ref, output_h, N) << '\n';
cudaMemcpy(input, input_h, N * sizeof(float), cudaMemcpyHostToDevice);
activate_vec1<functors::mish_dlib><<<10, 1024>>>(output, input, N);
cudaDeviceSynchronize();
cudaMemcpy(output_h, output, N * sizeof(float), cudaMemcpyDeviceToHost);
std::cout << "[vec1] mish_dlib: " << l2norm(output_ref, output_h, N) << '\n';
cudaMemcpy(input, input_h, N * sizeof(float), cudaMemcpyHostToDevice);
activate_vec1<functors::mish_ocv><<<10, 1024>>>(output, input, N);
cudaDeviceSynchronize();
cudaMemcpy(output_h, output, N * sizeof(float), cudaMemcpyDeviceToHost);
std::cout << "[vec1] mish_ocv: " << l2norm(output_ref, output_h, N) << '\n';
// vec4
static_assert(N % 4 == 0, "");
auto input4 = reinterpret_cast<float4*>(input);
auto output4 = reinterpret_cast<float4*>(output);
cudaMemcpy(input, input_h, N * sizeof(float), cudaMemcpyHostToDevice);
activate_vec4<functors::relu><<<10, 1024>>>(output4, input4, N / 4);
cudaMemcpy(input, input_h, N * sizeof(float), cudaMemcpyHostToDevice);
activate_vec4<functors::mish_tb><<<10, 1024>>>(output4, input4, N / 4);
cudaDeviceSynchronize();
cudaMemcpy(output_h, output, N * sizeof(float), cudaMemcpyDeviceToHost);
std::cout << "[vec4] mish_tb: " << l2norm(output_ref, output_h, N) << '\n';
cudaMemcpy(input, input_h, N * sizeof(float), cudaMemcpyHostToDevice);
activate_vec4<functors::mish_rw><<<10, 1024>>>(output4, input4, N / 4);
cudaDeviceSynchronize();
cudaMemcpy(output_h, output, N * sizeof(float), cudaMemcpyDeviceToHost);
std::cout << "[vec4] mish_rw: " << l2norm(output_ref, output_h, N) << '\n';
cudaMemcpy(input, input_h, N * sizeof(float), cudaMemcpyHostToDevice);
activate_vec4<functors::mish_njuffa1><<<10, 1024>>>(output4, input4, N / 4);
cudaDeviceSynchronize();
cudaMemcpy(output_h, output, N * sizeof(float), cudaMemcpyDeviceToHost);
std::cout << "[vec4] mish_njuffa1: " << l2norm(output_ref, output_h, N) << '\n';
cudaMemcpy(input, input_h, N * sizeof(float), cudaMemcpyHostToDevice);
activate_vec4<functors::mish_njuffa2><<<10, 1024>>>(output4, input4, N / 4);
cudaDeviceSynchronize();
cudaMemcpy(output_h, output, N * sizeof(float), cudaMemcpyDeviceToHost);
std::cout << "[vec4] mish_njuffa2: " << l2norm(output_ref, output_h, N) << '\n';
cudaMemcpy(input, input_h, N * sizeof(float), cudaMemcpyHostToDevice);
activate_vec4<functors::mish_njuffa3><<<10, 1024>>>(output4, input4, N / 4);
cudaDeviceSynchronize();
cudaMemcpy(output_h, output, N * sizeof(float), cudaMemcpyDeviceToHost);
std::cout << "[vec4] mish_njuffa3: " << l2norm(output_ref, output_h, N) << '\n';
cudaMemcpy(input, input_h, N * sizeof(float), cudaMemcpyHostToDevice);
activate_vec4<functors::mish_aj1><<<10, 1024>>>(output4, input4, N / 4);
cudaDeviceSynchronize();
cudaMemcpy(output_h, output, N * sizeof(float), cudaMemcpyDeviceToHost);
std::cout << "[vec4] mish_aj1: " << l2norm(output_ref, output_h, N) << '\n';
cudaMemcpy(input, input_h, N * sizeof(float), cudaMemcpyHostToDevice);
activate_vec4<functors::mish_aj2><<<10, 1024>>>(output4, input4, N / 4);
cudaDeviceSynchronize();
cudaMemcpy(output_h, output, N * sizeof(float), cudaMemcpyDeviceToHost);
std::cout << "[vec4] mish_aj2: " << l2norm(output_ref, output_h, N) << '\n';
cudaMemcpy(input, input_h, N * sizeof(float), cudaMemcpyHostToDevice);
activate_vec4<functors::mish_aj2_fastdiv><<<10, 1024>>>(output4, input4, N / 4);
cudaDeviceSynchronize();
cudaMemcpy(output_h, output, N * sizeof(float), cudaMemcpyDeviceToHost);
std::cout << "[vec4] mish_aj2_fastdiv: " << l2norm(output_ref, output_h, N) << '\n';
cudaMemcpy(input, input_h, N * sizeof(float), cudaMemcpyHostToDevice);
activate_vec4<functors::mish_dlib><<<10, 1024>>>(output4, input4, N / 4);
cudaDeviceSynchronize();
cudaMemcpy(output_h, output, N * sizeof(float), cudaMemcpyDeviceToHost);
std::cout << "[vec4] mish_dlib: " << l2norm(output_ref, output_h, N) << '\n';
cudaMemcpy(input, input_h, N * sizeof(float), cudaMemcpyHostToDevice);
activate_vec4<functors::mish_ocv><<<10, 1024>>>(output4, input4, N / 4);
cudaDeviceSynchronize();
cudaMemcpy(output_h, output, N * sizeof(float), cudaMemcpyDeviceToHost);
std::cout << "[vec4] mish_ocv: " << l2norm(output_ref, output_h, N) << '\n';
cudaDeviceSynchronize();
return 0;
}
#pragma once
#include <cuda_runtime.h>
namespace functors
{
struct relu
{
__device__ float operator()(float x)
{
return max(x, 0.0f);
}
};
struct mish_tb
{
__device__ float operator()(float x)
{
return x * tanhf(x < 20 ? log1pf(expf(x)) : x);
}
};
struct mish_rw
{
__device__ float softplus(float x)
{
const float threshold = 20;
if (x > threshold) return x; // too large
else if (x < -threshold) return expf(x); // too small
return log(expf(x) + 1.0f);
}
__device__ float operator()(float x)
{
return x * tanhf(softplus(x));
}
};
struct mish_njuffa1
{
__device__ float operator()(float x)
{
float r;
float e = expf (x);
r = 1.0f / fmaf (fmaf (-0.5f, e, -1.0f), e, -1.0f);
r = fmaf (r, x, x);
return r;
}
};
struct mish_njuffa2
{
__device__ float operator()(float x)
{
float r;
if (x >= -1.0f) {
float e = expf (x);
r = 1.0f / fmaf (fmaf (-0.5f, e, -1.0f), e, -1.0f);
r = fmaf (r, x, x);
} else {
float eh = expf (0.5f * x);
float p = 1.03628484e-3f; // 0x1.0fa7e6p-10
p = fmaf (p, x, -7.28869531e-3f); // -0x1.ddac04p-8
p = fmaf (p, x, 3.47027816e-2f); // 0x1.1c4902p-5
p = fmaf (p, x, -3.54762226e-1f); // -0x1.6b46cap-2
p = fmaf (p, x, 8.58785570e-1f); // 0x1.b7b2bep-1
p = fmaf (p, x, -1.38065982e+0f); // -0x1.6172ecp+0
p = fmaf (p, x, 5.97694337e-1f); // 0x1.3204fep-1
float q = 1.03527203e-3f; // 0x1.0f63eep-10
q = fmaf (q, x, -7.35638570e-3f); // -0x1.e21bacp-8
q = fmaf (q, x, 3.28683928e-2f); // 0x1.0d4204p-5
q = fmaf (q, x, -3.79927397e-1f); // -0x1.850bb0p-2
q = fmaf (q, x, 6.86127126e-1f); // 0x1.5f4c0ep-1
q = fmaf (q, x, -1.81509292e+0f); // -0x1.d0a9eep+0
q = fmaf (q, x, 1.00000000e+0f); // 0x1.000000p+0
r = (1.0f / q) * p;
if (x < -15.0f) r = 1.0f;
r = r * x * eh * eh;
}
return r;
}
};
struct mish_njuffa3
{
__device__ float operator()(float x)
{
float r;
float e = expf (x);
if (x >= -6.0625f) {
r = 1.0f / fmaf (fmaf (-0.5f, e, -1.0f), e, -1.0f);
r = fmaf (r, x, x);
} else {
r = fmaf (-0.5f, e, 1.0f);
r = r * x * e;
}
return r;
}
};
struct mish_aj1
{
__device__ float operator()(float x)
{
float expx = __expf(x);
return x / (1.0f + 2.0f / (expx * (2.0f + expx)));
}
};
struct mish_aj2
{
__device__ float operator()(float x)
{
float expx = __expf(x);
float psi = expx * (2.0f + expx);
return x * (psi / (2.0f + psi));
}
};
struct mish_aj2_fastdiv
{
__device__ float operator()(float x)
{
float expx = __expf(x);
float psi = expx * (2.0f + expx);
return x * (__fdividef(psi, (2.0f + psi)));
}
};
struct mish_dlib
{
__device__ float operator()(float x)
{
const auto e = std::exp(x);
const auto delta = 2 * e + e * e + 2;
return x - 2 * x/delta;
}
};
struct mish_ocv
{
__device__ float operator()(float value)
{
auto e = __expf(value);
auto n = e * e + 2 * e;
if (value <= -0.6f)
return value * __fdividef(n, n + 2);
return value - 2 * __fdividef(value, n + 2);
}
};
} /* namespace functors */
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment