Skip to content

Instantly share code, notes, and snippets.

@t-vi
Last active October 14, 2018 14:48
Show Gist options
  • Save t-vi/82a46dc87eceae303a4f805147f82310 to your computer and use it in GitHub Desktop.
Save t-vi/82a46dc87eceae303a4f805147f82310 to your computer and use it in GitHub Desktop.
Highly accurate batchnorm backward reductions.
csrc = """
#include <torch/extension.h>
#include <THC/THCDeviceUtils.cuh>
#include <THC/THCGeneral.h>
#include "ATen/ATen.h"
#include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAContext.h"
using namespace at;
#if defined(__HIP_PLATFORM_HCC__)
constexpr int WARP_SIZE = 64;
#else
constexpr int WARP_SIZE = 32;
#endif
// The maximum number of threads in a block
#if defined(__HIP_PLATFORM_HCC__)
constexpr int MAX_BLOCK_SIZE = 256;
#else
constexpr int MAX_BLOCK_SIZE = 512;
#endif
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
static int getNumThreads(int nElem) {
#if defined(__HIP_PLATFORM_HCC__)
int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
#else
int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
#endif
for (int i = 0; i != 5; ++i) {
if (nElem <= threadSizes[i]) {
return threadSizes[i];
}
}
return MAX_BLOCK_SIZE;
}
template <typename scalar_t>
__device__ scalar_t add_with_lower(scalar_t& rlower, const scalar_t& a, const scalar_t& b) {
scalar_t rupper = a + b;
if (fabs(a) >= fabs(b)) {
rlower += (a - rupper) + b;
} else {
rlower += (b - rupper) + a;
}
return rupper;
}
template<typename T>
__device__ __forceinline__ void reduce_block(T *x, T val, T lower)
{
int tid = threadIdx.x;
int blockSize = blockDim.x; // blockSize is intended to be a multiple of 32.
if(blockSize >= 64) {
x[tid] = val;
x[tid + blockSize] = lower;
__syncthreads();
}
#pragma unroll
for(int i = (blockSize >> 1); i >= 64; i >>= 1) {
if(tid < i) {
x[tid] = add_with_lower(x[tid + blockSize], x[tid], x[tid+i]);
__syncthreads();
x[tid + blockSize] += x[tid + blockSize + i];
}
__syncthreads();
}
if(tid < 32) {
T final;
T final_lower;
if(blockSize >= 64) {
final_lower = x[tid + blockSize] + x[tid + blockSize + 32];
__syncthreads();
final = add_with_lower(final_lower, x[tid], x[tid+32]);
} else {
final = val;
final_lower = lower;
}
// __SYNCWARP();
#pragma unroll
for(int i = 16; i >= 1; i >>= 1) {
final_lower += WARP_SHFL_DOWN(final_lower, i, 32);
final = add_with_lower(final_lower, final, WARP_SHFL_DOWN(final, i, 32));
}
if(tid == 0) {
x[0] = final; // EpilogueOp
x[1] = final_lower;
}
}
// Make sure the smem result is visible to all warps.
__syncthreads();
}
template <typename scalar_t, typename accscalar_t>
__global__ void sum_kernel_kahan_parallel(
PackedTensorAccessor<accscalar_t, 1, at::RestrictPtrTraits> sum,
const PackedTensorAccessor<scalar_t, 3, at::RestrictPtrTraits> input) {
extern __shared__ char buf[]; // aliasing to not have type complaints from nvcc
accscalar_t* s = (accscalar_t*)buf;
const int plane = threadIdx.y + blockDim.y * blockIdx.y;
const int stride = blockDim.x;
const int offset = threadIdx.x;
accscalar_t sum_ = 0;
accscalar_t sum_lower = 0;
if (plane < input.size(1)) {
for (int64_t b = 0; b < input.size(0); b++) {
for (int64_t f = offset; f < input.size(2); f += stride) {
accscalar_t inp = input[b][plane][f];
sum_ = add_with_lower(sum_lower, inp, sum_);
}
}
reduce_block(s+threadIdx.y * blockDim.x * 2, sum_, sum_lower);
if (offset == 0) {
sum[plane] = s[threadIdx.y * blockDim.x * 2] + s[threadIdx.y * blockDim.x * 2+ 1];
}
}
}
template <typename scalar_t, typename accscalar_t, bool train>
__global__ void grad_sum_and_dot_kernel_parallel(
PackedTensorAccessor<accscalar_t, 1, at::RestrictPtrTraits> sum,
PackedTensorAccessor<accscalar_t, 1, at::RestrictPtrTraits> scalar_prod,
const PackedTensorAccessor<scalar_t, 3, at::RestrictPtrTraits> grad_out,
const PackedTensorAccessor<scalar_t, 3, at::RestrictPtrTraits> input,
const PackedTensorAccessor<typename std::conditional<train, accscalar_t, scalar_t>::type, 1, at::RestrictPtrTraits> mean_inp) {
extern __shared__ char buf[]; // aliasing to not have type complaints from nvcc
accscalar_t* s = (accscalar_t*)buf;
const int plane = threadIdx.y + blockDim.y * blockIdx.y;
const int stride = blockDim.x;
const int offset = threadIdx.x;
accscalar_t sum_ = 0;
accscalar_t sum_lower = 0;
accscalar_t scalar_prod_ = 0;
accscalar_t scalar_prod_lower = 0;
if (plane < input.size(1)) {
accscalar_t mi = mean_inp[plane];
for (int64_t b = 0; b < input.size(0); b++) {
for (int64_t f = offset; f < input.size(2); f += stride) {
accscalar_t go = grad_out[b][plane][f];
sum_ = add_with_lower(sum_lower, go, sum_);
accscalar_t demeaned_inp_lower = 0;
accscalar_t demeaned_inp = add_with_lower(demeaned_inp_lower, static_cast<accscalar_t>(input[b][plane][f]), -mi);
accscalar_t g_dmil = go * demeaned_inp_lower;
// we skip computing the lower bits of l * rm_lower
accscalar_t prod = go * demeaned_inp;
accscalar_t prodl = fma(go, demeaned_inp, -prod);
scalar_prod_ = add_with_lower(scalar_prod_lower, prod, scalar_prod_);
scalar_prod_ = add_with_lower(scalar_prod_lower, g_dmil, scalar_prod_);
scalar_prod_ = add_with_lower(scalar_prod_lower, prodl, scalar_prod_);
}
}
reduce_block(s+threadIdx.y * blockDim.x * 2, sum_, sum_lower);
if (offset == 0) {
sum[plane] = s[threadIdx.y * blockDim.x * 2] + s[threadIdx.y * blockDim.x * 2+ 1];
}
reduce_block(s+threadIdx.y * blockDim.x * 2, scalar_prod_, scalar_prod_lower);
if (offset == 0) {
scalar_prod[plane] = s[threadIdx.y * blockDim.x * 2] + s[threadIdx.y * blockDim.x * 2+ 1];
}
}
}
template <typename scalar_t>
__global__ void sum_kernel_kahan(
PackedTensorAccessor<scalar_t, 1, at::RestrictPtrTraits> sum,
const PackedTensorAccessor<scalar_t, 3, at::RestrictPtrTraits> input) {
int64_t plane = threadIdx.x + blockDim.x * blockIdx.x;
scalar_t sum_ = 0;
scalar_t lower_order = 0;
if (plane < input.size(1)) {
for (int64_t b = 0; b < input.size(0); b++) {
for (int64_t f = 0; f < input.size(2); f++) {
sum_ = add_with_lower(lower_order, input[b][plane][f], sum_);
}
}
sum[plane] = sum_ + lower_order;
}
}
template <typename scalar_t>
__global__ void sum_kernel_kahan2(
PackedTensorAccessor<scalar_t, 1, at::RestrictPtrTraits> sum,
const PackedTensorAccessor<scalar_t, 3, at::RestrictPtrTraits> input) {
int64_t plane = threadIdx.x + blockDim.x * blockIdx.x;
scalar_t sum_ = 0;
scalar_t lower_order = 0;
int64_t BLOCKS = 512;
if (plane < input.size(1)) {
for (int64_t b = 0; b < input.size(0); b++) {
for (int64_t block = 0; block < BLOCKS; block++) {
scalar_t sum_local = 0;
scalar_t lower_order_local = 0;
for (int64_t f = block; f < input.size(2); f+=BLOCKS) {
sum_local = add_with_lower(lower_order_local, input[b][plane][f], sum_local);
}
sum_ = add_with_lower(lower_order, sum_local, sum_);
lower_order += lower_order_local;
}
}
sum[plane] = sum_ + lower_order;
}
}
template <typename scalar_t>
__global__ void scalar_product_kernel(
PackedTensorAccessor<scalar_t, 1, at::RestrictPtrTraits> sum,
const PackedTensorAccessor<scalar_t, 3, at::RestrictPtrTraits> grad_out,
const PackedTensorAccessor<scalar_t, 3, at::RestrictPtrTraits> inp,
const PackedTensorAccessor<scalar_t, 1, at::RestrictPtrTraits> mean_inp
) {
int64_t plane = threadIdx.x + blockDim.x * blockIdx.x;
scalar_t res = 0;
scalar_t res_lower = 0;
if (plane < grad_out.size(1)) {
scalar_t m = mean_inp[plane];
for (int64_t b = 0; b < grad_out.size(0); b++) {
for (int64_t f = 0; f < grad_out.size(2); f++) {
scalar_t rm_lower = 0;
scalar_t rm = add_with_lower(rm_lower, inp[b][plane][f], -m);
scalar_t l = grad_out[b][plane][f];
scalar_t lrml = l * rm_lower;
// we skip computing the lower bits of l * rm_lower
scalar_t prod = l * rm;
scalar_t prodl = fma(l, rm, -prod);
res = add_with_lower(res_lower, prod, res);
res = add_with_lower(res_lower, lrml, res);
res = add_with_lower(res_lower, prodl, res);
}
}
sum[plane] = res + res_lower;
}
}
template<typename scalar_t>
Tensor sum_template(const Tensor& input) {
Tensor sum = empty({input.size(1)}, input.options());
using accscalar_t = acc_type<scalar_t, true>;
constexpr int MAX_THREADS = 512;
int feature_threads = ((std::min<int>(input.size(2), MAX_THREADS)+31)/32)*32; // round to multiples of 32
int plane_threads = std::max<int>(1, MAX_THREADS/feature_threads);
int smem_size = sizeof(accscalar_t)*plane_threads*feature_threads*2;
dim3 threads(feature_threads, plane_threads);
dim3 blocks(1, (input.size(1)+plane_threads-1)/plane_threads);
sum_kernel_kahan_parallel<scalar_t, accscalar_t><<<blocks, threads, smem_size>>>(sum.packed_accessor<accscalar_t, 1, at::RestrictPtrTraits>(),
input.packed_accessor<scalar_t, 3, at::RestrictPtrTraits>());
return sum;
}
template<typename scalar_t>
std::tuple<Tensor,Tensor> sum_and_scalar_prod_template(const Tensor& grad_out, const Tensor& input, const Tensor& mean_inp) {
Tensor sum = at::zeros({input.size(1)}, input.options());
Tensor scalar_prod = at::empty({input.size(1)}, input.options());
using accscalar_t = acc_type<scalar_t, true>;
constexpr int MAX_THREADS = 512;
int feature_threads = ((std::min<int>(input.size(2), MAX_THREADS)+31)/32)*32; // round to multiples of 32
int plane_threads = std::max<int>(1, MAX_THREADS/feature_threads);
int smem_size = sizeof(accscalar_t)*plane_threads*feature_threads*2;
dim3 threads(feature_threads, plane_threads);
dim3 blocks(1, (input.size(1)+plane_threads-1)/plane_threads);
grad_sum_and_dot_kernel_parallel<scalar_t, accscalar_t,true><<<blocks, threads, smem_size>>>(
sum.packed_accessor<accscalar_t, 1, at::RestrictPtrTraits>(),
scalar_prod.packed_accessor<accscalar_t, 1, at::RestrictPtrTraits>(),
grad_out.packed_accessor<scalar_t, 3, at::RestrictPtrTraits>(),
input.packed_accessor<scalar_t, 3, at::RestrictPtrTraits>(),
mean_inp.packed_accessor<accscalar_t, 1, at::RestrictPtrTraits>()
);
return std::make_tuple(sum, scalar_prod);
}
template<typename scalar_t>
Tensor scalar_product_template(const Tensor& l, const Tensor& r, const Tensor& mean) {
Tensor res = empty({l.size(1)}, l.options());
dim3 threads(std::min<int>(l.size(1), 512));
dim3 blocks(std::max<int>(1, (l.size(1)+511)/ 512));
scalar_product_kernel<scalar_t><<<blocks, threads>>>(res.packed_accessor<scalar_t, 1, at::RestrictPtrTraits>(),
l.packed_accessor<scalar_t, 3, at::RestrictPtrTraits>(),
r.packed_accessor<scalar_t, 3, at::RestrictPtrTraits>(),
mean.packed_accessor<scalar_t, 1, at::RestrictPtrTraits>()
);
return res;
}
// TensorAccessor in which the last dimensions are collapsed or expanded as needed
template <typename scalar_t, int64_t dim>
static PackedTensorAccessor<scalar_t, dim, at::RestrictPtrTraits> reshaped_packed_accessor(const Tensor& t) {
// undefined...
if (! t.defined()) {
const std::vector<int64_t> zeros(dim);
return PackedTensorAccessor<scalar_t, dim, at::RestrictPtrTraits>(nullptr, zeros.data(), zeros.data());
}
int64_t in_dim = t.dim();
if (in_dim == dim) {
return t.packed_accessor<scalar_t, dim, at::RestrictPtrTraits>();
}
AT_CHECK(in_dim < dim || t.is_contiguous(), "need contiguous or <= 3d tensor");
std::vector<int64_t> sizes(dim);
std::vector<int64_t> strides(dim);
for (int i = 0; i < in_dim || i < dim; ++i) {
if (i < dim && i < in_dim) {
sizes[i] = t.size(i);
strides[i] = t.stride(i);
} else if (i < dim) {
sizes[i] = 1;
strides[i] = 0;
} else {
sizes[dim - 1] *= t.size(i);
strides[dim -1] = 1;
}
}
// evil trick to get adjusted 2d tensors to have large dimension last
if (dim == 3 && sizes[0] > sizes[2]) {
std::swap(sizes[0], sizes[2]);
std::swap(strides[0], strides[2]);
}
return PackedTensorAccessor<scalar_t, dim, at::RestrictPtrTraits>(t.data<scalar_t>(), sizes.data(), strides.data());
}
template <typename scalar_t, typename accscalar_t, bool train>
__global__ void batch_norm_backward_gradient_kernel(
const PackedTensorAccessor<scalar_t, 3, at::RestrictPtrTraits> input,
const PackedTensorAccessor<scalar_t, 3, at::RestrictPtrTraits> grad_output,
PackedTensorAccessor<scalar_t, 3, at::RestrictPtrTraits> grad_input,
PackedTensorAccessor<scalar_t, 1, at::RestrictPtrTraits> grad_weight,
PackedTensorAccessor<scalar_t, 1, at::RestrictPtrTraits> grad_bias,
const PackedTensorAccessor<accscalar_t, 1, at::RestrictPtrTraits> grad_out_sum,
const PackedTensorAccessor<accscalar_t, 1, at::RestrictPtrTraits> grad_out_dot_demeaned_input,
const PackedTensorAccessor<scalar_t, 1, at::RestrictPtrTraits> weight,
const PackedTensorAccessor<typename std::conditional<train, accscalar_t, scalar_t>::type, 1, at::RestrictPtrTraits> mean_,
const PackedTensorAccessor<typename std::conditional<train, accscalar_t, scalar_t>::type, 1, at::RestrictPtrTraits> var_or_invstd,
accscalar_t epsilon) {
int plane = blockIdx.y * blockDim.y + threadIdx.y;
if (plane >= input.size(1)) {
return;
}
int N = grad_output.size(0) * grad_output.size(2);
accscalar_t gamma = weight.size(0) > 0 ? static_cast<accscalar_t>(weight[plane]) : static_cast<accscalar_t>(1);
//accscalar_t beta = bias.size(0) > 0 ? static_cast<accscalar_t>(bias[plane]) : static_cast<accscalar_t>(0);
accscalar_t mean = static_cast<accscalar_t>(mean_[plane]);
accscalar_t invstd;
if (train) {
invstd = var_or_invstd[plane];
} else {
invstd = static_cast<accscalar_t>(1) / std::sqrt(static_cast<accscalar_t>(var_or_invstd[plane]) + epsilon);
}
accscalar_t weight_val = weight.size(0) > 0 ? static_cast<accscalar_t>(weight[plane]) : accscalar_t(1);
accscalar_t norm = accscalar_t(1) / N;
accscalar_t grad_output_sum = grad_out_sum[plane];
accscalar_t dot_p = grad_out_dot_demeaned_input[plane];
accscalar_t grad_mean = grad_output_sum * norm;
accscalar_t proj_scale = dot_p * norm * invstd * invstd;
accscalar_t grad_scale = invstd * weight_val;
if (grad_input.data() != NULL) {
for (int64_t batch = blockIdx.x; batch < input.size(0); batch += gridDim.x) {
for (int64_t feature = blockIdx.z; feature < input.size(2); feature += gridDim.z) {
scalar_t go = grad_output[batch][plane][feature];
if (train) {
scalar_t inp = input[batch][plane][feature];
accscalar_t proj = (inp - mean) * proj_scale;
grad_input[batch][plane][feature] = static_cast<scalar_t>((go - proj - grad_mean) * grad_scale);
} else {
grad_input[batch][plane][feature] = static_cast<scalar_t>(go * grad_scale);
}
}
}
}
if (grad_weight.size(0) > 0) {
if (threadIdx.x == 0) {
grad_weight[plane] = static_cast<scalar_t>(dot_p * invstd);
}
}
if (grad_bias.size(0) > 0) {
if (threadIdx.x == 0) {
grad_bias[plane] = static_cast<scalar_t>(grad_output_sum);
}
}
}
template<typename scalar_t>
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda_template(
const Tensor& grad_out_, const Tensor& input_, const Tensor& weight_,
const Tensor& running_mean_, const Tensor& running_var_, const Tensor& save_mean_, const Tensor& save_invstd_,
bool train, double epsilon, std::array<bool,3> grad_input_mask) {
using accscalar_t = at::acc_type<scalar_t, true>;
Tensor grad_input_;
Tensor grad_weight_;
Tensor grad_bias_;
auto input_options = input_.options();
if (grad_input_mask[0]) {
grad_input_ = at::empty_like(input_);
}
if (grad_input_mask[1]) {
grad_weight_ = at::empty(input_.size(1), input_options);
}
if (grad_input_mask[2]) {
grad_bias_ = at::empty(input_.size(1), input_options);
}
if (input_options.dtype() == ScalarType::Half) {
input_options.dtype(ScalarType::Float);
}
Tensor grad_out_sum_ = at::empty(input_.size(1), input_options);
Tensor grad_out_dot_demeaned_input_ = at::empty(input_.size(1), input_options);
auto grad_output = reshaped_packed_accessor<scalar_t, 3>(grad_out_);
auto input = reshaped_packed_accessor<scalar_t, 3>(input_);
auto grad_input = reshaped_packed_accessor<scalar_t, 3>(grad_input_);
auto weight = reshaped_packed_accessor<scalar_t, 1>(weight_);
auto grad_weight = reshaped_packed_accessor<scalar_t, 1>(grad_weight_);
auto grad_bias = reshaped_packed_accessor<scalar_t, 1>(grad_bias_);
auto running_mean = reshaped_packed_accessor<scalar_t, 1>(running_mean_);
auto running_var = reshaped_packed_accessor<scalar_t, 1>( running_var_);
auto save_mean = reshaped_packed_accessor<accscalar_t, 1>(save_mean_);
auto save_invstd = reshaped_packed_accessor<accscalar_t, 1>(save_invstd_);
auto grad_out_sum = reshaped_packed_accessor<accscalar_t, 1>(grad_out_sum_);
auto grad_out_dot_demeaned_input = reshaped_packed_accessor<accscalar_t, 1>(grad_out_dot_demeaned_input_);
auto stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(input.size(1));
dim3 threads(getNumThreads(input.size(2)));
constexpr int MAX_THREADS = 512;
int feature_threads = ((std::min<int>(input.size(2), MAX_THREADS)+31)/32)*32; // round to multiples of 32
int plane_threads = std::max<int>(1, MAX_THREADS/feature_threads);
int smem_size = sizeof(accscalar_t)*plane_threads*feature_threads*2;
dim3 threads_red(feature_threads, plane_threads);
dim3 blocks_red(1, (input.size(1)+plane_threads-1)/plane_threads);
if (train) {
grad_sum_and_dot_kernel_parallel<scalar_t, accscalar_t, true><<<blocks_red, threads_red, smem_size>>>(
grad_out_sum, grad_out_dot_demeaned_input,
grad_output, input, save_mean);
} else {
grad_sum_and_dot_kernel_parallel<scalar_t, accscalar_t, false><<<blocks_red, threads_red, smem_size>>>(
grad_out_sum, grad_out_dot_demeaned_input,
grad_output, input, running_mean);
}
{
constexpr int max_blocks_per_input = 60000;
int feature_blocks = std::min<int>(input.size(2), max_blocks_per_input);
int batch_blocks = std::min<int>(input.size(0), max_blocks_per_input / feature_blocks);
dim3 blocks(batch_blocks, (input.size(1)+127)/128, feature_blocks);
dim3 threads(1, 128);
if (train) {
batch_norm_backward_gradient_kernel<scalar_t, accscalar_t, true> <<<blocks, threads, 0, stream>>>
(input, grad_output, grad_input, grad_weight, grad_bias, grad_out_sum, grad_out_dot_demeaned_input,
weight, save_mean, save_invstd, epsilon);
} else {
batch_norm_backward_gradient_kernel<scalar_t, accscalar_t, false> <<<blocks, threads, 0, stream>>>
(input, grad_output, grad_input, grad_weight, grad_bias, grad_out_sum, grad_out_dot_demeaned_input,
weight, running_mean, running_var, epsilon);
}
}
THCudaCheck(cudaGetLastError());
return std::make_tuple(grad_input_, grad_weight_, grad_bias_);
}
Tensor sum_cuda(const Tensor& input) {
return AT_DISPATCH_FLOATING_TYPES(input.type(), "sum_cuda", [&] {
return sum_template<scalar_t>(input);
});
}
std::tuple<Tensor,Tensor> sum_and_scalar_prod_cuda(const Tensor& grad_out, const Tensor& input, const Tensor& mean_inp) {
return AT_DISPATCH_FLOATING_TYPES(input.type(), "sum_and_scalar_prod", [&] {
return sum_and_scalar_prod_template<scalar_t>(grad_out, input, mean_inp);
});
}
Tensor scalar_product_cuda(const Tensor& l, const Tensor& r, const Tensor& mean) {
return AT_DISPATCH_FLOATING_TYPES(l.type(), "scalar_product_cuda", [&] {
return scalar_product_template<scalar_t>(l, r, mean);
});
}
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda(const Tensor& grad_out, const Tensor& self, const Tensor& weight, const Tensor& running_mean, const Tensor& running_var,
const Tensor& save_mean, const Tensor& save_invstd, bool train, double epsilon, std::array<bool,3> grad_input_mask) {
return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_backward", [&] {
return batch_norm_backward_cuda_template<scalar_t>(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask);
});
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("sum_cuda", &sum_cuda, "blubb");
m.def("scalar_product", &scalar_product_cuda, "blubb");
m.def("sum_and_scalar_prod", &sum_and_scalar_prod_cuda, "blubb");
m.def("batch_norm_backward", &batch_norm_backward_cuda, "blubb");
}
"""
import hashlib
import torch
import torch.utils.cpp_extension
name = "test"
sum_ext = torch.utils.cpp_extension.load_inline(name, [], cuda_sources=[csrc], verbose=True)
grads3 = sum_ext.batch_norm_backward(grad_o, inp, weight, running_mean, running_var, sm3, sis3, True, 1e-5, [True, True, True])
grads4 = sum_ext.batch_norm_backward(grad_o.double(), inp.double(), weight.double(), running_mean.double(), running_var.double(), sm3.double(), sis3.double(), True, 1e-5, [True, True, True])
for g1, g2 in zip(grads3, grads2):
print (seed, (g1-g2.float()).abs().max().item())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment