Last active
November 2, 2017 18:42
-
-
Save colesbury/bc6835aaef3f755dd063cb484e788701 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
--- ./torch/csrc/autograd/generated/VariableType.cpp 2017-11-02 11:42:00.179695310 -0700 | |
+++ VariableType.cpp 2017-11-02 11:41:41.932641407 -0700 | |
@@ -181,6 +181,13 @@ | |
return make_variable(std::move(tensor)); | |
} | |
+Variable VariableType::maybe_wrap(Tensor data, const Variable & self, bool inplace) const { | |
+ if (inplace) { | |
+ return self; | |
+ } | |
+ return as_variable(data); | |
+} | |
+ | |
static Variable as_view(Variable base, Tensor tensor) { | |
if (base.is_view()) { | |
base = base.base(); | |
@@ -254,8 +261,12 @@ | |
if (grad_fn) { | |
grad_fn->num_inputs = 1; | |
} | |
+ if (inplace) { | |
+ var.rebase_history(flags, 0, std::move(grad_fn)); | |
+ } else { | |
var.set_history(flags, 0, std::move(grad_fn)); | |
} | |
+} | |
static void set_flags(std::vector<Variable> &vl, VarFlags flags, std::shared_ptr<Function> grad_fn) { | |
if (grad_fn) { | |
@@ -292,7 +303,7 @@ | |
auto& dst_ = unpack(dst, "dst", 1); | |
check_inplace(dst); | |
std::shared_ptr<CopyBackwards> grad_fn; | |
- auto flags = compute_flags({ src }); | |
+ auto flags = compute_flags({ dst, src }); | |
flags.requires_grad &= isFloatingPoint(dst.type().scalarType()); | |
if (flags.requires_grad) { | |
// TODO: handle type conversions | |
@@ -518,7 +529,7 @@ | |
grad_fn->dim0 = dim0; | |
grad_fn->dim1 = dim1; | |
} | |
- auto ret = as_view(static_cast<const Variable&>(self), baseType->m_transpose_(self_, dim0, dim1)); | |
+ baseType->m_transpose_(self_, dim0, dim1); | |
increment_version(self); | |
set_flags(static_cast<Variable&>(self), flags, grad_fn, false); | |
if (jit::tracer::isTracing({ self })) { | |
@@ -557,7 +568,7 @@ | |
grad_fn->is_executable = true; | |
grad_fn->next_functions = compute_next_functions({ self }); | |
} | |
- auto ret = as_view(static_cast<const Variable&>(self), baseType->m_t_(self_)); | |
+ baseType->m_t_(self_); | |
increment_version(self); | |
set_flags(static_cast<Variable&>(self), flags, grad_fn, false); | |
if (jit::tracer::isTracing({ self })) { | |
@@ -618,7 +629,7 @@ | |
grad_fn->self_argsize_dim = self.size(dim); | |
grad_fn->dim = dim; | |
} | |
- auto ret = as_view(static_cast<const Variable&>(self), baseType->m_squeeze_(self_, dim)); | |
+ baseType->m_squeeze_(self_, dim); | |
increment_version(self); | |
set_flags(static_cast<Variable&>(self), flags, grad_fn, false); | |
if (jit::tracer::isTracing({ self })) { | |
@@ -639,7 +650,7 @@ | |
grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_sizes = self.sizes(); | |
} | |
- auto ret = as_view(static_cast<const Variable&>(self), baseType->m_squeeze_(self_)); | |
+ baseType->m_squeeze_(self_); | |
increment_version(self); | |
set_flags(static_cast<Variable&>(self), flags, grad_fn, false); | |
if (jit::tracer::isTracing({ self })) { | |
@@ -679,7 +690,7 @@ | |
grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->dim = dim; | |
} | |
- auto ret = as_view(static_cast<const Variable&>(self), baseType->m_unsqueeze_(self_, dim)); | |
+ baseType->m_unsqueeze_(self_, dim); | |
increment_version(self); | |
set_flags(static_cast<Variable&>(self), flags, grad_fn, false); | |
if (jit::tracer::isTracing({ self })) { | |
@@ -5639,9 +5650,9 @@ | |
grad_fn->next_functions = compute_next_functions({ input }); | |
grad_fn->alpha = alpha; | |
} | |
- auto ret = as_variable(baseType->elu_forward(input_, alpha, inplace)); | |
+ auto ret = maybe_wrap(baseType->elu_forward(input_, alpha, inplace), input, inplace); | |
if (inplace) increment_version(input); | |
- set_flags(ret, flags, grad_fn); | |
+ set_flags(ret, flags, grad_fn, inplace); | |
if (grad_fn) { | |
auto& output = ret; | |
grad_fn->output_ = SavedVariable(output, grad_fn.get()); | |
@@ -5670,9 +5681,9 @@ | |
grad_fn->inplace = inplace; | |
grad_fn->output_ = SavedVariable(output, nullptr); | |
} | |
- auto ret = as_variable(baseType->elu_backward(grad_output_, input_, alpha, inplace, output_)); | |
+ auto ret = maybe_wrap(baseType->elu_backward(grad_output_, input_, alpha, inplace, output_), input, inplace); | |
if (inplace) increment_version(input); | |
- set_flags(ret, flags, grad_fn); | |
+ set_flags(ret, flags, grad_fn, inplace); | |
if (grad_fn) { | |
grad_fn->input_ = SavedVariable(input, grad_fn.get()); | |
auto& grad_input = ret; | |
@@ -5781,9 +5792,9 @@ | |
grad_fn->min_val = min_val; | |
grad_fn->max_val = max_val; | |
} | |
- auto ret = as_variable(baseType->hardtanh_forward(input_, min_val, max_val, inplace)); | |
+ auto ret = maybe_wrap(baseType->hardtanh_forward(input_, min_val, max_val, inplace), input, inplace); | |
if (inplace) increment_version(input); | |
- set_flags(ret, flags, grad_fn); | |
+ set_flags(ret, flags, grad_fn, inplace); | |
if (grad_fn) { | |
grad_fn->input_ = SavedVariable(input, grad_fn.get()); | |
} | |
@@ -5809,9 +5820,9 @@ | |
grad_fn->min_val = min_val; | |
grad_fn->max_val = max_val; | |
} | |
- auto ret = as_variable(baseType->hardtanh_backward(grad_output_, input_, min_val, max_val, inplace)); | |
+ auto ret = maybe_wrap(baseType->hardtanh_backward(grad_output_, input_, min_val, max_val, inplace), input, inplace); | |
if (inplace) increment_version(input); | |
- set_flags(ret, flags, grad_fn); | |
+ set_flags(ret, flags, grad_fn, inplace); | |
if (grad_fn) { | |
grad_fn->input_ = SavedVariable(input, grad_fn.get()); | |
} | |
@@ -5835,9 +5846,9 @@ | |
grad_fn->next_functions = compute_next_functions({ input }); | |
grad_fn->negative_slope = negative_slope; | |
} | |
- auto ret = as_variable(baseType->leaky_relu_forward(input_, negative_slope, inplace)); | |
+ auto ret = maybe_wrap(baseType->leaky_relu_forward(input_, negative_slope, inplace), input, inplace); | |
if (inplace) increment_version(input); | |
- set_flags(ret, flags, grad_fn); | |
+ set_flags(ret, flags, grad_fn, inplace); | |
if (grad_fn) { | |
grad_fn->input_ = SavedVariable(input, grad_fn.get()); | |
} | |
@@ -5861,9 +5872,9 @@ | |
grad_fn->next_functions = compute_next_functions({ grad_output, input }); | |
grad_fn->negative_slope = negative_slope; | |
} | |
- auto ret = as_variable(baseType->leaky_relu_backward(grad_output_, input_, negative_slope, inplace)); | |
+ auto ret = maybe_wrap(baseType->leaky_relu_backward(grad_output_, input_, negative_slope, inplace), input, inplace); | |
if (inplace) increment_version(input); | |
- set_flags(ret, flags, grad_fn); | |
+ set_flags(ret, flags, grad_fn, inplace); | |
if (grad_fn) { | |
grad_fn->input_ = SavedVariable(input, grad_fn.get()); | |
} | |
@@ -6028,9 +6039,9 @@ | |
grad_fn->training = training; | |
grad_fn->noise_ = SavedVariable(noise, nullptr); | |
} | |
- auto ret = as_variable(baseType->rrelu_forward(input_, lower, upper, training, inplace, generator, noise_)); | |
+ auto ret = maybe_wrap(baseType->rrelu_forward(input_, lower, upper, training, inplace, generator, noise_), input, inplace); | |
if (inplace) increment_version(input); | |
- set_flags(ret, flags, grad_fn); | |
+ set_flags(ret, flags, grad_fn, inplace); | |
if (grad_fn) { | |
grad_fn->input_ = SavedVariable(input, grad_fn.get()); | |
} | |
@@ -6054,9 +6065,9 @@ | |
grad_fn->training = training; | |
grad_fn->noise_ = SavedVariable(noise, nullptr); | |
} | |
- auto ret = as_variable(baseType->rrelu_backward(grad_output_, input_, lower, upper, training, inplace, noise_)); | |
+ auto ret = maybe_wrap(baseType->rrelu_backward(grad_output_, input_, lower, upper, training, inplace, noise_), input, inplace); | |
if (inplace) increment_version(input); | |
- set_flags(ret, flags, grad_fn); | |
+ set_flags(ret, flags, grad_fn, inplace); | |
if (grad_fn) { | |
grad_fn->input_ = SavedVariable(input, grad_fn.get()); | |
} | |
@@ -6223,9 +6234,9 @@ | |
grad_fn->threshold = threshold; | |
grad_fn->value = value; | |
} | |
- auto ret = as_variable(baseType->threshold_forward(input_, threshold, value, inplace)); | |
+ auto ret = maybe_wrap(baseType->threshold_forward(input_, threshold, value, inplace), input, inplace); | |
if (inplace) increment_version(input); | |
- set_flags(ret, flags, grad_fn); | |
+ set_flags(ret, flags, grad_fn, inplace); | |
if (grad_fn) { | |
grad_fn->input_ = SavedVariable(input, grad_fn.get()); | |
} | |
@@ -6251,9 +6262,9 @@ | |
grad_fn->threshold = threshold; | |
grad_fn->value = value; | |
} | |
- auto ret = as_variable(baseType->threshold_backward(grad_output_, input_, threshold, value, inplace)); | |
+ auto ret = maybe_wrap(baseType->threshold_backward(grad_output_, input_, threshold, value, inplace), input, inplace); | |
if (inplace) increment_version(input); | |
- set_flags(ret, flags, grad_fn); | |
+ set_flags(ret, flags, grad_fn, inplace); | |
if (grad_fn) { | |
grad_fn->input_ = SavedVariable(input, grad_fn.get()); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment