Skip to content

Instantly share code, notes, and snippets.

@dorpxam
Last active October 22, 2023 18:16
Show Gist options
  • Save dorpxam/67ad2bc222b2cf567d4a6fc298375e13 to your computer and use it in GitHub Desktop.
Save dorpxam/67ad2bc222b2cf567d4a6fc298375e13 to your computer and use it in GitHub Desktop.
GradScaler implementation for Libtorch C++ API
#pragma once
#include <iterator>
#include <type_traits>
#include <torch/torch.h>
#include <torch/serialize/archive.h>
namespace torch {
namespace amp {
template<typename T, typename = void>
struct is_container : std::false_type {};
template<typename T>
struct is_container<T, std::void_t<decltype(std::declval<T>().begin()),
decltype(std::declval<T>().end()), typename T::value_type>> : std::true_type {};
struct GradScalerOptions
{
GradScalerOptions() = default;
TORCH_ARG(double, init_scale) = pow(2.0, 16);
TORCH_ARG(double, growth_factor) = 2.0;
TORCH_ARG(double, backoff_factor) = 0.5;
TORCH_ARG(int64_t, growth_interval) = 2000;
TORCH_ARG(bool, enabled) = true;
};
class GradScaler
{
enum class OptState : int { READY, UNSCALED, STEPPED };
using PerDeviceTensors = std::map<c10::DeviceType, torch::Tensor>;
using States = c10::variant<OptState, PerDeviceTensors>;
using OptimizerStates = std::map<std::string, States>;
using PerOptimizerStates = std::map<std::uintptr_t, OptimizerStates>;
public:
GradScaler(const GradScaler& grad_scaler) = delete;
GradScaler(GradScaler&& grad_scaler) = default;
explicit GradScaler(GradScalerOptions const& options = {})
: _init_scale(options.init_scale())
, _growth_factor(options.growth_factor())
, _backoff_factor(options.backoff_factor())
, _growth_interval(options.growth_interval())
, _enabled(options.enabled())
{
if (_enabled && !(torch::cuda::is_available() || torch::hasXLA()))
{
TORCH_WARN("GradScaler is enabled, but CUDA is not available. Disabling.");
_enabled = false;
}
if (_enabled)
{
TORCH_CHECK(_growth_factor > 1.0, "The growth factor must be > 1.0.");
TORCH_CHECK(_backoff_factor < 1.0, "The backoff factor must be < 1.0.");
}
}
auto _check_scale_growth_tracker(std::string const& funcname) -> void
{
static auto fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration.";
TORCH_CHECK(_scale.defined(), "Attempted " + funcname + " but _scale is None. " + fix);
TORCH_CHECK(_growth_tracker.defined(), "Attempted " + funcname + " but _growth_tracker is None. " + fix);
}
auto _lazy_init_scale_growth_tracker(torch::Device const& dev) -> void
{
TORCH_CHECK(!_growth_tracker.defined(), "_growth_tracker initialized before _scale");
_scale = torch::full({ 1 }, _init_scale, c10::TensorOptions().dtype(torch::kFloat32).device(dev));
_growth_tracker = torch::full({ 1 }, _init_growth_tracker, c10::TensorOptions().dtype(torch::kInt32).device(dev));
}
template <typename T>
auto scale(T const& values)
{
if constexpr (std::is_same<T, torch::Tensor>())
{
if (!_enabled)
return values;
assert(values.is_cuda() || values.device().type() == torch::kXLA);
if (!_scale.defined())
_lazy_init_scale_growth_tracker(values.device());
assert(_scale.defined());
return values * _scale.to(values.device(), true);
}
if constexpr (is_container<T>::value)
{
if (!_enabled)
return values;
std::vector<_MultiDeviceReplicator> stash;
const auto apply_scale = [&](auto&& value, auto&& apply_scale)
{
using Result = std::decay_t<decltype(value)>;
if constexpr (std::is_same<Result, torch::Tensor>())
{
assert(value.is_cuda() || value.device().type() == torch::kXLA);
if (stash.empty())
{
if (!_scale.defined())
_lazy_init_scale_growth_tracker(value.device());
assert(_scale.defined());
stash.push_back(_MultiDeviceReplicator(_scale));
}
return value * stash.front().get(value.device().type());
}
if constexpr (is_container<Result>::value)
{
Result result;
result.reserve(value.size());
std::transform(std::begin(value),
std::end(value),
std::back_inserter(result),
[=](auto&& item)
{
return apply_scale(item, apply_scale);
});
return result;
}
};
return apply_scale(values, apply_scale);
}
}
auto _unscale_grads_(torch::optim::Optimizer& optimizer, torch::Tensor& inv_scale, torch::Tensor& found_inf, bool allow_fp16) -> PerDeviceTensors
{
auto per_device_inv_scale = _MultiDeviceReplicator(inv_scale);
auto per_device_found_inf = _MultiDeviceReplicator(found_inf);
std::map<c10::DeviceType, std::map<c10::ScalarType, std::vector<torch::Tensor>>> per_device_and_dtype_grads;
torch::NoGradGuard nograd;
for (auto& group : optimizer.param_groups())
{
for (auto& param : group.params())
{
assert(instanceof<torch::Tensor>(&param));
if (!param.grad().defined())
continue;
if ((!allow_fp16) && (param.grad().dtype() == torch::kFloat16))
throw std::invalid_argument("Attempting to unscale FP16 gradients.");
torch::Tensor to_unscale;
if (param.grad().is_sparse())
{
if (param.grad().dtype() == torch::kFloat16)
param.mutable_grad() = param.grad().coalesce();
to_unscale = param.grad()._values();
}
else
to_unscale = param.grad();
per_device_and_dtype_grads[to_unscale.device().type()][to_unscale.dtype().toScalarType()].push_back(to_unscale);
}
for (auto&& [device, per_dtype_grads] : per_device_and_dtype_grads)
{
for (auto&& [_, grads] : per_dtype_grads)
torch::_amp_foreach_non_finite_check_and_unscale_(grads, per_device_found_inf.get(device),
per_device_inv_scale.get(device));
}
}
return per_device_found_inf._per_device_tensors;
}
auto unscale_(torch::optim::Optimizer& optimizer) -> void
{
if (!_enabled)
return;
_check_scale_growth_tracker("unscale_");
auto& optimizer_state = get_per_optimizer_states(id(optimizer));
if (optimizer_state["stage"] == OptState::UNSCALED)
throw std::runtime_error("unscale_() has already been called on this optimizer since the last update().");
else
if (optimizer_state["stage"] == OptState::STEPPED)
throw std::runtime_error("unscale_() is being called after step().");
assert(_scale.defined());
auto inv_scale = _scale.to(torch::kDouble).reciprocal().to(at::kFloat);
auto found_inf = torch::full({ 1 }, 0.0, at::TensorOptions().dtype(at::kFloat).device(_scale.device()));
optimizer_state["found_inf_per_device"] = _unscale_grads_(optimizer, inv_scale, found_inf, false);
optimizer_state["stage"] = OptState::UNSCALED;
}
auto _maybe_opt_step(torch::optim::Optimizer& optimizer, OptimizerStates& optimizer_state, torch::optim::Optimizer::LossClosure args) -> c10::optional<c10::Scalar>
{
if (optimizer_state.contains("found_inf_per_device"))
{
auto& found_inf_per_device = c10::get<1>(optimizer_state["found_inf_per_device"]);
if (!sum(found_inf_per_device))
{
auto tensor = optimizer.step(args);
if (tensor.defined())
return tensor.item();
}
}
return c10::nullopt;
}
auto step(torch::optim::Optimizer& optimizer, torch::optim::Optimizer::LossClosure optimizer_args = nullptr) -> c10::optional<c10::Scalar>
{
if (!_enabled)
{
auto res = optimizer.step(optimizer_args);
if (res.defined())
return res.item();
else
return c10::nullopt;
}
if (optimizer_args != nullptr)
throw std::runtime_error("Closure use is not currently supported if GradScaler is enabled.");
_check_scale_growth_tracker("step");
auto& optimizer_state = get_per_optimizer_states(id(optimizer));
if (optimizer_state["stage"] == OptState::STEPPED)
throw std::runtime_error("step() has already been called since the last update().");
c10::optional<c10::Scalar> retval;
// if getattr(optimizer, "_step_supports_amp_scaling", False):
{
// ------------------------------------------------------------------------
// TODO: Future Feature
// ------------------------------------------------------------------------
// The boolean '_step_supports_amp_scaling' force to use dynamic inspection
// of class signature (thank's to python reflection) to call step() with
// with extra parameters (two tensors : 'grad_scale' & 'found_inf')
// ------------------------------------------------------------------------
// Because step() method in torch::optim::Optimizer() class is pure virtual
// and fixed to one single parameter (LossClosure functor). It's impossible
// currently to mimic the current python's behavior of the extra parameters
// ------------------------------------------------------------------------
}
if (optimizer_state["stage"] == OptState::READY)
unscale_(optimizer);
auto&& found_inf_per_device = c10::get<1>(optimizer_state["found_inf_per_device"]);
TORCH_CHECK(found_inf_per_device.size() > 0, "No inf checks were recorded for this optimizer.");
retval = _maybe_opt_step(optimizer, optimizer_state, optimizer_args);
optimizer_state["stage"] = OptState::STEPPED;
return retval;
}
void update(c10::optional<c10::variant<double, torch::Tensor>> const& new_scale = c10::nullopt)
{
if (!_enabled)
return;
_check_scale_growth_tracker("update");
if (new_scale.has_value())
{
assert(_scale.defined());
c10::visit([=](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, double>)
{
_scale.fill_(arg);
}
else
if constexpr (std::is_same_v<T, torch::Tensor>)
{
static auto reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False.";
TORCH_CHECK(arg.dtype() == torch::kFloat32, reason);
TORCH_CHECK(arg.numel() == 1, reason);
TORCH_CHECK(arg.requires_grad() == false, reason);
_scale.copy_(arg);
}
},
new_scale.value());
}
else
{
std::vector<torch::Tensor> found_infs;
for (auto&& [_, state] : _per_optimizer_states)
{
auto& iterator = c10::get<1>(state["found_inf_per_device"]);
for (auto& [_, found_inf] : iterator)
found_infs.push_back(found_inf.to(_scale.device(), true));
}
assert(found_infs.size() > 0, "No inf checks were recorded prior to update.");
auto& found_inf_combined = found_infs.front();
if (found_infs.size() > 1)
for (const auto i : c10::irange(1, found_infs.size()))
found_inf_combined += found_infs[i];
torch::_amp_update_scale_(_scale,
_growth_tracker,
found_inf_combined,
_growth_factor,
_backoff_factor,
_growth_interval);
}
_per_optimizer_states.clear();
}
torch::Tensor _get_scale_async() const
{
return _scale;
}
double get_scale() const
{
if (_enabled)
{
if (_scale.defined())
return _scale.item<double>();
else
return _init_scale;
}
return 1.0;
}
double get_growth_factor() const
{
return _growth_factor;
}
void set_growth_factor(double new_factor)
{
_growth_factor = new_factor;
}
double get_backoff_factor() const
{
return _backoff_factor;
}
void set_backoff_factor(double new_factor)
{
_backoff_factor = new_factor;
}
int64_t get_growth_interval() const
{
return _growth_interval;
}
void set_growth_interval(int64_t new_interval)
{
_growth_interval = new_interval;
}
int64_t get_init_growth_tracker() const
{
return _init_growth_tracker;
}
void set_init_growth_tracker(int64_t new_value)
{
_init_growth_tracker = new_value;
}
int64_t _get_growth_tracker() const
{
if (_enabled)
{
if (_growth_tracker.defined())
_growth_tracker.item<int64_t>();
else
return _init_growth_tracker;
}
return 0;
}
bool is_enabled() const
{
return _enabled;
}
private:
template <typename Type>
inline Type read(torch::serialize::InputArchive& archive, std::string const& name)
{
c10::IValue ivalue;
bool exists = archive.try_read(name, ivalue);
if (exists)
return ivalue.to<Type>();
else
return Type();
}
public:
void save(torch::serialize::OutputArchive& archive) const
{
if (_enabled)
{
TORCH_CHECK(_per_optimizer_states.empty(),
"A GradScaler instance may only be saved at the beginning " \
"of an iteration, or at the end after scaler.update().");
serialize::OutputArchive state(archive.compilation_unit());
{
state.write("scale", get_scale());
state.write("growth_factor", _growth_factor);
state.write("backoff_factor", _backoff_factor);
state.write("growth_interval", _growth_interval);
state.write("_growth_tracker", _get_growth_tracker());
}
archive.write("gradscaler", state);
}
}
void load(torch::serialize::InputArchive& archive)
{
if (!_enabled)
return;
if (archive.keys().empty())
throw std::runtime_error("The source state dict is empty, possibly because it was saved " \
"from a disabled instance of GradScaler.");
serialize::InputArchive state;
if (archive.try_read("gradscaler", state))
{
_init_scale = read<double>(state, "scale");
if (_scale.defined())
_scale.fill_(_init_scale);
_growth_factor = read<double>(state, "growth_factor");
_backoff_factor = read<double>(state, "backoff_factor");
_growth_interval = read<int64_t>(state, "growth_interval");
_init_growth_tracker = read<int64_t>(state, "_growth_tracker");
if (_growth_tracker.defined())
_growth_tracker.fill_(_init_growth_tracker);
}
}
private:
double _init_scale;
double _backoff_factor;
double _growth_factor;
int64_t _growth_interval;
private:
at::Tensor _scale;
at::Tensor _growth_tracker;
int64_t _init_growth_tracker{ 0 };
protected:
bool _enabled;
private:
template <typename Type>
inline std::uintptr_t id(Type const& type)
{
return reinterpret_cast<std::uintptr_t>(std::addressof(type));
}
template <typename Base, typename Type>
inline bool instanceof(const Type* ptr)
{
return dynamic_cast<const Base*>(ptr) != nullptr;
}
private:
PerOptimizerStates _per_optimizer_states;
inline auto _refresh_per_optimizer_state() -> OptimizerStates
{
return
{
{ "stage", OptState::READY },
{ "found_inf_per_device", {} }
};
}
inline OptimizerStates& get_per_optimizer_states(std::uintptr_t optimizer_id)
{
if (_per_optimizer_states.contains(optimizer_id) == false)
_per_optimizer_states[optimizer_id] = _refresh_per_optimizer_state();
return _per_optimizer_states.at(optimizer_id);
}
friend bool operator==(States const& lhs, OptState const& rhs)
{
return (lhs.index() > 0) ? false : c10::get<0>(lhs) == rhs;
}
private:
class _MultiDeviceReplicator
{
public:
_MultiDeviceReplicator(torch::Tensor& master_tensor)
: master(master_tensor)
{
assert(master_tensor.is_cuda() || master_tensor.device().type() == torch::DeviceType::XLA);
}
inline torch::Tensor& get(c10::DeviceType device)
{
if (!_per_device_tensors.contains(device))
_per_device_tensors[device] = master.to(device, true, true);
return _per_device_tensors[device];
}
torch::Tensor& master;
PerDeviceTensors _per_device_tensors;
};
template <typename Type = double>
inline auto sum(PerDeviceTensors const& per_device)
{
Type sum = Type(0);
for (auto&& [_, v] : per_device)
sum += v.item<Type>();
return sum;
}
};
serialize::OutputArchive& operator<< (serialize::OutputArchive& archive, const GradScaler& scaler)
{
scaler.save(archive); return archive;
}
serialize::InputArchive& operator>>(serialize::InputArchive& archive, GradScaler& scaler)
{
scaler.load(archive); return archive;
}
} // namespace amp
} // namespace torch
#pragma once
/*
Testing 'GradScaler' using adapted tests provided in the python api
> pytorch/test/test_cuda.py
The dependencies for running the tests are:
> https://github.com/catchorg/Catch2
> https://github.com/ryanhaining/cppitertools [chain, enumerate, range, zip]
*/
// Catch2
#define CONFIG_CATCH_MAIN
#include <catch_amalgamated.hpp>
// CPPIterTools
#include <chain.hpp>
#include <enumerate.hpp>
#include <range.hpp>
#include <zip.hpp>
// GradScaler (include Torch)
#include <gradscaler.hpp>
static constexpr auto inf = std::numeric_limits<float>::infinity();
static constexpr auto NaN = std::numeric_limits<float>::quiet_NaN();
inline bool operator==(torch::Tensor lhs, double rhs)
{
return lhs.item<double>() == rhs;
}
inline bool operator!=(torch::Tensor lhs, double rhs)
{
return !(lhs == rhs);
}
inline auto FloatTensor(torch::detail::TensorDataContainer const& content)
{
return torch::tensor(content, c10::TensorOptions().dtype(torch::kFloat).device(c10::kCUDA));
}
inline auto LongTensor(torch::detail::TensorDataContainer const& content)
{
return torch::tensor(content, c10::TensorOptions().dtype(torch::kLong).device(c10::kCUDA));
}
TEST_CASE("test_cuda")
{
c10::DeviceType device = c10::kCUDA;
c10::ScalarType dtype = torch::kFloat;
using TestDataset = std::vector<std::pair<torch::Tensor, torch::Tensor>>;
auto create_scaling_models_optimizers = [=](c10::DeviceType device = at::kCUDA)
{
auto mod_control = torch::nn::Sequential(torch::nn::Linear(torch::nn::LinearOptions(8, 8)), torch::nn::Linear(torch::nn::LinearOptions(8, 8))); mod_control->to(device);
auto mod_scaling = torch::nn::Sequential(torch::nn::Linear(torch::nn::LinearOptions(8, 8)), torch::nn::Linear(torch::nn::LinearOptions(8, 8))); mod_scaling->to(device);
torch::NoGradGuard no_grad;
for (auto&& [c, s] : iter::zip(mod_control->parameters(),
mod_scaling->parameters()))
s.copy_(c);
torch::optim::SGD opt_control (mod_control->parameters(), torch::optim::SGDOptions(1.0));
torch::optim::SGD opt_scaling (mod_scaling->parameters(), torch::optim::SGDOptions(1.0));
return std::make_tuple(std::move(mod_control), std::move(mod_scaling), std::move(opt_control), std::move(opt_scaling));
};
auto create_scaling_case = [=](c10::DeviceType device = at::kCUDA, c10::ScalarType dtype = torch::kFloat)
{
auto data = TestDataset
{
{ torch::randn({ 8, 8 }, torch::TensorOptions().dtype(dtype).device(device)), torch::randn({ 8, 8 }, torch::TensorOptions().dtype(dtype).device(device)) },
{ torch::randn({ 8, 8 }, torch::TensorOptions().dtype(dtype).device(device)), torch::randn({ 8, 8 }, torch::TensorOptions().dtype(dtype).device(device)) },
{ torch::randn({ 8, 8 }, torch::TensorOptions().dtype(dtype).device(device)), torch::randn({ 8, 8 }, torch::TensorOptions().dtype(dtype).device(device)) },
{ torch::randn({ 8, 8 }, torch::TensorOptions().dtype(dtype).device(device)), torch::randn({ 8, 8 }, torch::TensorOptions().dtype(dtype).device(device)) }
};
auto loss_fn = torch::nn::MSELoss();
loss_fn->to(at::kCUDA);
auto skip_iter = 2;
return std::tuple_cat(create_scaling_models_optimizers(device), std::make_tuple(data, loss_fn, skip_iter));
};
auto run_scaling_case = [=](auto&& run, int unskipped, int skipped, double atol = 1e-7)
{
const auto rtol = 1e-5;
for (auto enabled : { true, false })
{
auto&& [mod_control, mod_scaling, opt_control, opt_scaling, data, loss_fn, skip_iter] = create_scaling_case();
auto scaler = torch::amp::GradScaler(torch::amp::GradScalerOptions().init_scale(128.).growth_factor(2.0).growth_interval(1).enabled(enabled));
run(data, mod_control, opt_control, scaler, loss_fn, skip_iter, false);
run(data, mod_scaling, opt_scaling, scaler, loss_fn, skip_iter, true);
if (enabled)
{
auto net_growth = unskipped > 0 ? pow(scaler.get_growth_factor(), unskipped) : 1.0;
auto net_backoff = skipped > 0 ? pow(scaler.get_backoff_factor(), skipped) : 1.0;
REQUIRE(scaler.get_scale() == (128. * net_growth * net_backoff));
}
else
REQUIRE(scaler.get_scale() == 1.0);
for (auto&& [c, s] : iter::zip(mod_control->parameters(), mod_scaling->parameters()))
{
auto& c_grad = c.grad(); // special case: grad() can be 'none', and unlike python code
auto& s_grad = s.grad(); // the 'torch::allclose' throw an exception with undefined tensors
if (c_grad.defined() && s_grad.defined())
REQUIRE(torch::allclose(c.grad(), s.grad(), rtol, atol));
auto& c_state = opt_control.param_groups();
auto& s_state = opt_scaling.param_groups();
for (auto&& [c_state_k, s_state_k] : iter::zip(c_state, s_state))
for (auto&& [c_state_p, s_state_p] : iter::zip(c_state_k.params(), s_state_k.params()))
REQUIRE(torch::allclose(c_state_p, s_state_p, rtol, atol));
REQUIRE(torch::allclose(c, s, rtol, atol));
}
}
};
SECTION("test_grad_scaling_unscale_sparse")
{
torch::amp::GradScaler scaler; // use default options
auto inv_scale = torch::full({ 1 }, 0.25, c10::TensorOptions().dtype(dtype).device(device));
auto found_inf = torch::empty({ 1 }, c10::TensorOptions().dtype(dtype).device(device));
auto cur = found_inf.device().type();
auto i = torch::tensor({ { 0, 1, 1 }, { 2, 0, 2 } }, c10::TensorOptions().dtype(torch::kInt64).device(c10::kCUDA));
auto v = torch::tensor({ 16., 32., 64. }, c10::TensorOptions().dtype(torch::kFloat).device(c10::kCUDA));
auto s = torch::sparse_coo_tensor(i, v, { 2, 3 }, c10::TensorOptions().dtype(dtype).device(c10::kCUDA));
auto p = s.clone();
REQUIRE(p.is_sparse());
auto optA = torch::optim::SGD({ p }, torch::optim::SGDOptions(1.0));
p.mutable_grad() = s.clone();
found_inf.zero_();
found_inf = scaler._unscale_grads_(optA, inv_scale, found_inf, false)[cur];
REQUIRE(found_inf == 0.0);
REQUIRE(torch::equal(p.grad().to_dense(), (s / 4).to_dense()));
v = FloatTensor({ 16.f, 32.f, inf });
p.mutable_grad() = torch::sparse_coo_tensor(i, v, { 2, 3 }, c10::TensorOptions().dtype(dtype).device(c10::kCUDA));
found_inf.zero_();
found_inf = scaler._unscale_grads_(optA, inv_scale, found_inf, false)[cur];
REQUIRE(found_inf == 1.0);
v = FloatTensor({ 16.f, 32.f, NaN });
p.mutable_grad() = torch::sparse_coo_tensor(i, v, { 2, 3 }, c10::TensorOptions().dtype(dtype).device(c10::kCUDA));
found_inf.zero_();
found_inf = scaler._unscale_grads_(optA, inv_scale, found_inf, false)[cur];
REQUIRE(found_inf == 1.0);
p = s.clone().to(torch::kHalf);
REQUIRE(p.is_sparse());
auto optB = torch::optim::SGD({ p }, torch::optim::SGDOptions(1.0));
p.mutable_grad() = s.clone().to(torch::kHalf);
found_inf.zero_();
found_inf = scaler._unscale_grads_(optB, inv_scale, found_inf, true)[cur];
REQUIRE(found_inf == 0.0);
REQUIRE(torch::equal(p.grad().to_dense(), (s.to(torch::kHalf) / 4).to_dense()));
i = LongTensor({ { 0, 1, 0 }, { 2, 0, 2 } });
v = FloatTensor({ 64000., 32., 64000. });
p.mutable_grad() = torch::sparse_coo_tensor(i, v, { 2, 3 }, c10::TensorOptions().dtype(dtype).device(c10::kCUDA));
found_inf.zero_();
found_inf = scaler._unscale_grads_(optB, inv_scale, found_inf, true)[cur];
REQUIRE(found_inf == 0.0);
}
SECTION("test_grad_scaling_state_dict")
{
for (auto&& lazy_init_scale : { true, false })
{
auto s0 = torch::amp::GradScaler(torch::amp::GradScalerOptions().init_scale(3.).growth_factor(4.).backoff_factor(.5).growth_interval(2));
auto s1 = torch::amp::GradScaler(torch::amp::GradScalerOptions().init_scale(6.).growth_factor(7.).backoff_factor(.8).growth_interval(1));
s1.set_init_growth_tracker(7);
if (lazy_init_scale)
{
s1.scale(torch::full({ 1 }, 4.0, c10::TensorOptions().dtype(torch::kFloat32).device("cuda:0")));
REQUIRE(s1._get_scale_async().dtype() == torch::kFloat);
}
std::stringstream stream;
torch::save(s0, stream);
torch::load(s1, stream);
REQUIRE(s1.get_scale() == 3.);
REQUIRE(s1.get_growth_factor() == 4.);
REQUIRE(s1.get_backoff_factor() == .5);
REQUIRE(s1.get_growth_interval() == 2);
REQUIRE(s1.get_init_growth_tracker() == 0);
}
}
SECTION("test_grad_scale_will_not_overflow")
{
auto model = torch::nn::Linear(torch::nn::LinearOptions(5, 1)); model->to(at::kCUDA);
auto optimizer = torch::optim::Adam(model->parameters());
auto scaler = torch::amp::GradScaler(torch::amp::GradScalerOptions().growth_interval(1).growth_factor(pow(2.0, 4)).init_scale(1e38));
optimizer.zero_grad();
auto x = torch::randn({ 1, 5 }).to(at::kCUDA);
auto y = 1e-30 * torch::randn({ 1, 1 }).to(at::kCUDA);
auto l = torch::pow((model->forward(x) - y), 2).mean();
scaler.scale(l).backward();
scaler.step(optimizer);
scaler.update();
REQUIRE((scaler._get_scale_async() != inf
&& scaler._get_scale_async() != NaN));
}
SECTION("test_grad_scaling_clipping")
{
auto run = [](auto& data, auto& model, auto& optimizer, auto& scaler, auto& loss_fn, auto skip_iter, auto try_scaling_api)
{
auto max_norm = 0.2;
for (auto&& [i, pair] : iter::enumerate(data))
{
auto&& [input, target] = pair;
optimizer.zero_grad();
auto output = model->forward(input);
auto loss = loss_fn->forward(output, target);
if (try_scaling_api)
{
scaler.scale(loss).backward();
torch::nn::utils::clip_grad_norm_(model->parameters(), max_norm * scaler.get_scale());
if (i == skip_iter && scaler.is_enabled())
{
auto& weight = model[1]->as<torch::nn::Linear>()->weight;
weight.grad().data().fill_(inf);
}
scaler.step(optimizer);
scaler.update();
}
else
{
loss.backward();
torch::nn::utils::clip_grad_norm_(model->parameters(), max_norm);
if (!scaler.is_enabled() || (i != skip_iter))
optimizer.step();
}
}
};
run_scaling_case(run, 3, 1, 1e-5);
}
SECTION("test_grad_scaling_clipping_separate_unscale")
{
auto run = [](auto& data, auto& model, auto& optimizer, auto& scaler, auto& loss_fn, auto skip_iter, auto try_scaling_api)
{
auto max_norm = 0.2;
for (auto&& [i, pair] : iter::enumerate(data))
{
auto&& [input, target] = pair;
optimizer.zero_grad();
auto output = model->forward(input);
auto loss = loss_fn->forward(output, target);
if (try_scaling_api)
{
scaler.scale(loss).backward();
if (i == skip_iter && scaler.is_enabled())
{
auto& weight = model[1]->as<torch::nn::Linear>()->weight;
weight.grad().data().fill_(inf);
}
scaler.unscale_(optimizer);
torch::nn::utils::clip_grad_norm_(model->parameters(), max_norm);
scaler.step(optimizer);
scaler.update();
}
else
{
loss.backward();
torch::nn::utils::clip_grad_norm_(model->parameters(), max_norm);
if (!scaler.is_enabled() || (i != skip_iter))
optimizer.step();
}
}
};
run_scaling_case(run, 3, 1);
}
SECTION("test_grad_scaling_penalty")
{
auto run = [](auto& data, auto& model, auto& optimizer, auto& scaler, auto& loss_fn, auto skip_iter, auto try_scaling_api)
{
for (auto&& [i, pair] : iter::enumerate(data))
{
auto&& [input, target] = pair;
optimizer.zero_grad();
auto output = model->forward(input);
auto loss = loss_fn->forward(output, target);
std::vector<torch::Tensor> grad_params;
if (try_scaling_api)
{
grad_params = torch::autograd::grad({ scaler.scale(loss) }, model->parameters(), {}, {}, true);
auto inv_scale = 1. / scaler.get_scale();
for (auto& p : grad_params)
p = p * inv_scale;
}
else
grad_params = torch::autograd::grad({ loss }, model->parameters(), {}, {}, true);
auto grad_norm = torch::zeros({ 1 }).to(input.device());
for (auto& grad : grad_params)
grad_norm += grad.pow(2).sum();
grad_norm = grad_norm.sqrt();
loss = loss + grad_norm;
if (try_scaling_api)
{
scaler.scale(loss).backward();
if (i == skip_iter && scaler.is_enabled())
{
auto& weight = model[1]->as<torch::nn::Linear>()->weight;
weight.grad().data().fill_(inf);
}
scaler.step(optimizer);
scaler.update();
}
else
{
loss.backward();
if (!scaler.is_enabled() || (i != skip_iter))
optimizer.step();
}
}
};
run_scaling_case(run, 3, 1);
}
SECTION("test_grad_scaling_accumulation")
{
auto run = [](auto& data, auto& model, auto& optimizer, auto& scaler, auto& loss_fn, auto skip_iter, auto try_scaling_api)
{
auto iters_to_accumulate = 2;
for (auto&& [i, pair] : iter::enumerate(data))
{
auto&& [input, target] = pair;
auto output = model->forward(input);
auto loss = loss_fn->forward(output, target);
loss = loss / iters_to_accumulate;
if (try_scaling_api)
scaler.scale(loss).backward();
else
loss.backward();
if ((i + 1) % iters_to_accumulate == 0)
{
if (try_scaling_api)
{
scaler.step(optimizer);
scaler.update();
optimizer.zero_grad();
}
else
{
optimizer.step();
optimizer.zero_grad();
}
}
}
};
run_scaling_case(run, 2, 0);
}
SECTION("test_grad_scaling_multiple")
{
for (auto&& enabled : { true, false })
{
auto&& [mod_control0, mod_scaling0, opt_control0, opt_scaling0, data, loss_fn, skip_iter] = create_scaling_case();
auto&& [mod_control1, mod_scaling1, opt_control1, opt_scaling1] = create_scaling_models_optimizers();
auto scaler = torch::amp::GradScaler(torch::amp::GradScalerOptions().init_scale(128.).growth_factor(2.0).growth_interval(1).enabled(enabled));
auto run = [&](auto& model0, auto& model1, auto& optimizer0, auto& optimizer1, bool try_scaling_api)
{
for (auto&& [i, pair] : iter::enumerate(data))
{
auto&& [input, target] = pair;
optimizer0.zero_grad();
optimizer1.zero_grad();
auto output0 = model0->forward(input);
auto output1 = model1->forward(input);
auto loss0 = loss_fn(0.3 * output0 + 0.7 * output1, target);
auto loss1 = loss_fn(0.6 * output0 - 0.4 * output1, target);
if (try_scaling_api)
{
scaler.scale(loss0).backward({}, true);
scaler.scale(loss1).backward();
if (i == skip_iter && scaler.is_enabled())
{
auto& weight = model1[1]->as<torch::nn::Linear>()->weight;
weight.grad().data().fill_(inf);
}
scaler.unscale_(optimizer0);
scaler.step(optimizer0);
scaler.step(optimizer1);
scaler.update();
}
else
{
loss0.backward({}, true);
loss1.backward();
optimizer0.step();
if (!scaler.is_enabled() || (i != skip_iter))
optimizer1.step();
}
}
};
run(mod_control0, mod_control1, opt_control0, opt_control1, false);
run(mod_scaling0, mod_scaling1, opt_scaling0, opt_scaling1, true);
REQUIRE(scaler.get_scale() == (enabled ? (128. * pow(scaler.get_growth_factor(), 3)
* pow(scaler.get_backoff_factor(), 1)) : 1.0));
for (auto&& [c, s] : iter::zip(iter::chain(mod_control0->parameters(), mod_control1->parameters()),
iter::chain(mod_scaling0->parameters(), mod_scaling1->parameters())))
REQUIRE(torch::allclose(c, s, 1e-5, 1e-7));
}
}
}
CATCH_TRANSLATE_EXCEPTION(std::exception const& e)
{
return c10::GetExceptionString(e);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment