Skip to content

Instantly share code, notes, and snippets.

@colesbury
Last active November 2, 2017 18:42
Show Gist options
  • Save colesbury/bc6835aaef3f755dd063cb484e788701 to your computer and use it in GitHub Desktop.
Save colesbury/bc6835aaef3f755dd063cb484e788701 to your computer and use it in GitHub Desktop.
--- ./torch/csrc/autograd/generated/VariableType.cpp 2017-11-02 11:42:00.179695310 -0700
+++ VariableType.cpp 2017-11-02 11:41:41.932641407 -0700
@@ -181,6 +181,13 @@
return make_variable(std::move(tensor));
}
+Variable VariableType::maybe_wrap(Tensor data, const Variable & self, bool inplace) const {
+ if (inplace) {
+ return self;
+ }
+ return as_variable(data);
+}
+
static Variable as_view(Variable base, Tensor tensor) {
if (base.is_view()) {
base = base.base();
@@ -254,8 +261,12 @@
if (grad_fn) {
grad_fn->num_inputs = 1;
}
+ if (inplace) {
+ var.rebase_history(flags, 0, std::move(grad_fn));
+ } else {
var.set_history(flags, 0, std::move(grad_fn));
}
+}
static void set_flags(std::vector<Variable> &vl, VarFlags flags, std::shared_ptr<Function> grad_fn) {
if (grad_fn) {
@@ -292,7 +303,7 @@
auto& dst_ = unpack(dst, "dst", 1);
check_inplace(dst);
std::shared_ptr<CopyBackwards> grad_fn;
- auto flags = compute_flags({ src });
+ auto flags = compute_flags({ dst, src });
flags.requires_grad &= isFloatingPoint(dst.type().scalarType());
if (flags.requires_grad) {
// TODO: handle type conversions
@@ -518,7 +529,7 @@
grad_fn->dim0 = dim0;
grad_fn->dim1 = dim1;
}
- auto ret = as_view(static_cast<const Variable&>(self), baseType->m_transpose_(self_, dim0, dim1));
+ baseType->m_transpose_(self_, dim0, dim1);
increment_version(self);
set_flags(static_cast<Variable&>(self), flags, grad_fn, false);
if (jit::tracer::isTracing({ self })) {
@@ -557,7 +568,7 @@
grad_fn->is_executable = true;
grad_fn->next_functions = compute_next_functions({ self });
}
- auto ret = as_view(static_cast<const Variable&>(self), baseType->m_t_(self_));
+ baseType->m_t_(self_);
increment_version(self);
set_flags(static_cast<Variable&>(self), flags, grad_fn, false);
if (jit::tracer::isTracing({ self })) {
@@ -618,7 +629,7 @@
grad_fn->self_argsize_dim = self.size(dim);
grad_fn->dim = dim;
}
- auto ret = as_view(static_cast<const Variable&>(self), baseType->m_squeeze_(self_, dim));
+ baseType->m_squeeze_(self_, dim);
increment_version(self);
set_flags(static_cast<Variable&>(self), flags, grad_fn, false);
if (jit::tracer::isTracing({ self })) {
@@ -639,7 +650,7 @@
grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_sizes = self.sizes();
}
- auto ret = as_view(static_cast<const Variable&>(self), baseType->m_squeeze_(self_));
+ baseType->m_squeeze_(self_);
increment_version(self);
set_flags(static_cast<Variable&>(self), flags, grad_fn, false);
if (jit::tracer::isTracing({ self })) {
@@ -679,7 +690,7 @@
grad_fn->next_functions = compute_next_functions({ self });
grad_fn->dim = dim;
}
- auto ret = as_view(static_cast<const Variable&>(self), baseType->m_unsqueeze_(self_, dim));
+ baseType->m_unsqueeze_(self_, dim);
increment_version(self);
set_flags(static_cast<Variable&>(self), flags, grad_fn, false);
if (jit::tracer::isTracing({ self })) {
@@ -5639,9 +5650,9 @@
grad_fn->next_functions = compute_next_functions({ input });
grad_fn->alpha = alpha;
}
- auto ret = as_variable(baseType->elu_forward(input_, alpha, inplace));
+ auto ret = maybe_wrap(baseType->elu_forward(input_, alpha, inplace), input, inplace);
if (inplace) increment_version(input);
- set_flags(ret, flags, grad_fn);
+ set_flags(ret, flags, grad_fn, inplace);
if (grad_fn) {
auto& output = ret;
grad_fn->output_ = SavedVariable(output, grad_fn.get());
@@ -5670,9 +5681,9 @@
grad_fn->inplace = inplace;
grad_fn->output_ = SavedVariable(output, nullptr);
}
- auto ret = as_variable(baseType->elu_backward(grad_output_, input_, alpha, inplace, output_));
+ auto ret = maybe_wrap(baseType->elu_backward(grad_output_, input_, alpha, inplace, output_), input, inplace);
if (inplace) increment_version(input);
- set_flags(ret, flags, grad_fn);
+ set_flags(ret, flags, grad_fn, inplace);
if (grad_fn) {
grad_fn->input_ = SavedVariable(input, grad_fn.get());
auto& grad_input = ret;
@@ -5781,9 +5792,9 @@
grad_fn->min_val = min_val;
grad_fn->max_val = max_val;
}
- auto ret = as_variable(baseType->hardtanh_forward(input_, min_val, max_val, inplace));
+ auto ret = maybe_wrap(baseType->hardtanh_forward(input_, min_val, max_val, inplace), input, inplace);
if (inplace) increment_version(input);
- set_flags(ret, flags, grad_fn);
+ set_flags(ret, flags, grad_fn, inplace);
if (grad_fn) {
grad_fn->input_ = SavedVariable(input, grad_fn.get());
}
@@ -5809,9 +5820,9 @@
grad_fn->min_val = min_val;
grad_fn->max_val = max_val;
}
- auto ret = as_variable(baseType->hardtanh_backward(grad_output_, input_, min_val, max_val, inplace));
+ auto ret = maybe_wrap(baseType->hardtanh_backward(grad_output_, input_, min_val, max_val, inplace), input, inplace);
if (inplace) increment_version(input);
- set_flags(ret, flags, grad_fn);
+ set_flags(ret, flags, grad_fn, inplace);
if (grad_fn) {
grad_fn->input_ = SavedVariable(input, grad_fn.get());
}
@@ -5835,9 +5846,9 @@
grad_fn->next_functions = compute_next_functions({ input });
grad_fn->negative_slope = negative_slope;
}
- auto ret = as_variable(baseType->leaky_relu_forward(input_, negative_slope, inplace));
+ auto ret = maybe_wrap(baseType->leaky_relu_forward(input_, negative_slope, inplace), input, inplace);
if (inplace) increment_version(input);
- set_flags(ret, flags, grad_fn);
+ set_flags(ret, flags, grad_fn, inplace);
if (grad_fn) {
grad_fn->input_ = SavedVariable(input, grad_fn.get());
}
@@ -5861,9 +5872,9 @@
grad_fn->next_functions = compute_next_functions({ grad_output, input });
grad_fn->negative_slope = negative_slope;
}
- auto ret = as_variable(baseType->leaky_relu_backward(grad_output_, input_, negative_slope, inplace));
+ auto ret = maybe_wrap(baseType->leaky_relu_backward(grad_output_, input_, negative_slope, inplace), input, inplace);
if (inplace) increment_version(input);
- set_flags(ret, flags, grad_fn);
+ set_flags(ret, flags, grad_fn, inplace);
if (grad_fn) {
grad_fn->input_ = SavedVariable(input, grad_fn.get());
}
@@ -6028,9 +6039,9 @@
grad_fn->training = training;
grad_fn->noise_ = SavedVariable(noise, nullptr);
}
- auto ret = as_variable(baseType->rrelu_forward(input_, lower, upper, training, inplace, generator, noise_));
+ auto ret = maybe_wrap(baseType->rrelu_forward(input_, lower, upper, training, inplace, generator, noise_), input, inplace);
if (inplace) increment_version(input);
- set_flags(ret, flags, grad_fn);
+ set_flags(ret, flags, grad_fn, inplace);
if (grad_fn) {
grad_fn->input_ = SavedVariable(input, grad_fn.get());
}
@@ -6054,9 +6065,9 @@
grad_fn->training = training;
grad_fn->noise_ = SavedVariable(noise, nullptr);
}
- auto ret = as_variable(baseType->rrelu_backward(grad_output_, input_, lower, upper, training, inplace, noise_));
+ auto ret = maybe_wrap(baseType->rrelu_backward(grad_output_, input_, lower, upper, training, inplace, noise_), input, inplace);
if (inplace) increment_version(input);
- set_flags(ret, flags, grad_fn);
+ set_flags(ret, flags, grad_fn, inplace);
if (grad_fn) {
grad_fn->input_ = SavedVariable(input, grad_fn.get());
}
@@ -6223,9 +6234,9 @@
grad_fn->threshold = threshold;
grad_fn->value = value;
}
- auto ret = as_variable(baseType->threshold_forward(input_, threshold, value, inplace));
+ auto ret = maybe_wrap(baseType->threshold_forward(input_, threshold, value, inplace), input, inplace);
if (inplace) increment_version(input);
- set_flags(ret, flags, grad_fn);
+ set_flags(ret, flags, grad_fn, inplace);
if (grad_fn) {
grad_fn->input_ = SavedVariable(input, grad_fn.get());
}
@@ -6251,9 +6262,9 @@
grad_fn->threshold = threshold;
grad_fn->value = value;
}
- auto ret = as_variable(baseType->threshold_backward(grad_output_, input_, threshold, value, inplace));
+ auto ret = maybe_wrap(baseType->threshold_backward(grad_output_, input_, threshold, value, inplace), input, inplace);
if (inplace) increment_version(input);
- set_flags(ret, flags, grad_fn);
+ set_flags(ret, flags, grad_fn, inplace);
if (grad_fn) {
grad_fn->input_ = SavedVariable(input, grad_fn.get());
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment