Skip to content

Instantly share code, notes, and snippets.

@raytroop
Last active September 22, 2018 02:20
Show Gist options
  • Save raytroop/e52f080c54bcd95f0b1d710c6b8680bd to your computer and use it in GitHub Desktop.
Save raytroop/e52f080c54bcd95f0b1d710c6b8680bd to your computer and use it in GitHub Desktop.
explore pytorch BatchNorm , the relationship among `track_running_stats`, `eval` and `train` mode
"""
explore the relationship among `track_running_stats`, `eval` and `train` mode
"""
import torch
from torch import nn
import numpy as np
torch.manual_seed(42)
torch.cuda.seed_all()
x = torch.randn(20, 1, 32, 32) * 2 + 3 # mu=3, std=2
x_split = torch.split(x, 2)
log_idx = 0
def testcase(bn, data, istraining=True, tracking=True):
global log_idx
print('///////{:<2}//////'.format(log_idx))
log_idx += 1
bn.train(istraining)
bn.track_running_stats = tracking
out = bn(data[np.random.randint(0, 10)])
print('weight:', bn.weight)
print('bias: ', bn.bias)
print('running_mean: ', bn.running_mean)
print('running_var: ', bn.running_var)
print('num_batches_tracked: ', bn.num_batches_tracked)
return out
nb_case = -1
if nb_case == 0:
print('nb_case == 0')
bn1 = nn.BatchNorm2d(1, affine=False, track_running_stats=False)
testcase(bn1, x_split, istraining=True, tracking=True)
# torch/nn/modules/batchnorm.py", line 57, in forward
# self.num_batches_tracked += 1
# TypeError: unsupported operand type(s) for +=: 'NoneType' and 'int'
elif nb_case == 1:
print('nb_case == 1')
bn1 = nn.BatchNorm2d(1, affine=False, track_running_stats=False)
out = testcase(bn1, x_split, istraining=True, tracking=False)
print(np.mean(out.data.numpy()), np.std(out.data.numpy()))
out = testcase(bn1, x_split, istraining=True, tracking=False)
print(np.mean(out.data.numpy()), np.std(out.data.numpy()))
# PASS
# nb_case == 1
# ///////0 //////
# weight: None
# bias: None
# running_mean: None
# running_var: None
# num_batches_tracked: None
# 6.0535967e-09 0.9999952
# ///////1 //////
# weight: None
# bias: None
# running_mean: None
# running_var: None
# num_batches_tracked: None
# -2.7939677e-09 0.9999952
elif nb_case == 2:
print('nb_case == 2')
bn1 = nn.BatchNorm2d(1, affine=False, track_running_stats=False)
testcase(bn1, x_split, istraining=False, tracking=True)
# torch/nn/functional.py", line 1254, in batch_norm
# training, momentum, eps, torch.backends.cudnn.enabled
# RuntimeError: running_mean must be defined in evaluation mode
elif nb_case == 3:
print('nb_case == 3')
bn1 = nn.BatchNorm2d(1, affine=False, track_running_stats=False)
out = testcase(bn1, x_split, istraining=False, tracking=False)
print(np.mean(out.data.numpy()), np.std(out.data.numpy()))
out = testcase(bn1, x_split, istraining=False, tracking=False)
print(np.mean(out.data.numpy()), np.std(out.data.numpy()))
# PASS
# nb_case == 3
# ///////0 //////
# weight: None
# bias: None
# running_mean: None
# running_var: None
# num_batches_tracked: None
# -1.1175871e-08 0.999995
# ///////1 //////
# weight: None
# bias: None
# running_mean: None
# running_var: None
# num_batches_tracked: None
# -9.313226e-09 0.9999949
# If you want to track running stats in whatever **train mode** or *eval mode**
# you should instantiate `BatchNorm` layer with `track_running_stats=True` i.e. default option.
# if not, you won't track running stats forever.
nb2_case = 5
if nb2_case == 0:
print('nb2_case == 0')
bn1 = nn.BatchNorm2d(1, affine=False, track_running_stats=True)
out = testcase(bn1, x_split, istraining=True, tracking=True)
print(np.mean(out.data.numpy()), np.std(out.data.numpy()))
out = testcase(bn1, x_split, istraining=True, tracking=True)
print(np.mean(out.data.numpy()), np.std(out.data.numpy()))
# PASS
# ///////0 //////
# weight: None
# bias: None
# running_mean: tensor([0.0043])
# running_var: tensor([0.9971])
# num_batches_tracked: tensor(1)
# -9.313226e-09 0.9999949
# ///////1 //////
# weight: None
# bias: None
# running_mean: tensor([0.0025])
# running_var: tensor([0.9984])
# num_batches_tracked: tensor(2)
# 3.7252903e-09 0.99999505
elif nb2_case == 1:
print('nb2_case == 1')
bn1 = nn.BatchNorm2d(1, affine=False, track_running_stats=True)
out = testcase(bn1, x_split, istraining=True, tracking=False)
print(np.mean(out.data.numpy()), np.std(out.data.numpy()))
out = testcase(bn1, x_split, istraining=True, tracking=False)
print(np.mean(out.data.numpy()), np.std(out.data.numpy()))
# PASS
# ///////0 //////
# weight: None
# bias: None
# running_mean: tensor([0.])
# running_var: tensor([1.])
# num_batches_tracked: tensor(0)
# 6.0535967e-09 0.9999952
# ///////1 //////
# weight: None
# bias: None
# running_mean: tensor([0.])
# running_var: tensor([1.])
# num_batches_tracked: tensor(0)
# -9.313226e-09 0.9999949
elif nb2_case == 2:
print('nb2_case == 2')
bn1 = nn.BatchNorm2d(1, affine=False, track_running_stats=True)
out = testcase(bn1, x_split, istraining=False, tracking=True)
print(np.mean(out.data.numpy()), np.std(out.data.numpy()))
out = testcase(bn1, x_split, istraining=False, tracking=True)
print(np.mean(out.data.numpy()), np.std(out.data.numpy()))
# PASS
# nb2_case == 2
# ///////0 //////
# weight: None
# bias: None
# running_mean: tensor([0.])
# running_var: tensor([1.])
# num_batches_tracked: tensor(0)
# 2.974208 2.0096714
# ///////1 //////
# weight: None
# bias: None
# running_mean: tensor([0.])
# running_var: tensor([1.])
# num_batches_tracked: tensor(0)
# 3.008463 2.0370362
# Employ default `running_mean` and `running_var` (i.e. 0, 1), out should be same with intput
# BatchNorm WONT update the running averages in eval mode even though `track_running_stats` is True
# This is because of `exponential_average_factor = 0.0` instead of `momentum`.
# In this way, BatchNorm update running stats only if both `training` and `track_running_stats` are True,
elif nb2_case == 3:
print('nb2_case == 3')
bn1 = nn.BatchNorm2d(1, affine=False, track_running_stats=True)
out = testcase(bn1, x_split, istraining=False, tracking=False)
print(np.mean(out.data.numpy()), np.std(out.data.numpy()))
out = testcase(bn1, x_split, istraining=False, tracking=False)
print(np.mean(out.data.numpy()), np.std(out.data.numpy()))
# PASS
# nb2_case == 3
# ///////0 //////
# weight: None
# bias: None
# running_mean: tensor([0.])
# running_var: tensor([1.])
# num_batches_tracked: tensor(0)
# 3.4458935e-08 0.9999988
# ///////1 //////
# weight: None
# bias: None
# running_mean: tensor([0.])
# running_var: tensor([1.])
# num_batches_tracked: tensor(0)
# -3.306195e-08 0.9999987
elif nb2_case == 4:
print('nb2_case == 4')
bn1 = nn.BatchNorm2d(1, affine=False, track_running_stats=True)
out = testcase(bn1, x_split, istraining=True, tracking=True)
print(np.mean(out.data.numpy()), np.std(out.data.numpy()))
out = testcase(bn1, x_split, istraining=True, tracking=True)
print(np.mean(out.data.numpy()), np.std(out.data.numpy()))
out = testcase(bn1, x_split, istraining=False, tracking=False)
print(np.mean(out.data.numpy()), np.std(out.data.numpy()))
out = testcase(bn1, x_split, istraining=False, tracking=False)
print(np.mean(out.data.numpy()), np.std(out.data.numpy()))
# nb2_case == 4
# ///////0 //////
# weight: None
# bias: None
# running_mean: tensor([0.3042])
# running_var: tensor([1.3195])
# num_batches_tracked: tensor(1)
# 3.4458935e-08 0.9999988
# ///////1 //////
# weight: None
# bias: None
# running_mean: tensor([0.5779])
# running_var: tensor([1.6071])
# num_batches_tracked: tensor(2)
# 3.4458935e-08 0.9999988
# ///////2 //////
# weight: None
# bias: None
# running_mean: tensor([0.5779])
# running_var: tensor([1.6071])
# num_batches_tracked: tensor(2)
# -3.306195e-08 0.9999987
# ///////3 //////
# weight: None
# bias: None
# running_mean: tensor([0.5779])
# running_var: tensor([1.6071])
# num_batches_tracked: tensor(2)
# -6.519258e-09 0.9999988
# PROBLEM: `istraining=False, tracking=False` WONT employ running stats
elif nb2_case == 5:
print('nb2_case == 5')
bn1 = nn.BatchNorm2d(1, momentum=0.2, affine=False, track_running_stats=True)
out = testcase(bn1, x_split, istraining=True, tracking=True)
print(np.mean(out.data.numpy()), np.std(out.data.numpy()))
out = testcase(bn1, x_split, istraining=True, tracking=True)
print(np.mean(out.data.numpy()), np.std(out.data.numpy()))
out = testcase(bn1, x_split, istraining=False, tracking=True)
print(np.mean(out.data.numpy()), np.std(out.data.numpy()))
out = testcase(bn1, x_split, istraining=False, tracking=True)
print(np.mean(out.data.numpy()), np.std(out.data.numpy()))
# PASS
# nb2_case == 5
# ///////0 //////
# weight: None
# bias: None
# running_mean: tensor([0.6083])
# running_var: tensor([1.6390])
# num_batches_tracked: tensor(1)
# 3.4458935e-08 0.9999988
# ///////1 //////
# weight: None
# bias: None
# running_mean: tensor([1.0950])
# running_var: tensor([2.1502])
# num_batches_tracked: tensor(2)
# 3.4458935e-08 0.9999988
# ///////2 //////
# weight: None
# bias: None
# running_mean: tensor([1.0950])
# running_var: tensor([2.1502])
# num_batches_tracked: tensor(2)
# 1.2913984 1.3449724
# ///////3 //////
# weight: None
# bias: None
# running_mean: tensor([1.0950])
# running_var: tensor([2.1502])
# num_batches_tracked: tensor(2)
# 1.3512965 1.3416901
###########################################################################################
# `track_running_stats` determines whether running stats is available
# `training` determines output normalization and running stats update when `track_running_stats`
# is NOT available, batch stats is used instead.
###########################################################################################

python api

  • pytorch/torch/nn/modules/batchnorm.py
# TODO: check contiguous in THNN
# TODO: use separate backend functions?
class _BatchNorm(Module):
    _version = 2

    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
                 track_running_stats=True):
        super(_BatchNorm, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        if self.affine:
            self.weight = Parameter(torch.Tensor(num_features))
            self.bias = Parameter(torch.Tensor(num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(num_features))
            self.register_buffer('running_var', torch.ones(num_features))
            self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
        else:
            self.register_parameter('running_mean', None)
            self.register_parameter('running_var', None)
            self.register_parameter('num_batches_tracked', None)
        self.reset_parameters()

    def reset_running_stats(self):
        if self.track_running_stats:
            self.running_mean.zero_()
            self.running_var.fill_(1)
            self.num_batches_tracked.zero_()

    def reset_parameters(self):
        self.reset_running_stats()
        if self.affine:
            self.weight.data.uniform_()
            self.bias.data.zero_()

    def _check_input_dim(self, input):
        raise NotImplementedError

    def forward(self, input):
        self._check_input_dim(input)

        exponential_average_factor = 0.0

        if self.training and self.track_running_stats:
            self.num_batches_tracked += 1
            if self.momentum is None:  # use cumulative moving average
                exponential_average_factor = 1.0 / self.num_batches_tracked.item()
            else:  # use exponential moving average
                exponential_average_factor = self.momentum

        return F.batch_norm(
            input, self.running_mean, self.running_var, self.weight, self.bias,
            self.training or not self.track_running_stats,
            exponential_average_factor, self.eps)

NOTE: call nn.functional.batch_norm

def batch_norm(input, running_mean, running_var, weight=None, bias=None,
               training=False, momentum=0.1, eps=1e-5):
    r"""Applies Batch Normalization for each channel across a batch of data.

    See :class:`~torch.nn.BatchNorm1d`, :class:`~torch.nn.BatchNorm2d`,
    :class:`~torch.nn.BatchNorm3d` for details.
    """
    if training:
        size = list(input.size())
        if reduce(mul, size[2:], size[0]) == 1:
            raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))
    return torch.batch_norm(
        input, weight, bias, running_mean, running_var,
        training, momentum, eps, torch.backends.cudnn.enabled
    )

NOTE: call torch.batch_norm

0.0 class Module

  • pytorch/torch/csrc/api/include/torch/nn/module.h
 protected:
  /// Registers a parameter with this `Module`.
  Tensor& register_parameter(
      std::string name,
      Tensor tensor,
      bool requires_grad = true);
  /// Registers a buffer with this `Module`.
  Tensor& register_buffer(std::string name, Tensor tensor);

NOTE: requires_grad of register_parameter default to True

 private:
  template <typename T>
  using OrderedDict = torch::detail::OrderedDict<std::string, T>;

  OrderedDict<Tensor> parameters_;
  OrderedDict<Tensor> buffers_;
  OrderedDict<std::shared_ptr<Module>> children_;

  /// The module's name (e.g. "LSTM").
  mutable at::optional<std::string> name_;

  /// Whether the module is in training mode.
  bool is_training_{true};

NOTE: is_training_ default to True. c.f. pytorch/torch/csrc/api/include/torch/detail/ordered_dict.h

/// A simple ordered dictionary implementation, akin to Python's `OrderedDict`.
template <typename Key, typename Value>
class OrderedDict {
 public:

As a result, Module::parameters_ and Module::buffers_ are torch::detail::OrderedDict<std::string, Tensor> actually, Module::children_ is torch::detail::OrderedDict<std::string, std::shared_ptr<Module>>.

  • pytorch/torch/csrc/api/src/nn/module.cpp
ParameterCursor Module::parameters() {
  return ParameterCursor(*this);
}

ConstParameterCursor Module::parameters() const {
  return ConstParameterCursor(*this);
}

BufferCursor Module::buffers() {
  return BufferCursor(*this);
}

ConstBufferCursor Module::buffers() const {
  return ConstBufferCursor(*this);
}

void Module::train() {
  for (auto& child : children_) {
    child.value->train();
  }
  is_training_ = true;
}

void Module::eval() {
  for (auto& child : children_) {
    child.value->eval();
  }
  is_training_ = false;
}

bool Module::is_training() const noexcept {
  return is_training_;
}

void Module::zero_grad() {
  for (auto& child : children_) {
    child.value->zero_grad();
  }
  for (auto& parameter : parameters_) {
    auto& grad = parameter->grad();
    if (grad.defined()) {
      grad = grad.detach();
      grad.zero_();
    }
  }
}

NOTE: Module::parameters() and Module::buffers() return Cursor i.e. the iterable. Module::train() and Module::eval() set itselft and its children is_training_ to true or false.

Tensor& Module::register_parameter(
    std::string name,
    Tensor tensor,
    bool requires_grad) {
  tensor.set_requires_grad(requires_grad);
  return parameters_.insert(std::move(name), std::move(tensor));
}

Tensor& Module::register_buffer(std::string name, Tensor tensor) {
  return buffers_.insert(std::move(name), std::move(tensor));
}

NOTE: requires_grad of Module::register_parameter defaults to true in header.

  • Module::register_parameter insert requires_grad_=true Tensor into Module::parameters_
  • Module::register_buffer insert requires_grad_=false Tensor into Module::buffers_

0.1 pytorch/torch/csrc/api/include/torch/nn/pimpl.h

#include <torch/csrc/utils/variadic.h>
#include <torch/tensor.h>

#include <memory>
#include <type_traits>
#include <utility>
#define TORCH_ARG(T, name)                          \
  auto name(const T& new_##name)->decltype(*this) { \
    this->name##_ = new_##name;                     \
    return *this;                                   \
  }                                                 \
  auto name(T&& new_##name)->decltype(*this) {      \
    this->name##_ = std::move(new_##name);          \
    return *this;                                   \
  }                                                 \
  const T& name() const noexcept {                  \
    return this->name##_;                           \
  }                                                 \
  T name##_

/// Defines a class `Name` which inherits from `nn::ModuleHolder` to provide a
/// wrapper over a `std::shared_ptr<Impl>`.
#define TORCH_MODULE_IMPL(Name, Impl)                            \
  class Name : public torch::nn::ModuleHolder<Impl> {            \
   public:                                                       \
    using torch::nn::ModuleHolder<Impl>::ModuleHolder;           \
    Name(const Name&) = default;                                 \
    Name(Name&&) = default;                                      \
    Name(Name& other) : Name(static_cast<const Name&>(other)) {} \
    Name& operator=(const Name&) = default;                      \
    Name& operator=(Name&&) = default;                           \
  }

/// Like `TORCH_MODULE_IMPL`, but defaults the `Impl` name to `<Name>Impl`.
#define TORCH_MODULE(Name) TORCH_MODULE_IMPL(Name, Name##Impl)

NOTE: It is TORCH_MODULE that wrap <Name>Impl and provide interface

1. pytorch/torch/csrc/api/include/torch/nn/modules/batchnorm.h

#include <torch/nn/cloneable.h>
#include <torch/nn/pimpl.h>
#include <torch/tensor.h>

#include <cstdint>
namespace torch {
namespace nn {
struct BatchNormOptions {
  /* implicit */ BatchNormOptions(int64_t features);
  TORCH_ARG(int64_t, features);
  TORCH_ARG(bool, affine) = true;
  TORCH_ARG(bool, stateful) = false;
  TORCH_ARG(double, eps) = 1e-5;
  TORCH_ARG(double, momentum) = 0.1;
};

class BatchNormImpl : public torch::nn::Cloneable<BatchNormImpl> {
 public:
  template <typename... Ts>
  explicit BatchNormImpl(Ts&&... ts)
      : BatchNormImpl(BatchNormOptions(std::forward<Ts>(ts)...)) {}
  explicit BatchNormImpl(BatchNormOptions options);

  void reset() override;

  Tensor forward(Tensor input);
  Tensor pure_forward(Tensor input, Tensor mean, Tensor variance);

  BatchNormOptions options;
  Tensor weight;
  Tensor bias;
  Tensor running_mean;
  Tensor running_variance;
};

TORCH_MODULE(BatchNorm);

} // namespace nn
} // namespace torch

Note: BatchNormImpl:

  • data member: options, weight, bias, running_mean and running_variance
  • method: reset, forward and pure_forward.
  • TORCH_MODULE(BatchNorm); define class BatchNorm which inherits from torch::nn::ModuleHolder<BatchNormImpl>

Let's expand macro TORCH_MODULE

class BatchNorm : public torch::nn::ModuleHolder<BatchNormImpl> {
   public:
    using torch::nn::ModuleHolder<BatchNormImpl>::ModuleHolder;
    BatchNorm(const BatchNorm&) = default;
    BatchNorm(BatchNorm&&) = default;
    BatchNorm(BatchNorm& other) : BatchNorm(static_cast<const BatchNorm&>(other)) {}
    BatchNorm& operator=(const BatchNorm&) = default;
    BatchNorm& operator=(BatchNorm&&) = default;
  }

2. pytorch/torch/csrc/api/src/nn/modules/batchnorm.cpp

#include <torch/nn/modules/batchnorm.h>

#include <torch/cuda.h>
#include <torch/tensor.h>

#include <ATen/Error.h>

#include <cstddef>
#include <utility>
#include <vector>
BatchNormOptions::BatchNormOptions(int64_t features) : features_(features) {}

BatchNormImpl::BatchNormImpl(BatchNormOptions options)
    : options(std::move(options)) {
  reset();
}

NOTE: initialize options and invoke reset() when constructed with BatchNormOptions

void BatchNormImpl::reset() {
  if (options.affine_) {
    weight = register_parameter(
        "weight", torch::empty({options.features_}).uniform_());
    bias = register_parameter("bias", torch::zeros({options.features_}));
  }

  if (options.stateful_) {
    running_mean =
        register_buffer("running_mean", torch::zeros({options.features_}));
    running_variance =
        register_buffer("running_variance", torch::ones({options.features_}));
  }
}

NOTE: copy assignment for weight, bias, running_mean and running_variance when both options.affine_ and options.stateful_ are True. if not, each Tensor will be default-initialized. I think, affine_ is the equivalent for affine in python, so do stateful_ for track_running_stats

Tensor BatchNormImpl::forward(Tensor input) {
  return pure_forward(input, Tensor(), Tensor());
}
Tensor BatchNormImpl::pure_forward(Tensor input, Tensor mean, Tensor variance) {
  auto& running_mean = options.stateful_ ? this->running_mean : mean;
  auto& running_variance =
      options.stateful_ ? this->running_variance : variance;

  if (is_training()) {
    const auto num_channels = input.dim() > 1 ? input.size(1) : 1;
    AT_CHECK(
        input.numel() / num_channels > 1,
        "BatchNorm expected more than 1 value per channel when training!");
  }

  return torch::batch_norm(
    input,
    weight,
    bias,
    running_mean,
    running_variance,
    is_training(),
    options.momentum_,
    options.eps_,
    torch::cuda::cudnn_is_available());
}

NOTE:

  • forward call pure_forward. running_mean and running_variance keep persistent state via pass by reference instead of pass by value. When options.stateful_ is False, new Tensor initialized with default i.e. Tensor() is used.
  • is_training() is equivalent of training in python. batch size has to be greater than 1 when training.

3. pytorch/aten/src/ATen/native/Normalization.cpp

Tensor batch_norm(
    const Tensor& input, const Tensor& weight /* optional */, const Tensor& bias /* optional */,
    const Tensor& running_mean /* optional */, const Tensor& running_var /* optional */,
    bool training, double momentum, double eps, bool cudnn_enabled) {

NOTE:

  • input, weight, bias, running_mean and running_var are all pass by const reference.
  • It seems that batch_norm can NOT update running_mean and running_var. but why pass momentum?
  • ?? how to update running_mean and running_var when training. at backward??
auto num_features = input.sizes()[1];
  if (running_mean.defined()) {
    check_dims_match_num_input_features("running_mean", num_features, running_mean.numel());
  } else if (!training) {
    throw std::runtime_error("running_mean must be defined in evaluation mode");
  }
  if (running_var.defined()) {
    check_dims_match_num_input_features("running_var", num_features, running_var.numel());
  } else if (!training) {
    throw std::runtime_error("running_var must be defined in evaluation mode");
  }
  if (weight.defined()) {
    check_dims_match_num_input_features("weight", num_features, weight.numel());
  }
  if (bias.defined()) {
    check_dims_match_num_input_features("bias", num_features, bias.numel());
  }

NOTE:: number of element in running_mean, running_var, weight and bias should be same with number of channel in iput i.e the 2nd dim.

  return at::thnn_batch_norm(
            input.contiguous(), weight, bias,
            running_mean, running_var, training, momentum, eps);
}

NOTE: we just track cpu implementation

4. pytorch/aten/doc/Functions.h

static inline Tensor thnn_batch_norm( const Tensor & self,
                                      const Tensor & weight,
                                      const Tensor & bias,
                                      const Tensor & running_mean,
                                      const Tensor & running_var,
                                      bool training,
                                      double momentum,
                                      double eps) {
    return infer_type(self).thnn_batch_norm(self,
                                            weight,
                                            bias,
                                            running_mean,
                                            running_var,
                                            training,
                                            momentum, eps);
}
static inline Type & infer_type(const Tensor & t) {
  AT_ASSERT(t.defined(), "undefined Tensor");
  return t.type();
}
static inline Type & infer_type(const Tensor & t) {
  AT_ASSERT(t.defined(), "undefined Tensor");
  return t.type();
}

5.0 pytorch/aten/doc/Tensor.h

struct Tensor : public detail::TensorBase {
  Tensor() : TensorBase() {}
  Tensor(TensorImpl * self, bool retain) : TensorBase(self, retain) {}
  Tensor(const TensorBase & rhs) : TensorBase(rhs) {}
  Tensor(const Tensor & rhs) = default;
  Tensor(Tensor && rhs) noexcept = default;

  Type & type() const {
      return pImpl->type();
    }

  std::unique_ptr<Storage> storage() const {
    return pImpl->storage();

  template<typename T>
  T * data() const;

NOTE: parent class TensorBase actually is TensorBaseImpl<true>

// Tensor is a "generic" object holding a pointer to the underlying > TensorImpl object, which // has an embedded reference count. In this way, Tensor is similar to > boost::intrusive_ptr. // // For example: // // void func(Tensor a) { // Tensor b = a; // ... // } // // In this example, when we say Tensor b = a, we are creating a new object > that points to the // same underlying TensorImpl, and bumps its reference count. When b goes out > of scope, the // destructor decrements the reference count by calling release() on the > TensorImpl it points to. // The existing constructors, operator overloads, etc. take care to implement > the correct semantics. // // Note that Tensor can also be NULL, i.e. it is not associated with any > underlying TensorImpl, and // special care must be taken to handle this.

5.1 pytorch/aten/src/ATen/TensorBase.h

template<bool is_strong>
struct TensorBaseImpl {
  TensorBaseImpl(): TensorBaseImpl(UndefinedTensor::singleton(), false) {}
  TensorBaseImpl(TensorImpl * self, bool should_retain)
  : pImpl(self) {
    if (pImpl == nullptr) {
      throw std::runtime_error("TensorBaseImpl with nullptr not supported");
    }
    if(should_retain && pImpl != UndefinedTensor::singleton()) {
      retain();
    }
  }
  TensorBaseImpl(const TensorBaseImpl & rhs)
  : pImpl(rhs.pImpl) {
    if (pImpl != UndefinedTensor::singleton()) {
      retain();
    }
  }
  TensorBaseImpl(TensorBaseImpl && rhs) noexcept
  : pImpl(rhs.pImpl) {
    rhs.pImpl = UndefinedTensor::singleton();
  }
  ~TensorBaseImpl() {
    if (pImpl != UndefinedTensor::singleton()) {
      release();
    }
  }

NOTE: // TensorBaseImpl is the base class for Tensor which handles the reference counting

  TensorImpl * detach() {
    TensorImpl * ret = pImpl;
    pImpl = UndefinedTensor::singleton();
    return ret;
  }

  bool defined() const {
    return pImpl != UndefinedTensor::singleton();
  }

  friend struct Type;

  //TODO(zach): sort out friend structes
public:
  TensorImpl * pImpl;

NOTE: Tensor inherit pImpl pointing to TensorImpl

using TensorBase = TensorBaseImpl<true>;
using WeakTensorBase = TensorBaseImpl<false>;

5.2 pytorch/aten/src/ATen/TensorImpl.h

struct TensorImpl : public Retainable {
  explicit TensorImpl(Type * type)
  : is_scalar(false), type_(type) {}

  Type & type() const {
    return *type_;
  }

  protected:
  bool is_scalar;
  Type * type_;
};

NOTE: Here is where type_ come from.

5.3 pytorch/aten/doc/Type.h

struct AT_API Type {
  explicit Type(Context * context)
  : context(context) {}
  virtual ~Type() {}
virtual Tensor thnn_batch_norm(const Tensor & self, const Tensor & weight, const Tensor & bias, const Tensor & running_mean, const Tensor & running_var, bool training, double momentum, double eps) const;

Marco

#define TH_CONCAT_STRING_2(x,y) TH_CONCAT_STRING_2_EXPAND(x,y)
#define TH_CONCAT_STRING_2_EXPAND(x,y) #x #y

#define TH_CONCAT_STRING_3(x,y,z) TH_CONCAT_STRING_3_EXPAND(x,y,z)
#define TH_CONCAT_STRING_3_EXPAND(x,y,z) #x #y #z

#define TH_CONCAT_STRING_4(x,y,z,w) TH_CONCAT_STRING_4_EXPAND(x,y,z,w)
#define TH_CONCAT_STRING_4_EXPAND(x,y,z,w) #x #y #z #w

#define TH_CONCAT_2(x,y) TH_CONCAT_2_EXPAND(x,y)
#define TH_CONCAT_2_EXPAND(x,y) x ## y

#define TH_CONCAT_3(x,y,z) TH_CONCAT_3_EXPAND(x,y,z)
#define TH_CONCAT_3_EXPAND(x,y,z) x ## y ## z

#define TH_CONCAT_4_EXPAND(x,y,z,w) x ## y ## z ## w
#define TH_CONCAT_4(x,y,z,w) TH_CONCAT_4_EXPAND(x,y,z,w)

#define THMin(X, Y)  ((X) < (Y) ? (X) : (Y))
#define THMax(X, Y)  ((X) > (Y) ? (X) : (Y))

#define THTensor_(NAME)   TH_CONCAT_4(TH,Real,Tensor_,NAME)

#define THNN_(NAME) TH_CONCAT_3(THNN_, Real, NAME)

6. pytorch/aten/src/THNN/generic/BatchNormalization.c

  • forward
void THNN_(BatchNormalization_updateOutput)(
  THNNState *state, THTensor *input, THTensor *output,
  THTensor *weight, THTensor *bias,
  THTensor *running_mean, THTensor *running_var,
  THTensor *save_mean, THTensor *save_std,
  bool train, double momentum, double eps)
{
THTensor_(resizeAs)(output, input);
  int64_t nInput = THTensor_(size)(input, 1);
  int64_t f;
  ptrdiff_t n = THTensor_(nElement)(input) / nInput;

  if (train) {
    THTensor_(resize1d)(save_mean, nInput);
    THTensor_(resize1d)(save_std, nInput);
  }

for feature map f

  #pragma omp parallel for
  for (f = 0; f < nInput; ++f) {
    THTensor *in = THTensor_(newSelect)(input, 1, f);
    THTensor *out = THTensor_(newSelect)(output, 1, f);

    real mean, invstd;

    if (train) {
      // compute mean per input
      accreal sum = 0;
      TH_TENSOR_APPLY(real, in, sum += *in_data;);

      mean = (real) sum / n;
      THTensor_(set1d)(save_mean, f, (real) mean);

      // compute variance per input
      sum = 0;
      TH_TENSOR_APPLY(real, in,
        sum += (*in_data - mean) * (*in_data - mean););

      if (sum == 0 && eps == 0.0) {
        invstd = 0;
      } else {
        invstd = (real) (1 / sqrt(sum/n + eps));
      }
      THTensor_(set1d)(save_std, f, (real) invstd);

      // update running averages
      if (running_mean) {
        THTensor_(set1d)(running_mean, f,
          (real) (momentum * mean + (1 - momentum) * THTensor_(get1d)(running_mean, f)));
      }
      if (running_var) {
        accreal unbiased_var = sum / (n - 1);
        THTensor_(set1d)(running_var, f,
          (real) (momentum * unbiased_var + (1 - momentum) * THTensor_(get1d)(running_var, f)));
      }
    } else {
      mean = THTensor_(get1d)(running_mean, f);
      invstd = 1 / sqrt(THTensor_(get1d)(running_var, f) + eps);
    }

    // compute output
    real w = weight ? THTensor_(get1d)(weight, f) : 1;
    real b = bias ? THTensor_(get1d)(bias, f) : 0;

    TH_TENSOR_APPLY2(real, in, real, out,
      *out_data = (real) (((*in_data - mean) * invstd) * w + b););

    THTensor_(free)(out);
    THTensor_(free)(in);
  }
}
void THNN_(BatchNormalization_backward)(
  THNNState *state, THTensor *input, THTensor *gradOutput, THTensor *gradInput,
  THTensor *gradWeight, THTensor *gradBias, THTensor *weight,
  THTensor *running_mean, THTensor *running_var,
  THTensor *save_mean, THTensor *save_std,
  bool train, double scale, double eps)
{
  THNN_CHECK_SHAPE(input, gradOutput);
  int64_t nInput = THTensor_(size)(input, 1);
  int64_t f;
  ptrdiff_t n = THTensor_(nElement)(input) / nInput;

  if (gradInput) {
    THTensor_(resizeAs)(gradInput, input);
  }

  #pragma omp parallel for
  for (f = 0; f < nInput; ++f) {
    THTensor *in = THTensor_(newSelect)(input, 1, f);
    THTensor *gradOut = THTensor_(newSelect)(gradOutput, 1, f);
    real w = weight ? THTensor_(get1d)(weight, f) : 1;
    real mean, invstd;
    if (train) {
      mean = THTensor_(get1d)(save_mean, f);
      invstd = THTensor_(get1d)(save_std, f);
    } else {
      mean = THTensor_(get1d)(running_mean, f);
      invstd = 1 / sqrt(THTensor_(get1d)(running_var, f) + eps);
    }

    // sum over all gradOutput in feature plane
    accreal sum = 0;
    TH_TENSOR_APPLY(real, gradOut, sum += *gradOut_data;);

    // dot product of the Q(X) and gradOuput
    accreal dotp = 0;
    TH_TENSOR_APPLY2(real, in, real, gradOut,
      dotp += (*in_data - mean) * (*gradOut_data););

    if (gradInput) {
      THTensor *gradIn = THTensor_(newSelect)(gradInput, 1, f);

      if (train) {
        // when in training mode
        // Q(X) = X - E[x] ; i.e. input centered to zero mean
        // Y = Q(X) / σ    ; i.e. BN output before weight and bias
        // dL/dX = (Q(dL/dY) - dot(Y, dL/dY) * Y) / σ * w

        // projection of gradOutput on to output scaled by std
        real k = (real) dotp * invstd * invstd / n;
        TH_TENSOR_APPLY2(real, gradIn, real, in,
          *gradIn_data = (*in_data - mean) * k;);

        accreal gradMean = sum / n;
        TH_TENSOR_APPLY2(real, gradIn, real, gradOut,
          *gradIn_data = (*gradOut_data - gradMean - *gradIn_data) * invstd * w;);

      } else {
        // when in evaluation mode
        // Q(X) = X - running_mean  ; i.e. input centered to zero mean
        // Y = Q(X) / running_std    ; i.e. BN output before weight and bias
        // dL/dX = w / running_std
        TH_TENSOR_APPLY2(real, gradIn, real, gradOut,
          *gradIn_data = *gradOut_data * invstd * w;);
      }

      THTensor_(free)(gradIn);
    }

    if (gradWeight) {
      real val = THTensor_(get1d)(gradWeight, f);
      THTensor_(set1d)(gradWeight, f, val + scale * dotp * invstd);
    }

    if (gradBias) {
      real val = THTensor_(get1d)(gradBias, f);
      THTensor_(set1d)(gradBias, f, val + scale * sum);
    }

    THTensor_(free)(gradOut);
    THTensor_(free)(in);
  }
}

Pytorch forum

  1. https://discuss.pytorch.org/t/example-on-how-to-use-batch-norm/216/
  • As @moskomule pointed out, you have to specify how many feature channels will your input have (because that’s the number of BatchNorm parameters). Batch and spatial dimensions don’t matter.
  • BatchNorm will only update the running averages in train mode, so if you want the model to keep updating them in test time, you will have to keep BatchNorm modules in the training mode. See the C implementation for details (it should be readable).
  • About ReLU and MaxPool - if you think about it for a moment both ReLU + MaxPool and MaxPool + ReLU are equivalent operations, with the second option being 37.5% more efficient ( numel + numel in first case numel + numel/4 in the second case, where numel is the number of elements in the tensor). That’s why the example has a different order.
  1. https://discuss.pytorch.org/t/example-on-how-to-use-batch-norm/216/ Q: At test time, I would like to freeze both the weights, (lambda and beta), as well as freeze the running averages that is has computed. (Ostensibly because it has a good estimate for those from training already). So I basically expect that I would want all 4 of those values frozen. A: Yeah in that case if you keep the BatchNorm modules in evaluation mode, and you won’t pass their parameters to the optimizer (best to set their requires_grad to False), they will be completely frozen.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment