Skip to content

Instantly share code, notes, and snippets.

@YashasSamaga
Created June 16, 2020 13:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save YashasSamaga/c3ee66732ff3c2b07cd48ea5bd7fb4e1 to your computer and use it in GitHub Desktop.
Save YashasSamaga/c3ee66732ff3c2b07cd48ea5bd7fb4e1 to your computer and use it in GitHub Desktop.
#include <cuda_runtime.h>
#include <random>
#include <iostream>
struct relu_grad
{
__device__ float operator()(float x) { return x > 0; }
};
struct mish_grad_dn
{
__device__ float softplus_kernel(float x, float threshold = 20)
{
if (x > threshold) return x;
else if (x < -threshold) return expf(x);
return log1pf(expf(x));
}
__device__ float operator()(float x)
{
const float MISH_THRESHOLD = 20.0f;
const float sp = softplus_kernel(x, MISH_THRESHOLD);
const float grad_sp = -expm1f(-sp);
const float tsp = tanh(sp);
const float grad_tsp = (1 - tsp*tsp) * grad_sp;
const float grad = x * grad_tsp + tsp;
return grad;
}
};
struct mish_grad_tb
{
__device__ float operator()(float x)
{
const float THRESHOLD = 20.0f;
const float sp = x < THRESHOLD ? log1p(expf(x)) : x;
const float grad_sp = 1 - exp(-sp);
const float tsp = tanh(sp);
const float grad_tsp = (1 - tsp*tsp) * grad_sp;
const float grad = x * grad_tsp + tsp;
return grad;
}
};
struct mish_grad_tb_expm1
{
__device__ float operator()(float x)
{
const float THRESHOLD = 20.0f;
const float sp = x < THRESHOLD ? log1p(expf(x)) : x;
const float grad_sp = -expm1(-sp);
const float tsp = tanh(sp);
const float grad_tsp = (1 - tsp*tsp) * grad_sp;
const float grad = x * grad_tsp + tsp;
return grad;
}
};
struct mish_grad_fast
{
__device__ float operator()(float x)
{
auto e = __expf(x);
auto n = e * e + 2 * e;
float tsp;
if (x <= -0.6f)
tsp = __fdividef(n, n + 2);
else
tsp = 1 - __fdividef(2, n + 2);
const float grad_sp = __fdividef(e, e + 1);
const float grad_tsp = (1 - tsp*tsp) * grad_sp;
const float grad = x * grad_tsp + tsp;
return x > 10.5f ? 1 : grad;
}
};
struct mish_grad_double
{
__device__ float operator()(float x)
{
const double sp = log1p(exp(x));
const double grad_sp = -expm1(-sp);
const double tsp = tanh(sp);
const double grad_tsp = (1 - tsp*tsp) * grad_sp;
const double grad = x * grad_tsp + tsp;
return grad;
}
};
template <class GradientFunc>
__global__ void grad_vec1(float* __restrict__ dz, const float* __restrict__ input, int n)
{
GradientFunc grad;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x)
{
dz[i] *= grad(input[i]);
}
}
template <class GradientFunc>
__global__ void grad_vec4(float4* __restrict__ dz, const float4* __restrict__ input, int n)
{
GradientFunc grad;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x)
{
float4 temp = input[i];
float4 dy = dz[i];
dy.w *= grad(temp.w);
dy.x *= grad(temp.x);
dy.y *= grad(temp.y);
dy.z *= grad(temp.z);
dz[i] = dy;
}
}
__global__ void limit_2L1S_v1(float * __restrict__ dz, const float * __restrict__ input, int n)
{
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x)
dz[i] += input[i];
}
__global__ void limit_2L1S_v4(float4 * __restrict__ dz, const float4 * __restrict__ input, int n)
{
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x)
{
auto dy = dz[i];
auto inp = input[i];
dy.w += inp.w;
dy.x += inp.x;
dy.y += inp.y;
dy.z += inp.z;
dz[i] = dy;
}
}
template <class GradientFunc>
__global__ void dump()
{
GradientFunc grad;
for (float x = -100; x < 20; x += 0.0001)
printf("%.7f %.7e\n", x, grad(x));
}
int main ()
{
if (1)
{
dump<mish_grad_tb><<<1, 1>>>();
cudaDeviceSynchronize();
return 0;
}
constexpr int N = 1024 * 1024 * 16;
float *input_activation;
float *grad;
cudaMalloc(&input_activation, N * sizeof(float));
cudaMalloc(&grad, N * sizeof(float));
float *input_activation_h = new float[N];
float *grad_h = new float[N];
float *output_h = new float[N];
float *output_ref = new float[N];
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<float> in_dis(-50, 50);
for (int i = 0; i < N; i++)
{
long double a = in_dis(gen);
input_activation_h[i] = a;
long double dy = 1.0;
grad_h[i] = dy;
const long double sp = std::log1p(std::exp(a));
const long double grad_sp = -std::expm1(-sp);
const long double tsp = std::tanh(sp);
const long double grad_tsp = (1 - tsp * tsp) * grad_sp;
const long double grad = a * grad_tsp + tsp;
output_ref[i] = dy * grad;
}
auto lInorm = [&] (float* x, float* y, int n) {
float max = 0;
for (int i = 0; i < n; i++)
max = std::max(max, std::abs(y[i] - x[i]));
return max;
};
auto l2norm = [] (float* x, float* y, int n) {
std::vector<double> diff(n);
for (int i = 0; i < n; i++)
diff[i] = y[i] - x[i];
auto sqr_sum = std::accumulate(std::begin(diff), std::end(diff), 0.0, [](auto lhs, auto rhs) { return lhs + rhs * rhs; });
return std::sqrt(sqr_sum);
};
auto grad4 = reinterpret_cast<float4*>(grad);
auto input_activation4 = reinterpret_cast<float4*>(input_activation);
// vec1
cudaMemcpy(input_activation, input_activation_h, N * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(grad, grad_h, N * sizeof(float), cudaMemcpyHostToDevice);
grad_vec1<relu_grad><<<10, 1024>>>(grad, input_activation, N);
cudaDeviceSynchronize();
cudaMemcpy(input_activation, input_activation_h, N * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(grad, grad_h, N * sizeof(float), cudaMemcpyHostToDevice);
grad_vec4<relu_grad><<<10, 1024>>>(grad4, input_activation4, N / 4);
cudaDeviceSynchronize();
cudaMemcpy(input_activation, input_activation_h, N * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(grad, grad_h, N * sizeof(float), cudaMemcpyHostToDevice);
limit_2L1S_v1<<<10, 1024>>>(grad, input_activation, N);
cudaDeviceSynchronize();
cudaMemcpy(input_activation, input_activation_h, N * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(grad, grad_h, N * sizeof(float), cudaMemcpyHostToDevice);
limit_2L1S_v4<<<10, 1024>>>(grad4, input_activation4, N / 4);
cudaDeviceSynchronize();
cudaMemcpy(input_activation, input_activation_h, N * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(grad, grad_h, N * sizeof(float), cudaMemcpyHostToDevice);
grad_vec1<mish_grad_dn><<<10, 1024>>>(grad, input_activation, N);
cudaDeviceSynchronize();
cudaMemcpy(output_h, grad, N * sizeof(float), cudaMemcpyDeviceToHost);
std::cout << "[vec1] mish_grad_dn: " << l2norm(output_ref, output_h, N) << ' ' << lInorm(output_ref, output_h, N) << '\n';
cudaMemcpy(input_activation, input_activation_h, N * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(grad, grad_h, N * sizeof(float), cudaMemcpyHostToDevice);
grad_vec4<mish_grad_dn><<<10, 1024>>>(grad4, input_activation4, N / 4);
cudaDeviceSynchronize();
cudaMemcpy(output_h, grad, N * sizeof(float), cudaMemcpyDeviceToHost);
std::cout << "[vec1] mish_grad_dn: " << l2norm(output_ref, output_h, N) << ' ' << lInorm(output_ref, output_h, N) << '\n';
cudaMemcpy(input_activation, input_activation_h, N * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(grad, grad_h, N * sizeof(float), cudaMemcpyHostToDevice);
grad_vec1<mish_grad_tb><<<10, 1024>>>(grad, input_activation, N);
cudaDeviceSynchronize();
cudaMemcpy(output_h, grad, N * sizeof(float), cudaMemcpyDeviceToHost);
std::cout << "[vec1] mish_grad_tb: " << l2norm(output_ref, output_h, N) << ' ' << lInorm(output_ref, output_h, N) << '\n';
cudaMemcpy(input_activation, input_activation_h, N * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(grad, grad_h, N * sizeof(float), cudaMemcpyHostToDevice);
grad_vec4<mish_grad_tb><<<10, 1024>>>(grad4, input_activation4, N / 4);
cudaDeviceSynchronize();
cudaMemcpy(output_h, grad, N * sizeof(float), cudaMemcpyDeviceToHost);
std::cout << "[vec1] mish_grad_tb: " << l2norm(output_ref, output_h, N) << ' ' << lInorm(output_ref, output_h, N) << '\n';
cudaMemcpy(input_activation, input_activation_h, N * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(grad, grad_h, N * sizeof(float), cudaMemcpyHostToDevice);
grad_vec1<mish_grad_tb_expm1><<<10, 1024>>>(grad, input_activation, N);
cudaDeviceSynchronize();
cudaMemcpy(output_h, grad, N * sizeof(float), cudaMemcpyDeviceToHost);
std::cout << "[vec1] mish_grad_tb_expm1: " << l2norm(output_ref, output_h, N) << ' ' << lInorm(output_ref, output_h, N) << '\n';
cudaMemcpy(input_activation, input_activation_h, N * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(grad, grad_h, N * sizeof(float), cudaMemcpyHostToDevice);
grad_vec4<mish_grad_tb_expm1><<<10, 1024>>>(grad4, input_activation4, N / 4);
cudaDeviceSynchronize();
cudaMemcpy(output_h, grad, N * sizeof(float), cudaMemcpyDeviceToHost);
std::cout << "[vec1] mish_grad_tb_expm1: " << l2norm(output_ref, output_h, N) << ' ' << lInorm(output_ref, output_h, N) << '\n';
cudaMemcpy(input_activation, input_activation_h, N * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(grad, grad_h, N * sizeof(float), cudaMemcpyHostToDevice);
grad_vec1<mish_grad_fast><<<10, 1024>>>(grad, input_activation, N);
cudaDeviceSynchronize();
cudaMemcpy(output_h, grad, N * sizeof(float), cudaMemcpyDeviceToHost);
std::cout << "[vec1] mish_grad_fast: " << l2norm(output_ref, output_h, N) << ' ' << lInorm(output_ref, output_h, N) << '\n';
cudaMemcpy(input_activation, input_activation_h, N * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(grad, grad_h, N * sizeof(float), cudaMemcpyHostToDevice);
grad_vec4<mish_grad_fast><<<10, 1024>>>(grad4, input_activation4, N / 4);
cudaDeviceSynchronize();
cudaMemcpy(output_h, grad, N * sizeof(float), cudaMemcpyDeviceToHost);
std::cout << "[vec1] mish_grad_fast: " << l2norm(output_ref, output_h, N) << ' ' << lInorm(output_ref, output_h, N) << '\n';
cudaMemcpy(input_activation, input_activation_h, N * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(grad, grad_h, N * sizeof(float), cudaMemcpyHostToDevice);
grad_vec1<mish_grad_double><<<10, 1024>>>(grad, input_activation, N);
cudaDeviceSynchronize();
cudaMemcpy(output_h, grad, N * sizeof(float), cudaMemcpyDeviceToHost);
std::cout << "[vec1] mish_grad_double: " << l2norm(output_ref, output_h, N) << ' ' << lInorm(output_ref, output_h, N) << '\n';
cudaMemcpy(input_activation, input_activation_h, N * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(grad, grad_h, N * sizeof(float), cudaMemcpyHostToDevice);
grad_vec4<mish_grad_double><<<10, 1024>>>(grad4, input_activation4, N / 4);
cudaDeviceSynchronize();
cudaMemcpy(output_h, grad, N * sizeof(float), cudaMemcpyDeviceToHost);
std::cout << "[vec1] mish_grad_double: " << l2norm(output_ref, output_h, N) << ' ' << lInorm(output_ref, output_h, N) << '\n';
return 0;
}
import numpy as np
SMOOTHING_STEP_SIZE = 1000
LEFT_X_CUTOFF = -100
RIGHT_X_CUTOFF = 100
def ref_mish(x):
return x * np.tanh(np.log1p(np.exp(x)))
def ref_grad(x):
sp = np.log1p(np.exp(x))
grad_sp = -np.expm1(-sp)
tsp = np.tanh(sp)
grad_tsp = (1 - tsp * tsp) * grad_sp
return x * grad_tsp + tsp
def generate_stats(src):
x_list = []
y_list = []
with open(src, "r") as f:
for line in f.readlines():
x, y = [float(field.strip()) for field in line.split(' ')]
if LEFT_X_CUTOFF < x and x < RIGHT_X_CUTOFF:
x_list.append(x)
y_list.append(y)
rel_error_log10 = []
abs_diff_err = []
for x, y in zip(x_list, y_list):
x128 = np.float128(x)
y128 = np.float128(y)
ref = ref_grad(x128)
diff = np.abs(y128 - ref)
rerr = -np.Inf if diff == 0 else np.log10(np.abs(diff / ref))
log_diff = 0 if diff == 0 else np.log10(diff)
rel_error_log10.append(float(rerr))
abs_diff_err.append(float(diff))
# smoothing
x_final = []
rel_error_log10_final = []
abs_diff_err_final = []
for step in range(len(x_list) // SMOOTHING_STEP_SIZE):
ibegin = step * SMOOTHING_STEP_SIZE
iend = ibegin + SMOOTHING_STEP_SIZE
avg_x = np.mean(x_list[ibegin : iend])
max_rel_err_log10 = np.max(rel_error_log10[ibegin : iend])
max_diff_err = np.max(abs_diff_err[ibegin : iend])
x_final.append(avg_x)
rel_error_log10_final.append(max_rel_err_log10)
abs_diff_err_final.append(max_diff_err)
return x_final, rel_error_log10_final, abs_diff_err_final
x_double, re_double, ad_double = generate_stats("dump_fast_grad")
x_ocv, re_ocv, ad_ocv = generate_stats("dump_dn_grad")
x_tb, re_tb, ad_tb = generate_stats("dump_tb_grad")
import matplotlib.pyplot as plt
linewidth = 0.5
fig, ax = plt.subplots(1, 3)
ax[0].plot(x_double, re_double, linewidth = linewidth, c = 'g', label = "fast grad")
ax[0].plot(x_ocv, re_ocv, linewidth = linewidth, c = 'r', label = "darknet")
ax[0].plot(x_tb, re_tb, linewidth = linewidth, c = 'b', label = "tb")
ax[0].set_title("relative error (log10)")
ax[0].legend()
ax[1].plot(x_double, ad_double, linewidth = linewidth, c = 'g', label = "fast grad")
ax[1].plot(x_ocv, ad_ocv, linewidth = linewidth, c = 'r', label = "darknet")
ax[1].plot(x_tb, ad_tb, linewidth = linewidth, c = 'b', label = "tb")
ax[1].set_title("abs(diff)")
ax[1].legend()
ax[2].plot(x_double, [np.log10(a) for a in ad_double], linewidth = linewidth, c = 'g', label = "fast grad")
ax[2].plot(x_ocv, [np.log10(a) for a in ad_ocv], linewidth = linewidth, c = 'r', label = "darknet")
ax[2].plot(x_tb, [np.log10(a) for a in ad_tb], linewidth = linewidth, c = 'b', label = "tb")
ax[2].set_title("log10(abs(diff))")
ax[2].legend()
plt.show()
print(np.max(re_ocv), np.max(ad_ocv))
print(x_ocv[np.argmax(re_ocv)], x_ocv[np.argmax(ad_ocv)])
print(np.max(re_tb), np.max(ad_tb))
print(x_tb[np.argmax(re_tb)], x_tb[np.argmax(ad_tb)])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment