pytorch/torch/nn/modules/batchnorm.py
# TODO: check contiguous in THNN
# TODO: use separate backend functions?
class _BatchNorm(Module):
_version = 2
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True):
super(_BatchNorm, self).__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
if self.affine:
self.weight = Parameter(torch.Tensor(num_features))
self.bias = Parameter(torch.Tensor(num_features))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
if self.track_running_stats:
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
else:
self.register_parameter('running_mean', None)
self.register_parameter('running_var', None)
self.register_parameter('num_batches_tracked', None)
self.reset_parameters()
def reset_running_stats(self):
if self.track_running_stats:
self.running_mean.zero_()
self.running_var.fill_(1)
self.num_batches_tracked.zero_()
def reset_parameters(self):
self.reset_running_stats()
if self.affine:
self.weight.data.uniform_()
self.bias.data.zero_()
def _check_input_dim(self, input):
raise NotImplementedError
def forward(self, input):
self._check_input_dim(input)
exponential_average_factor = 0.0
if self.training and self.track_running_stats:
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / self.num_batches_tracked.item()
else: # use exponential moving average
exponential_average_factor = self.momentum
return F.batch_norm(
input, self.running_mean, self.running_var, self.weight, self.bias,
self.training or not self.track_running_stats,
exponential_average_factor, self.eps)
NOTE: call nn.functional.batch_norm
def batch_norm(input, running_mean, running_var, weight=None, bias=None,
training=False, momentum=0.1, eps=1e-5):
r"""Applies Batch Normalization for each channel across a batch of data.
See :class:`~torch.nn.BatchNorm1d`, :class:`~torch.nn.BatchNorm2d`,
:class:`~torch.nn.BatchNorm3d` for details.
"""
if training:
size = list(input.size())
if reduce(mul, size[2:], size[0]) == 1:
raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))
return torch.batch_norm(
input, weight, bias, running_mean, running_var,
training, momentum, eps, torch.backends.cudnn.enabled
)
NOTE: call torch.batch_norm
pytorch/torch/csrc/api/include/torch/nn/module.h
protected:
/// Registers a parameter with this `Module`.
Tensor& register_parameter(
std::string name,
Tensor tensor,
bool requires_grad = true);
/// Registers a buffer with this `Module`.
Tensor& register_buffer(std::string name, Tensor tensor);
NOTE: requires_grad
of register_parameter
default to True
private:
template <typename T>
using OrderedDict = torch::detail::OrderedDict<std::string, T>;
OrderedDict<Tensor> parameters_;
OrderedDict<Tensor> buffers_;
OrderedDict<std::shared_ptr<Module>> children_;
/// The module's name (e.g. "LSTM").
mutable at::optional<std::string> name_;
/// Whether the module is in training mode.
bool is_training_{true};
NOTE: is_training_
default to True
.
c.f. pytorch/torch/csrc/api/include/torch/detail/ordered_dict.h
/// A simple ordered dictionary implementation, akin to Python's `OrderedDict`.
template <typename Key, typename Value>
class OrderedDict {
public:
As a result, Module::parameters_
and Module::buffers_
are torch::detail::OrderedDict<std::string, Tensor>
actually, Module::children_
is torch::detail::OrderedDict<std::string, std::shared_ptr<Module>>
.
pytorch/torch/csrc/api/src/nn/module.cpp
ParameterCursor Module::parameters() {
return ParameterCursor(*this);
}
ConstParameterCursor Module::parameters() const {
return ConstParameterCursor(*this);
}
BufferCursor Module::buffers() {
return BufferCursor(*this);
}
ConstBufferCursor Module::buffers() const {
return ConstBufferCursor(*this);
}
void Module::train() {
for (auto& child : children_) {
child.value->train();
}
is_training_ = true;
}
void Module::eval() {
for (auto& child : children_) {
child.value->eval();
}
is_training_ = false;
}
bool Module::is_training() const noexcept {
return is_training_;
}
void Module::zero_grad() {
for (auto& child : children_) {
child.value->zero_grad();
}
for (auto& parameter : parameters_) {
auto& grad = parameter->grad();
if (grad.defined()) {
grad = grad.detach();
grad.zero_();
}
}
}
NOTE: Module::parameters()
and Module::buffers()
return Cursor
i.e. the iterable
. Module::train()
and Module::eval()
set itselft and its children is_training_
to true
or false
.
Tensor& Module::register_parameter(
std::string name,
Tensor tensor,
bool requires_grad) {
tensor.set_requires_grad(requires_grad);
return parameters_.insert(std::move(name), std::move(tensor));
}
Tensor& Module::register_buffer(std::string name, Tensor tensor) {
return buffers_.insert(std::move(name), std::move(tensor));
}
NOTE: requires_grad
of Module::register_parameter
defaults to true
in header.
Module::register_parameter
insertrequires_grad_=true
Tensor intoModule::parameters_
Module::register_buffer
insertrequires_grad_=false
Tensor intoModule::buffers_
#include <torch/csrc/utils/variadic.h>
#include <torch/tensor.h>
#include <memory>
#include <type_traits>
#include <utility>
#define TORCH_ARG(T, name) \
auto name(const T& new_##name)->decltype(*this) { \
this->name##_ = new_##name; \
return *this; \
} \
auto name(T&& new_##name)->decltype(*this) { \
this->name##_ = std::move(new_##name); \
return *this; \
} \
const T& name() const noexcept { \
return this->name##_; \
} \
T name##_
/// Defines a class `Name` which inherits from `nn::ModuleHolder` to provide a
/// wrapper over a `std::shared_ptr<Impl>`.
#define TORCH_MODULE_IMPL(Name, Impl) \
class Name : public torch::nn::ModuleHolder<Impl> { \
public: \
using torch::nn::ModuleHolder<Impl>::ModuleHolder; \
Name(const Name&) = default; \
Name(Name&&) = default; \
Name(Name& other) : Name(static_cast<const Name&>(other)) {} \
Name& operator=(const Name&) = default; \
Name& operator=(Name&&) = default; \
}
/// Like `TORCH_MODULE_IMPL`, but defaults the `Impl` name to `<Name>Impl`.
#define TORCH_MODULE(Name) TORCH_MODULE_IMPL(Name, Name##Impl)
NOTE: It is TORCH_MODULE
that wrap <Name>Impl
and provide interface
#include <torch/nn/cloneable.h>
#include <torch/nn/pimpl.h>
#include <torch/tensor.h>
#include <cstdint>
namespace torch {
namespace nn {
struct BatchNormOptions {
/* implicit */ BatchNormOptions(int64_t features);
TORCH_ARG(int64_t, features);
TORCH_ARG(bool, affine) = true;
TORCH_ARG(bool, stateful) = false;
TORCH_ARG(double, eps) = 1e-5;
TORCH_ARG(double, momentum) = 0.1;
};
class BatchNormImpl : public torch::nn::Cloneable<BatchNormImpl> {
public:
template <typename... Ts>
explicit BatchNormImpl(Ts&&... ts)
: BatchNormImpl(BatchNormOptions(std::forward<Ts>(ts)...)) {}
explicit BatchNormImpl(BatchNormOptions options);
void reset() override;
Tensor forward(Tensor input);
Tensor pure_forward(Tensor input, Tensor mean, Tensor variance);
BatchNormOptions options;
Tensor weight;
Tensor bias;
Tensor running_mean;
Tensor running_variance;
};
TORCH_MODULE(BatchNorm);
} // namespace nn
} // namespace torch
Note:
BatchNormImpl
:
- data member:
options
,weight
,bias
,running_mean
andrunning_variance
- method:
reset
,forward
andpure_forward
. TORCH_MODULE(BatchNorm);
defineclass BatchNorm
which inherits fromtorch::nn::ModuleHolder<BatchNormImpl>
Let's expand macro TORCH_MODULE
class BatchNorm : public torch::nn::ModuleHolder<BatchNormImpl> {
public:
using torch::nn::ModuleHolder<BatchNormImpl>::ModuleHolder;
BatchNorm(const BatchNorm&) = default;
BatchNorm(BatchNorm&&) = default;
BatchNorm(BatchNorm& other) : BatchNorm(static_cast<const BatchNorm&>(other)) {}
BatchNorm& operator=(const BatchNorm&) = default;
BatchNorm& operator=(BatchNorm&&) = default;
}
#include <torch/nn/modules/batchnorm.h>
#include <torch/cuda.h>
#include <torch/tensor.h>
#include <ATen/Error.h>
#include <cstddef>
#include <utility>
#include <vector>
BatchNormOptions::BatchNormOptions(int64_t features) : features_(features) {}
BatchNormImpl::BatchNormImpl(BatchNormOptions options)
: options(std::move(options)) {
reset();
}
NOTE: initialize options
and invoke reset()
when constructed with BatchNormOptions
void BatchNormImpl::reset() {
if (options.affine_) {
weight = register_parameter(
"weight", torch::empty({options.features_}).uniform_());
bias = register_parameter("bias", torch::zeros({options.features_}));
}
if (options.stateful_) {
running_mean =
register_buffer("running_mean", torch::zeros({options.features_}));
running_variance =
register_buffer("running_variance", torch::ones({options.features_}));
}
}
NOTE: copy assignment for weight
, bias
, running_mean
and running_variance
when both options.affine_
and options.stateful_
are True. if not, each Tensor
will be default-initialized.
I think, affine_
is the equivalent for affine
in python, so do stateful_
for track_running_stats
Tensor BatchNormImpl::forward(Tensor input) {
return pure_forward(input, Tensor(), Tensor());
}
Tensor BatchNormImpl::pure_forward(Tensor input, Tensor mean, Tensor variance) {
auto& running_mean = options.stateful_ ? this->running_mean : mean;
auto& running_variance =
options.stateful_ ? this->running_variance : variance;
if (is_training()) {
const auto num_channels = input.dim() > 1 ? input.size(1) : 1;
AT_CHECK(
input.numel() / num_channels > 1,
"BatchNorm expected more than 1 value per channel when training!");
}
return torch::batch_norm(
input,
weight,
bias,
running_mean,
running_variance,
is_training(),
options.momentum_,
options.eps_,
torch::cuda::cudnn_is_available());
}
NOTE:
forward
callpure_forward
.running_mean
andrunning_variance
keep persistent state viapass by reference
instead ofpass by value
. Whenoptions.stateful_
is False, new Tensor initialized with default i.e.Tensor()
is used.is_training()
is equivalent oftraining
in python. batch size has to be greater than 1 when training.
Tensor batch_norm(
const Tensor& input, const Tensor& weight /* optional */, const Tensor& bias /* optional */,
const Tensor& running_mean /* optional */, const Tensor& running_var /* optional */,
bool training, double momentum, double eps, bool cudnn_enabled) {
NOTE:
input
,weight
,bias
,running_mean
andrunning_var
are all pass by const reference.- It seems that
batch_norm
can NOT updaterunning_mean
andrunning_var
. but why passmomentum
? - ?? how to update
running_mean
andrunning_var
when training. at backward??
auto num_features = input.sizes()[1];
if (running_mean.defined()) {
check_dims_match_num_input_features("running_mean", num_features, running_mean.numel());
} else if (!training) {
throw std::runtime_error("running_mean must be defined in evaluation mode");
}
if (running_var.defined()) {
check_dims_match_num_input_features("running_var", num_features, running_var.numel());
} else if (!training) {
throw std::runtime_error("running_var must be defined in evaluation mode");
}
if (weight.defined()) {
check_dims_match_num_input_features("weight", num_features, weight.numel());
}
if (bias.defined()) {
check_dims_match_num_input_features("bias", num_features, bias.numel());
}
NOTE:: number of element in running_mean
, running_var
, weight
and bias
should be same with number of channel in iput
i.e the 2nd dim.
return at::thnn_batch_norm(
input.contiguous(), weight, bias,
running_mean, running_var, training, momentum, eps);
}
NOTE: we just track cpu implementation
static inline Tensor thnn_batch_norm( const Tensor & self,
const Tensor & weight,
const Tensor & bias,
const Tensor & running_mean,
const Tensor & running_var,
bool training,
double momentum,
double eps) {
return infer_type(self).thnn_batch_norm(self,
weight,
bias,
running_mean,
running_var,
training,
momentum, eps);
}
static inline Type & infer_type(const Tensor & t) {
AT_ASSERT(t.defined(), "undefined Tensor");
return t.type();
}
static inline Type & infer_type(const Tensor & t) {
AT_ASSERT(t.defined(), "undefined Tensor");
return t.type();
}
struct Tensor : public detail::TensorBase {
Tensor() : TensorBase() {}
Tensor(TensorImpl * self, bool retain) : TensorBase(self, retain) {}
Tensor(const TensorBase & rhs) : TensorBase(rhs) {}
Tensor(const Tensor & rhs) = default;
Tensor(Tensor && rhs) noexcept = default;
Type & type() const {
return pImpl->type();
}
std::unique_ptr<Storage> storage() const {
return pImpl->storage();
template<typename T>
T * data() const;
NOTE:
parent class TensorBase
actually is TensorBaseImpl<true>
// Tensor is a "generic" object holding a pointer to the underlying > TensorImpl object, which // has an embedded reference count. In this way, Tensor is similar to > boost::intrusive_ptr. // // For example: // // void func(Tensor a) { // Tensor b = a; // ... // } // // In this example, when we say Tensor b = a, we are creating a new object > that points to the // same underlying TensorImpl, and bumps its reference count. When b goes out > of scope, the // destructor decrements the reference count by calling release() on the > TensorImpl it points to. // The existing constructors, operator overloads, etc. take care to implement > the correct semantics. // // Note that Tensor can also be NULL, i.e. it is not associated with any > underlying TensorImpl, and // special care must be taken to handle this.
template<bool is_strong>
struct TensorBaseImpl {
TensorBaseImpl(): TensorBaseImpl(UndefinedTensor::singleton(), false) {}
TensorBaseImpl(TensorImpl * self, bool should_retain)
: pImpl(self) {
if (pImpl == nullptr) {
throw std::runtime_error("TensorBaseImpl with nullptr not supported");
}
if(should_retain && pImpl != UndefinedTensor::singleton()) {
retain();
}
}
TensorBaseImpl(const TensorBaseImpl & rhs)
: pImpl(rhs.pImpl) {
if (pImpl != UndefinedTensor::singleton()) {
retain();
}
}
TensorBaseImpl(TensorBaseImpl && rhs) noexcept
: pImpl(rhs.pImpl) {
rhs.pImpl = UndefinedTensor::singleton();
}
~TensorBaseImpl() {
if (pImpl != UndefinedTensor::singleton()) {
release();
}
}
NOTE: // TensorBaseImpl is the base class for Tensor which handles the reference counting
TensorImpl * detach() {
TensorImpl * ret = pImpl;
pImpl = UndefinedTensor::singleton();
return ret;
}
bool defined() const {
return pImpl != UndefinedTensor::singleton();
}
friend struct Type;
//TODO(zach): sort out friend structes
public:
TensorImpl * pImpl;
NOTE: Tensor
inherit pImpl
pointing to TensorImpl
using TensorBase = TensorBaseImpl<true>;
using WeakTensorBase = TensorBaseImpl<false>;
struct TensorImpl : public Retainable {
explicit TensorImpl(Type * type)
: is_scalar(false), type_(type) {}
Type & type() const {
return *type_;
}
protected:
bool is_scalar;
Type * type_;
};
NOTE: Here is where type_
come from.
struct AT_API Type {
explicit Type(Context * context)
: context(context) {}
virtual ~Type() {}
virtual Tensor thnn_batch_norm(const Tensor & self, const Tensor & weight, const Tensor & bias, const Tensor & running_mean, const Tensor & running_var, bool training, double momentum, double eps) const;
#define TH_CONCAT_STRING_2(x,y) TH_CONCAT_STRING_2_EXPAND(x,y)
#define TH_CONCAT_STRING_2_EXPAND(x,y) #x #y
#define TH_CONCAT_STRING_3(x,y,z) TH_CONCAT_STRING_3_EXPAND(x,y,z)
#define TH_CONCAT_STRING_3_EXPAND(x,y,z) #x #y #z
#define TH_CONCAT_STRING_4(x,y,z,w) TH_CONCAT_STRING_4_EXPAND(x,y,z,w)
#define TH_CONCAT_STRING_4_EXPAND(x,y,z,w) #x #y #z #w
#define TH_CONCAT_2(x,y) TH_CONCAT_2_EXPAND(x,y)
#define TH_CONCAT_2_EXPAND(x,y) x ## y
#define TH_CONCAT_3(x,y,z) TH_CONCAT_3_EXPAND(x,y,z)
#define TH_CONCAT_3_EXPAND(x,y,z) x ## y ## z
#define TH_CONCAT_4_EXPAND(x,y,z,w) x ## y ## z ## w
#define TH_CONCAT_4(x,y,z,w) TH_CONCAT_4_EXPAND(x,y,z,w)
#define THMin(X, Y) ((X) < (Y) ? (X) : (Y))
#define THMax(X, Y) ((X) > (Y) ? (X) : (Y))
#define THTensor_(NAME) TH_CONCAT_4(TH,Real,Tensor_,NAME)
#define THNN_(NAME) TH_CONCAT_3(THNN_, Real, NAME)
- forward
void THNN_(BatchNormalization_updateOutput)(
THNNState *state, THTensor *input, THTensor *output,
THTensor *weight, THTensor *bias,
THTensor *running_mean, THTensor *running_var,
THTensor *save_mean, THTensor *save_std,
bool train, double momentum, double eps)
{
THTensor_(resizeAs)(output, input);
int64_t nInput = THTensor_(size)(input, 1);
int64_t f;
ptrdiff_t n = THTensor_(nElement)(input) / nInput;
if (train) {
THTensor_(resize1d)(save_mean, nInput);
THTensor_(resize1d)(save_std, nInput);
}
for feature map f
#pragma omp parallel for
for (f = 0; f < nInput; ++f) {
THTensor *in = THTensor_(newSelect)(input, 1, f);
THTensor *out = THTensor_(newSelect)(output, 1, f);
real mean, invstd;
if (train) {
// compute mean per input
accreal sum = 0;
TH_TENSOR_APPLY(real, in, sum += *in_data;);
mean = (real) sum / n;
THTensor_(set1d)(save_mean, f, (real) mean);
// compute variance per input
sum = 0;
TH_TENSOR_APPLY(real, in,
sum += (*in_data - mean) * (*in_data - mean););
if (sum == 0 && eps == 0.0) {
invstd = 0;
} else {
invstd = (real) (1 / sqrt(sum/n + eps));
}
THTensor_(set1d)(save_std, f, (real) invstd);
// update running averages
if (running_mean) {
THTensor_(set1d)(running_mean, f,
(real) (momentum * mean + (1 - momentum) * THTensor_(get1d)(running_mean, f)));
}
if (running_var) {
accreal unbiased_var = sum / (n - 1);
THTensor_(set1d)(running_var, f,
(real) (momentum * unbiased_var + (1 - momentum) * THTensor_(get1d)(running_var, f)));
}
} else {
mean = THTensor_(get1d)(running_mean, f);
invstd = 1 / sqrt(THTensor_(get1d)(running_var, f) + eps);
}
// compute output
real w = weight ? THTensor_(get1d)(weight, f) : 1;
real b = bias ? THTensor_(get1d)(bias, f) : 0;
TH_TENSOR_APPLY2(real, in, real, out,
*out_data = (real) (((*in_data - mean) * invstd) * w + b););
THTensor_(free)(out);
THTensor_(free)(in);
}
}
void THNN_(BatchNormalization_backward)(
THNNState *state, THTensor *input, THTensor *gradOutput, THTensor *gradInput,
THTensor *gradWeight, THTensor *gradBias, THTensor *weight,
THTensor *running_mean, THTensor *running_var,
THTensor *save_mean, THTensor *save_std,
bool train, double scale, double eps)
{
THNN_CHECK_SHAPE(input, gradOutput);
int64_t nInput = THTensor_(size)(input, 1);
int64_t f;
ptrdiff_t n = THTensor_(nElement)(input) / nInput;
if (gradInput) {
THTensor_(resizeAs)(gradInput, input);
}
#pragma omp parallel for
for (f = 0; f < nInput; ++f) {
THTensor *in = THTensor_(newSelect)(input, 1, f);
THTensor *gradOut = THTensor_(newSelect)(gradOutput, 1, f);
real w = weight ? THTensor_(get1d)(weight, f) : 1;
real mean, invstd;
if (train) {
mean = THTensor_(get1d)(save_mean, f);
invstd = THTensor_(get1d)(save_std, f);
} else {
mean = THTensor_(get1d)(running_mean, f);
invstd = 1 / sqrt(THTensor_(get1d)(running_var, f) + eps);
}
// sum over all gradOutput in feature plane
accreal sum = 0;
TH_TENSOR_APPLY(real, gradOut, sum += *gradOut_data;);
// dot product of the Q(X) and gradOuput
accreal dotp = 0;
TH_TENSOR_APPLY2(real, in, real, gradOut,
dotp += (*in_data - mean) * (*gradOut_data););
if (gradInput) {
THTensor *gradIn = THTensor_(newSelect)(gradInput, 1, f);
if (train) {
// when in training mode
// Q(X) = X - E[x] ; i.e. input centered to zero mean
// Y = Q(X) / σ ; i.e. BN output before weight and bias
// dL/dX = (Q(dL/dY) - dot(Y, dL/dY) * Y) / σ * w
// projection of gradOutput on to output scaled by std
real k = (real) dotp * invstd * invstd / n;
TH_TENSOR_APPLY2(real, gradIn, real, in,
*gradIn_data = (*in_data - mean) * k;);
accreal gradMean = sum / n;
TH_TENSOR_APPLY2(real, gradIn, real, gradOut,
*gradIn_data = (*gradOut_data - gradMean - *gradIn_data) * invstd * w;);
} else {
// when in evaluation mode
// Q(X) = X - running_mean ; i.e. input centered to zero mean
// Y = Q(X) / running_std ; i.e. BN output before weight and bias
// dL/dX = w / running_std
TH_TENSOR_APPLY2(real, gradIn, real, gradOut,
*gradIn_data = *gradOut_data * invstd * w;);
}
THTensor_(free)(gradIn);
}
if (gradWeight) {
real val = THTensor_(get1d)(gradWeight, f);
THTensor_(set1d)(gradWeight, f, val + scale * dotp * invstd);
}
if (gradBias) {
real val = THTensor_(get1d)(gradBias, f);
THTensor_(set1d)(gradBias, f, val + scale * sum);
}
THTensor_(free)(gradOut);
THTensor_(free)(in);
}
}
- As @moskomule pointed out, you have to specify how many feature channels will your input have (because that’s the number of BatchNorm parameters). Batch and spatial dimensions don’t matter.
- BatchNorm will only update the running averages in
train
mode, so if you want the model to keep updating them in test time, you will have to keepBatchNorm
modules in the training mode. See the C implementation for details (it should be readable). - About ReLU and MaxPool - if you think about it for a moment both ReLU + MaxPool and MaxPool + ReLU are equivalent operations, with the second option being 37.5% more efficient (
numel + numel
in first casenumel + numel/4
in the second case, wherenumel
is the number of elements in the tensor). That’s why the example has a different order.
- https://discuss.pytorch.org/t/example-on-how-to-use-batch-norm/216/ Q: At test time, I would like to freeze both the weights, (lambda and beta), as well as freeze the running averages that is has computed. (Ostensibly because it has a good estimate for those from training already). So I basically expect that I would want all 4 of those values frozen. A: Yeah in that case if you keep the BatchNorm modules in evaluation mode, and you won’t pass their parameters to the optimizer (best to set their requires_grad to False), they will be completely frozen.