Created
October 21, 2017 11:32
-
-
Save colesbury/1b9d4ed6a05c0e49fc0c95eab8f997f5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
--- VariableType.cpp 2017-10-21 04:32:34.984221040 -0700 | |
+++ ./torch/csrc/autograd/generated/VariableType.cpp 2017-10-21 04:32:46.720255617 -0700 | |
@@ -173,12 +173,61 @@ | |
return make_variable(std::move(tensor)); | |
} | |
-static void check_inplace(const VariableImpl& pImpl) { | |
- if (pImpl.requires_grad && !pImpl.grad_fn) { | |
+struct VariableFlags { | |
+ bool requires_grad; | |
+ bool is_volatile; | |
+}; | |
+ | |
+template<typename T> | |
+static VariableFlags compute_flags_tmpl(T tensors) { | |
+ VariableFlags flags = {false, false}; | |
+ for (const Tensor& tensor : tensors) { | |
+ auto& var = static_cast<const Variable&>(tensor); | |
+ if (var.defined()) { | |
+ flags.requires_grad |= var.requires_grad(); | |
+ flags.is_volatile |= var.is_volatile(); | |
+ } | |
+ } | |
+ flags.requires_grad &= !flags.is_volatile; | |
+ return flags; | |
+} | |
+ | |
+using TensorRef = std::reference_wrapper<const Tensor>; | |
+using TensorRefList = std::initializer_list<TensorRef>; | |
+ | |
+static VariableFlags compute_flags(const TensorRefList& tensors) { | |
+ return compute_flags_tmpl(tensors); | |
+} | |
+ | |
+static VariableFlags compute_flags(TensorList tensors) { | |
+ return compute_flags_tmpl(tensors); | |
+} | |
+ | |
+static void check_no_requires_grad(const Tensor& tensor, const char* name) { | |
+ auto& var = static_cast<const Variable&>(tensor); | |
+ if (var.defined() && var.requires_grad()) { | |
+ std::string msg = "the derivative for '"; | |
+ msg += name; | |
+ msg += "' is not implemented"; | |
+ throw std::runtime_error(msg); | |
+ } | |
+} | |
+ | |
+static function_list compute_next_functions(const std::initializer_list<Tensor>& tensors) { | |
+ return Function::flags(tensors).next_functions; | |
+} | |
+ | |
+static function_list compute_next_functions(TensorList tensors) { | |
+ return Function::flags(tensors).next_functions; | |
+} | |
+ | |
+static void check_inplace(const Tensor& tensor) { | |
+ auto& var = static_cast<const Variable&>(tensor); | |
+ if (var.requires_grad() && !var.grad_fn()) { | |
at::runtime_error( | |
"a leaf Variable that requires grad has been used in an in-place operation."); | |
} | |
- auto live_refs = pImpl.version_counter.live_refs(); | |
+ auto live_refs = var.version_counter().live_refs(); | |
if (live_refs > 1) { | |
at::runtime_error( | |
"in-place operations can be only used on variables that don't share " | |
@@ -187,29 +236,13 @@ | |
} | |
} | |
-static void wrap_output(VariableImpl& pImpl, FunctionFlags flags, std::shared_ptr<Function> grad_fn) { | |
- // Hooks up the grad_fn and sets the flags of the function output. This only | |
- // supports a single differentiable output. | |
- pImpl.requires_grad = flags.is_executable; | |
- pImpl.is_volatile = flags.is_volatile; | |
- if (!flags.is_volatile) { | |
- pImpl.output_nr = grad_fn->num_inputs++; | |
- grad_fn->set_flags(std::move(flags)); | |
- pImpl.grad_fn = std::move(grad_fn); | |
- } | |
-} | |
- | |
-static void wrap_output(Tensor& t, FunctionFlags flags, std::shared_ptr<Function> grad_fn) { | |
- auto pImpl = static_cast<VariableImpl*>(t.get()); | |
- wrap_output(*pImpl, std::move(flags), std::move(grad_fn)); | |
-} | |
- | |
-static void wrap_output(std::tuple<Variable, Variable>& t, FunctionFlags flags, std::shared_ptr<Function> grad_fn) { | |
- wrap_output(std::get<0>(t), std::move(flags), std::move(grad_fn)); | |
+static void set_flags(Variable& var, VariableFlags flags, std::shared_ptr<Function> grad_fn) { | |
+ var.requires_grad() = flags.requires_grad; | |
+ var.is_volatile() = flags.is_volatile; | |
+ if (grad_fn) { | |
+ var.output_nr() = grad_fn->num_inputs++; | |
+ var.grad_fn() = std::move(grad_fn); | |
} | |
- | |
-static void wrap_output(std::tuple<Variable, Variable, Variable>& t, FunctionFlags flags, std::shared_ptr<Function> grad_fn) { | |
- wrap_output(std::get<0>(t), std::move(flags), std::move(grad_fn)); | |
} | |
static void increment_version(const Tensor & t) { | |
@@ -234,30 +267,29 @@ | |
// it automatically | |
auto& src_ = unpack_any(src, "src", 0); | |
auto& dst_ = unpack(dst, "dst", 1); | |
- auto& pImpl = static_cast<VariableImpl&>(*dst.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ src }); | |
- baseType->s_copy(src_, dst_); | |
- pImpl.version_counter.increment(); | |
- if (isFloatingPoint(dst.type().scalarType())) { | |
- if (isFloatingPoint(src.type().scalarType())) { | |
- // TODO: handle type conversions | |
- wrap_output(pImpl, std::move(flags), std::make_shared<Identity>()); | |
- } else { | |
+ check_inplace(dst); | |
+ std::shared_ptr<Identity> grad_fn; | |
+ auto flags = compute_flags({ src }); | |
+ flags.requires_grad &= isFloatingPoint(dst.type().scalarType()); | |
+ if (flags.requires_grad) { | |
// TODO: handle type conversions | |
- wrap_output(pImpl, std::move(flags), std::make_shared<Identity>()); | |
- } | |
+ grad_fn = std::make_shared<Identity>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ src }); | |
} | |
+ baseType->s_copy(src_, dst_); | |
+ increment_version(dst); | |
+ set_flags(static_cast<Variable&>(dst), flags, std::move(grad_fn)); | |
} | |
Tensor & VariableType::m_resize_(Tensor & self, IntList size) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- if (pImpl.grad_fn) { | |
+ check_inplace(self); | |
+ auto& self_var = static_cast<Variable&>(self); | |
+ if (self_var.grad_fn()) { | |
at::runtime_error("cannot resize non-leaf variables"); | |
} | |
- if (pImpl.requires_grad) { | |
+ if (self_var.requires_grad()) { | |
at::runtime_error("cannot resize variables which require grad"); | |
} | |
baseType->m_resize_(self_, size); | |
@@ -333,15 +365,20 @@ | |
} | |
Tensor & VariableType::m_fill_(Tensor & self, Scalar value) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ }); | |
- auto grad_fn = std::make_shared<FillBackward>(); | |
+ check_inplace(self); | |
+ check_no_requires_grad(self, "self"); | |
+ std::shared_ptr<FillBackward> grad_fn; | |
+ auto flags = compute_flags({ }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<FillBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ }); | |
+ } | |
baseType->m_fill_(self_, value); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ })) { | |
- jit::Node *n = jit::tracer::recordTrace( "fill_", { }, self ); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
+ if (jit::tracer::isTracing({ self })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "fill_", { self }, self ); | |
setattr(n, jit::stringToSymbol("value"), value); | |
} | |
return self; | |
@@ -363,18 +400,20 @@ | |
Tensor & VariableType::s_m_masked_fill_(Tensor & self, const Tensor & mask, Scalar value) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& mask_ = unpack_byte(mask, "mask", 1); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<MaskedFillBackward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<MaskedFillBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<MaskedFillBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->mask_ = SavedVariable(mask, nullptr); | |
} | |
baseType->s_m_masked_fill_(self_, mask_, value); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ self })) { | |
- jit::Node *n = jit::tracer::recordTrace( "masked_fill_", { self }, self ); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
+ if (jit::tracer::isTracing({ self, mask })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "masked_fill_", { self, mask }, self ); | |
setattr(n, jit::stringToSymbol("value"), value); | |
} | |
return self; | |
@@ -383,19 +422,21 @@ | |
auto& self_ = unpack(self, "self", 0); | |
auto& mask_ = unpack_byte(mask, "mask", 1); | |
auto& source_ = unpack(source, "source", 2); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self, source }); | |
- auto grad_fn = std::make_shared<MaskedScatterBackward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<MaskedScatterBackward> grad_fn; | |
+ auto flags = compute_flags({ self, source }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<MaskedScatterBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, source }); | |
grad_fn->mask_ = SavedVariable(mask, nullptr); | |
grad_fn->source_sizes = source.sizes(); | |
} | |
baseType->s_m_masked_scatter_(self_, mask_, source_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ self, source })) { | |
- jit::Node *n = jit::tracer::recordTrace( "masked_scatter_", { self, source }, self ); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
+ if (jit::tracer::isTracing({ self, mask, source })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "masked_scatter_", { self, mask, source }, self ); | |
(void)n; | |
} | |
return self; | |
@@ -403,31 +444,37 @@ | |
Tensor VariableType::s_masked_select(const Tensor & self, const Tensor & mask) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& mask_ = unpack_byte(mask, "mask", 1); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<MaskedSelectBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<MaskedSelectBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<MaskedSelectBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_info = self; | |
grad_fn->mask_ = SavedVariable(mask, nullptr); | |
} | |
auto ret = as_variable(baseType->s_masked_select(self_, mask_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ self })) { | |
- jit::Node *n = jit::tracer::recordTrace( "masked_select", { self }, ret ); | |
+ set_flags(ret, flags, grad_fn); | |
+ if (jit::tracer::isTracing({ self, mask })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "masked_select", { self, mask }, ret ); | |
(void)n; | |
} | |
return Tensor(std::move(ret)); | |
} | |
Tensor VariableType::transpose(const Tensor & self, int64_t dim0, int64_t dim1) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<TransposeBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<TransposeBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<TransposeBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->dim0 = dim0; | |
grad_fn->dim1 = dim1; | |
} | |
auto ret = as_variable(baseType->transpose(self_, dim0, dim1)); | |
take_version_counter(ret, self); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "transpose", { self }, ret ); | |
setattr(n, jit::stringToSymbol("dim0"), dim0); | |
@@ -437,17 +484,19 @@ | |
} | |
Tensor & VariableType::m_transpose_(Tensor & self, int64_t dim0, int64_t dim1) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<TransposeBackward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<TransposeBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<TransposeBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->dim0 = dim0; | |
grad_fn->dim1 = dim1; | |
} | |
baseType->m_transpose_(self_, dim0, dim1); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "transpose_", { self }, self ); | |
setattr(n, jit::stringToSymbol("dim0"), dim0); | |
@@ -457,11 +506,16 @@ | |
} | |
Tensor VariableType::t(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<TBackward>(); | |
+ std::shared_ptr<TBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<TBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
auto ret = as_variable(baseType->t(self_)); | |
take_version_counter(ret, self); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "t", { self }, ret ); | |
(void)n; | |
@@ -470,13 +524,17 @@ | |
} | |
Tensor & VariableType::m_t_(Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<TBackward>(); | |
+ check_inplace(self); | |
+ std::shared_ptr<TBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<TBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
baseType->m_t_(self_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "t_", { self }, self ); | |
(void)n; | |
@@ -485,15 +543,18 @@ | |
} | |
Tensor VariableType::squeeze(const Tensor & self, int64_t dim) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<SqueezeBackward1>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<SqueezeBackward1> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SqueezeBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_argsize_dim = self.size(dim); | |
grad_fn->dim = dim; | |
} | |
auto ret = as_variable(baseType->squeeze(self_, dim)); | |
take_version_counter(ret, self); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "squeeze", { self }, ret ); | |
setattr(n, jit::stringToSymbol("dim"), dim); | |
@@ -502,14 +563,17 @@ | |
} | |
Tensor VariableType::squeeze(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<SqueezeBackward0>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<SqueezeBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SqueezeBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_sizes = self.sizes(); | |
} | |
auto ret = as_variable(baseType->squeeze(self_)); | |
take_version_counter(ret, self); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "squeeze", { self }, ret ); | |
(void)n; | |
@@ -518,17 +582,19 @@ | |
} | |
Tensor & VariableType::m_squeeze_(Tensor & self, int64_t dim) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<SqueezeBackward1>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<SqueezeBackward1> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SqueezeBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_argsize_dim = self.size(dim); | |
grad_fn->dim = dim; | |
} | |
baseType->m_squeeze_(self_, dim); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "squeeze_", { self }, self ); | |
setattr(n, jit::stringToSymbol("dim"), dim); | |
@@ -537,16 +603,18 @@ | |
} | |
Tensor & VariableType::m_squeeze_(Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<SqueezeBackward0>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<SqueezeBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SqueezeBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_sizes = self.sizes(); | |
} | |
baseType->m_squeeze_(self_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "squeeze_", { self }, self ); | |
(void)n; | |
@@ -555,14 +623,17 @@ | |
} | |
Tensor VariableType::unsqueeze(const Tensor & self, int64_t dim) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<UnsqueezeBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<UnsqueezeBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<UnsqueezeBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->dim = dim; | |
} | |
auto ret = as_variable(baseType->unsqueeze(self_, dim)); | |
take_version_counter(ret, self); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "unsqueeze", { self }, ret ); | |
setattr(n, jit::stringToSymbol("dim"), dim); | |
@@ -571,16 +642,18 @@ | |
} | |
Tensor & VariableType::m_unsqueeze_(Tensor & self, int64_t dim) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<UnsqueezeBackward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<UnsqueezeBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<UnsqueezeBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->dim = dim; | |
} | |
baseType->m_unsqueeze_(self_, dim); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "unsqueeze_", { self }, self ); | |
setattr(n, jit::stringToSymbol("dim"), dim); | |
@@ -589,10 +662,15 @@ | |
} | |
Tensor VariableType::nonzero(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<NonzeroBackward>(); | |
+ std::shared_ptr<NonzeroBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<NonzeroBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
auto ret = as_variable(baseType->nonzero(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "nonzero", { self }, ret ); | |
(void)n; | |
@@ -601,10 +679,15 @@ | |
} | |
Tensor VariableType::m_clone(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<CloneBackward>(); | |
+ std::shared_ptr<CloneBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<CloneBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
auto ret = as_variable(baseType->m_clone(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "clone", { self }, ret ); | |
(void)n; | |
@@ -613,14 +696,17 @@ | |
} | |
Tensor VariableType::m_view(const Tensor & self, IntList size) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<ViewBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<ViewBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<ViewBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_sizes = self.sizes(); | |
} | |
auto ret = as_variable(baseType->m_view(self_, size)); | |
take_version_counter(ret, self); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "view", { self }, ret ); | |
setattr(n, jit::stringToSymbol("size"), size); | |
@@ -629,14 +715,17 @@ | |
} | |
Tensor VariableType::m_expand(const Tensor & self, IntList size) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<ExpandBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<ExpandBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<ExpandBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_sizes = self.sizes(); | |
} | |
auto ret = as_variable(baseType->m_expand(self_, size)); | |
take_version_counter(ret, self); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "expand", { self }, ret ); | |
setattr(n, jit::stringToSymbol("size"), size); | |
@@ -646,18 +735,21 @@ | |
Tensor VariableType::index_select(const Tensor & self, int64_t dim, const Tensor & index) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& index_ = unpack_long(index, "index", 2); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<IndexSelectBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<IndexSelectBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<IndexSelectBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_sizes = self.sizes(); | |
grad_fn->dim = dim; | |
grad_fn->index_ = SavedVariable(index, nullptr); | |
} | |
auto ret = as_variable(baseType->index_select(self_, dim, index_)); | |
take_version_counter(ret, self); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ self })) { | |
- jit::Node *n = jit::tracer::recordTrace( "index_select", { self }, ret ); | |
+ set_flags(ret, flags, grad_fn); | |
+ if (jit::tracer::isTracing({ self, index })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "index_select", { self, index }, ret ); | |
setattr(n, jit::stringToSymbol("dim"), dim); | |
} | |
return Tensor(std::move(ret)); | |
@@ -666,19 +758,21 @@ | |
auto& self_ = unpack(self, "self", 0); | |
auto& index_ = unpack_long(index, "index", 2); | |
auto& source_ = unpack(source, "source", 3); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self, source }); | |
- auto grad_fn = std::make_shared<IndexCopyBackward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<IndexCopyBackward> grad_fn; | |
+ auto flags = compute_flags({ self, source }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<IndexCopyBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, source }); | |
grad_fn->dim = dim; | |
grad_fn->index_ = SavedVariable(index, nullptr); | |
} | |
baseType->m_index_copy_(self_, dim, index_, source_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ self, source })) { | |
- jit::Node *n = jit::tracer::recordTrace( "index_copy_", { self, source }, self ); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
+ if (jit::tracer::isTracing({ self, index, source })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "index_copy_", { self, index, source }, self ); | |
setattr(n, jit::stringToSymbol("dim"), dim); | |
} | |
return self; | |
@@ -687,19 +781,21 @@ | |
auto& self_ = unpack(self, "self", 0); | |
auto& index_ = unpack_long(index, "index", 2); | |
auto& source_ = unpack(source, "source", 3); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self, source }); | |
- auto grad_fn = std::make_shared<IndexAddBackward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<IndexAddBackward> grad_fn; | |
+ auto flags = compute_flags({ self, source }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<IndexAddBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, source }); | |
grad_fn->dim = dim; | |
grad_fn->index_ = SavedVariable(index, nullptr); | |
} | |
baseType->m_index_add_(self_, dim, index_, source_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ self, source })) { | |
- jit::Node *n = jit::tracer::recordTrace( "index_add_", { self, source }, self ); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
+ if (jit::tracer::isTracing({ self, index, source })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "index_add_", { self, index, source }, self ); | |
setattr(n, jit::stringToSymbol("dim"), dim); | |
} | |
return self; | |
@@ -707,19 +803,21 @@ | |
Tensor & VariableType::m_index_fill_(Tensor & self, int64_t dim, const Tensor & index, Scalar value) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& index_ = unpack_long(index, "index", 2); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<IndexFillBackward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<IndexFillBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<IndexFillBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->dim = dim; | |
grad_fn->index_ = SavedVariable(index, nullptr); | |
} | |
baseType->m_index_fill_(self_, dim, index_, value); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ self })) { | |
- jit::Node *n = jit::tracer::recordTrace( "index_fill_", { self }, self ); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
+ if (jit::tracer::isTracing({ self, index })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "index_fill_", { self, index }, self ); | |
setattr(n, jit::stringToSymbol("dim"), dim); | |
setattr(n, jit::stringToSymbol("value"), value); | |
} | |
@@ -727,16 +825,19 @@ | |
} | |
Tensor VariableType::m_narrow(const Tensor & self, int64_t dimension, int64_t start, int64_t length) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<NarrowBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<NarrowBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<NarrowBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_argsize_dimension = self.size(dimension); | |
grad_fn->dimension = dimension; | |
grad_fn->start = start; | |
} | |
auto ret = as_variable(baseType->m_narrow(self_, dimension, start, length)); | |
take_version_counter(ret, self); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "narrow", { self }, ret ); | |
setattr(n, jit::stringToSymbol("dimension"), dimension); | |
@@ -747,16 +848,19 @@ | |
} | |
Tensor VariableType::m_unfold(const Tensor & self, int64_t dimension, int64_t size, int64_t step) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<UnfoldBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<UnfoldBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<UnfoldBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_sizes = self.sizes(); | |
grad_fn->dimension = dimension; | |
grad_fn->size = size; | |
grad_fn->step = step; | |
} | |
auto ret = as_variable(baseType->m_unfold(self_, dimension, size, step)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "unfold", { self }, ret ); | |
setattr(n, jit::stringToSymbol("dimension"), dimension); | |
@@ -775,19 +879,21 @@ | |
auto& self_ = unpack(self, "self", 0); | |
auto& index_ = unpack_long(index, "index", 2); | |
auto& src_ = unpack(src, "src", 3); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self, src }); | |
- auto grad_fn = std::make_shared<ScatterBackward0>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<ScatterBackward0> grad_fn; | |
+ auto flags = compute_flags({ self, src }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<ScatterBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, src }); | |
grad_fn->dim = dim; | |
grad_fn->index_ = SavedVariable(index, nullptr); | |
} | |
baseType->m_scatter_(self_, dim, index_, src_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ self, src })) { | |
- jit::Node *n = jit::tracer::recordTrace( "scatter_", { self, src }, self ); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
+ if (jit::tracer::isTracing({ self, index, src })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "scatter_", { self, index, src }, self ); | |
setattr(n, jit::stringToSymbol("dim"), dim); | |
} | |
return self; | |
@@ -795,19 +901,21 @@ | |
Tensor & VariableType::m_scatter_(Tensor & self, int64_t dim, const Tensor & index, Scalar value) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& index_ = unpack_long(index, "index", 2); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<ScatterBackward1>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<ScatterBackward1> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<ScatterBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->dim = dim; | |
grad_fn->index_ = SavedVariable(index, nullptr); | |
} | |
baseType->m_scatter_(self_, dim, index_, value); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ self })) { | |
- jit::Node *n = jit::tracer::recordTrace( "scatter_", { self }, self ); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
+ if (jit::tracer::isTracing({ self, index })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "scatter_", { self, index }, self ); | |
setattr(n, jit::stringToSymbol("dim"), dim); | |
setattr(n, jit::stringToSymbol("value"), value); | |
} | |
@@ -817,19 +925,21 @@ | |
auto& self_ = unpack(self, "self", 0); | |
auto& index_ = unpack_long(index, "index", 2); | |
auto& src_ = unpack(src, "src", 3); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self, src }); | |
- auto grad_fn = std::make_shared<ScatterAddBackward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<ScatterAddBackward> grad_fn; | |
+ auto flags = compute_flags({ self, src }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<ScatterAddBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, src }); | |
grad_fn->dim = dim; | |
grad_fn->index_ = SavedVariable(index, nullptr); | |
} | |
baseType->m_scatter_add_(self_, dim, index_, src_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ self, src })) { | |
- jit::Node *n = jit::tracer::recordTrace( "scatter_add_", { self, src }, self ); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
+ if (jit::tracer::isTracing({ self, index, src })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "scatter_add_", { self, index, src }, self ); | |
setattr(n, jit::stringToSymbol("dim"), dim); | |
} | |
return self; | |
@@ -837,17 +947,20 @@ | |
Tensor VariableType::gather(const Tensor & self, int64_t dim, const Tensor & index) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& index_ = unpack_long(index, "index", 2); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<GatherBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<GatherBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<GatherBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_sizes = self.sizes(); | |
grad_fn->dim = dim; | |
grad_fn->index_ = SavedVariable(index, nullptr); | |
} | |
auto ret = as_variable(baseType->gather(self_, dim, index_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ self })) { | |
- jit::Node *n = jit::tracer::recordTrace( "gather", { self }, ret ); | |
+ set_flags(ret, flags, grad_fn); | |
+ if (jit::tracer::isTracing({ self, index })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "gather", { self, index }, ret ); | |
setattr(n, jit::stringToSymbol("dim"), dim); | |
} | |
return Tensor(std::move(ret)); | |
@@ -973,13 +1086,16 @@ | |
} | |
Tensor VariableType::lt(const Tensor & self, Scalar other) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<LtBackward0>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<LtBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<LtBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_info = self; | |
} | |
auto ret = as_variable(baseType->lt(self_, other)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "lt", { self }, ret ); | |
setattr(n, jit::stringToSymbol("other"), other); | |
@@ -989,14 +1105,17 @@ | |
Tensor VariableType::s_lt(const Tensor & self, const Tensor & other) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& other_ = unpack(other, "other", 1); | |
- auto flags = Function::flags({ self, other }); | |
- auto grad_fn = std::make_shared<LtBackward1>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<LtBackward1> grad_fn; | |
+ auto flags = compute_flags({ self, other }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<LtBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, other }); | |
grad_fn->self_info = self; | |
grad_fn->other_info = other; | |
} | |
auto ret = as_variable(baseType->s_lt(self_, other_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self, other })) { | |
jit::Node *n = jit::tracer::recordTrace( "lt", { self, other }, ret ); | |
(void)n; | |
@@ -1005,16 +1124,18 @@ | |
} | |
Tensor & VariableType::m_lt_(Tensor & self, Scalar other) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<LtBackward0>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<LtBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<LtBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_info = self; | |
} | |
baseType->m_lt_(self_, other); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "lt_", { self }, self ); | |
setattr(n, jit::stringToSymbol("other"), other); | |
@@ -1024,17 +1145,19 @@ | |
Tensor & VariableType::s_m_lt_(Tensor & self, const Tensor & other) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& other_ = unpack(other, "other", 1); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self, other }); | |
- auto grad_fn = std::make_shared<LtBackward1>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<LtBackward1> grad_fn; | |
+ auto flags = compute_flags({ self, other }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<LtBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, other }); | |
grad_fn->self_info = self; | |
grad_fn->other_info = other; | |
} | |
baseType->s_m_lt_(self_, other_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self, other })) { | |
jit::Node *n = jit::tracer::recordTrace( "lt_", { self, other }, self ); | |
(void)n; | |
@@ -1043,13 +1166,16 @@ | |
} | |
Tensor VariableType::gt(const Tensor & self, Scalar other) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<GtBackward0>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<GtBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<GtBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_info = self; | |
} | |
auto ret = as_variable(baseType->gt(self_, other)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "gt", { self }, ret ); | |
setattr(n, jit::stringToSymbol("other"), other); | |
@@ -1059,14 +1185,17 @@ | |
Tensor VariableType::s_gt(const Tensor & self, const Tensor & other) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& other_ = unpack(other, "other", 1); | |
- auto flags = Function::flags({ self, other }); | |
- auto grad_fn = std::make_shared<GtBackward1>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<GtBackward1> grad_fn; | |
+ auto flags = compute_flags({ self, other }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<GtBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, other }); | |
grad_fn->self_info = self; | |
grad_fn->other_info = other; | |
} | |
auto ret = as_variable(baseType->s_gt(self_, other_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self, other })) { | |
jit::Node *n = jit::tracer::recordTrace( "gt", { self, other }, ret ); | |
(void)n; | |
@@ -1075,16 +1204,18 @@ | |
} | |
Tensor & VariableType::m_gt_(Tensor & self, Scalar other) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<GtBackward0>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<GtBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<GtBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_info = self; | |
} | |
baseType->m_gt_(self_, other); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "gt_", { self }, self ); | |
setattr(n, jit::stringToSymbol("other"), other); | |
@@ -1094,17 +1225,19 @@ | |
Tensor & VariableType::s_m_gt_(Tensor & self, const Tensor & other) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& other_ = unpack(other, "other", 1); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self, other }); | |
- auto grad_fn = std::make_shared<GtBackward1>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<GtBackward1> grad_fn; | |
+ auto flags = compute_flags({ self, other }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<GtBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, other }); | |
grad_fn->self_info = self; | |
grad_fn->other_info = other; | |
} | |
baseType->s_m_gt_(self_, other_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self, other })) { | |
jit::Node *n = jit::tracer::recordTrace( "gt_", { self, other }, self ); | |
(void)n; | |
@@ -1113,13 +1246,16 @@ | |
} | |
Tensor VariableType::le(const Tensor & self, Scalar other) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<LeBackward0>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<LeBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<LeBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_info = self; | |
} | |
auto ret = as_variable(baseType->le(self_, other)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "le", { self }, ret ); | |
setattr(n, jit::stringToSymbol("other"), other); | |
@@ -1129,14 +1265,17 @@ | |
Tensor VariableType::s_le(const Tensor & self, const Tensor & other) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& other_ = unpack(other, "other", 1); | |
- auto flags = Function::flags({ self, other }); | |
- auto grad_fn = std::make_shared<LeBackward1>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<LeBackward1> grad_fn; | |
+ auto flags = compute_flags({ self, other }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<LeBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, other }); | |
grad_fn->self_info = self; | |
grad_fn->other_info = other; | |
} | |
auto ret = as_variable(baseType->s_le(self_, other_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self, other })) { | |
jit::Node *n = jit::tracer::recordTrace( "le", { self, other }, ret ); | |
(void)n; | |
@@ -1145,16 +1284,18 @@ | |
} | |
Tensor & VariableType::m_le_(Tensor & self, Scalar other) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<LeBackward0>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<LeBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<LeBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_info = self; | |
} | |
baseType->m_le_(self_, other); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "le_", { self }, self ); | |
setattr(n, jit::stringToSymbol("other"), other); | |
@@ -1164,17 +1305,19 @@ | |
Tensor & VariableType::s_m_le_(Tensor & self, const Tensor & other) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& other_ = unpack(other, "other", 1); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self, other }); | |
- auto grad_fn = std::make_shared<LeBackward1>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<LeBackward1> grad_fn; | |
+ auto flags = compute_flags({ self, other }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<LeBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, other }); | |
grad_fn->self_info = self; | |
grad_fn->other_info = other; | |
} | |
baseType->s_m_le_(self_, other_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self, other })) { | |
jit::Node *n = jit::tracer::recordTrace( "le_", { self, other }, self ); | |
(void)n; | |
@@ -1183,13 +1326,16 @@ | |
} | |
Tensor VariableType::ge(const Tensor & self, Scalar other) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<GeBackward0>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<GeBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<GeBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_info = self; | |
} | |
auto ret = as_variable(baseType->ge(self_, other)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "ge", { self }, ret ); | |
setattr(n, jit::stringToSymbol("other"), other); | |
@@ -1199,14 +1345,17 @@ | |
Tensor VariableType::s_ge(const Tensor & self, const Tensor & other) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& other_ = unpack(other, "other", 1); | |
- auto flags = Function::flags({ self, other }); | |
- auto grad_fn = std::make_shared<GeBackward1>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<GeBackward1> grad_fn; | |
+ auto flags = compute_flags({ self, other }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<GeBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, other }); | |
grad_fn->self_info = self; | |
grad_fn->other_info = other; | |
} | |
auto ret = as_variable(baseType->s_ge(self_, other_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self, other })) { | |
jit::Node *n = jit::tracer::recordTrace( "ge", { self, other }, ret ); | |
(void)n; | |
@@ -1215,16 +1364,18 @@ | |
} | |
Tensor & VariableType::m_ge_(Tensor & self, Scalar other) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<GeBackward0>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<GeBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<GeBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_info = self; | |
} | |
baseType->m_ge_(self_, other); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "ge_", { self }, self ); | |
setattr(n, jit::stringToSymbol("other"), other); | |
@@ -1234,17 +1385,19 @@ | |
Tensor & VariableType::s_m_ge_(Tensor & self, const Tensor & other) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& other_ = unpack(other, "other", 1); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self, other }); | |
- auto grad_fn = std::make_shared<GeBackward1>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<GeBackward1> grad_fn; | |
+ auto flags = compute_flags({ self, other }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<GeBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, other }); | |
grad_fn->self_info = self; | |
grad_fn->other_info = other; | |
} | |
baseType->s_m_ge_(self_, other_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self, other })) { | |
jit::Node *n = jit::tracer::recordTrace( "ge_", { self, other }, self ); | |
(void)n; | |
@@ -1253,13 +1406,16 @@ | |
} | |
Tensor VariableType::eq(const Tensor & self, Scalar other) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<EqBackward0>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<EqBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<EqBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_info = self; | |
} | |
auto ret = as_variable(baseType->eq(self_, other)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "eq", { self }, ret ); | |
setattr(n, jit::stringToSymbol("other"), other); | |
@@ -1269,14 +1425,17 @@ | |
Tensor VariableType::s_eq(const Tensor & self, const Tensor & other) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& other_ = unpack(other, "other", 1); | |
- auto flags = Function::flags({ self, other }); | |
- auto grad_fn = std::make_shared<EqBackward1>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<EqBackward1> grad_fn; | |
+ auto flags = compute_flags({ self, other }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<EqBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, other }); | |
grad_fn->self_info = self; | |
grad_fn->other_info = other; | |
} | |
auto ret = as_variable(baseType->s_eq(self_, other_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self, other })) { | |
jit::Node *n = jit::tracer::recordTrace( "eq", { self, other }, ret ); | |
(void)n; | |
@@ -1285,16 +1444,18 @@ | |
} | |
Tensor & VariableType::m_eq_(Tensor & self, Scalar other) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<EqBackward0>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<EqBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<EqBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_info = self; | |
} | |
baseType->m_eq_(self_, other); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "eq_", { self }, self ); | |
setattr(n, jit::stringToSymbol("other"), other); | |
@@ -1304,17 +1465,19 @@ | |
Tensor & VariableType::s_m_eq_(Tensor & self, const Tensor & other) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& other_ = unpack(other, "other", 1); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self, other }); | |
- auto grad_fn = std::make_shared<EqBackward1>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<EqBackward1> grad_fn; | |
+ auto flags = compute_flags({ self, other }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<EqBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, other }); | |
grad_fn->self_info = self; | |
grad_fn->other_info = other; | |
} | |
baseType->s_m_eq_(self_, other_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self, other })) { | |
jit::Node *n = jit::tracer::recordTrace( "eq_", { self, other }, self ); | |
(void)n; | |
@@ -1323,13 +1486,16 @@ | |
} | |
Tensor VariableType::ne(const Tensor & self, Scalar other) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<NeBackward0>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<NeBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<NeBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_info = self; | |
} | |
auto ret = as_variable(baseType->ne(self_, other)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "ne", { self }, ret ); | |
setattr(n, jit::stringToSymbol("other"), other); | |
@@ -1339,14 +1505,17 @@ | |
Tensor VariableType::s_ne(const Tensor & self, const Tensor & other) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& other_ = unpack(other, "other", 1); | |
- auto flags = Function::flags({ self, other }); | |
- auto grad_fn = std::make_shared<NeBackward1>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<NeBackward1> grad_fn; | |
+ auto flags = compute_flags({ self, other }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<NeBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, other }); | |
grad_fn->self_info = self; | |
grad_fn->other_info = other; | |
} | |
auto ret = as_variable(baseType->s_ne(self_, other_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self, other })) { | |
jit::Node *n = jit::tracer::recordTrace( "ne", { self, other }, ret ); | |
(void)n; | |
@@ -1355,16 +1524,18 @@ | |
} | |
Tensor & VariableType::m_ne_(Tensor & self, Scalar other) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<NeBackward0>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<NeBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<NeBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_info = self; | |
} | |
baseType->m_ne_(self_, other); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "ne_", { self }, self ); | |
setattr(n, jit::stringToSymbol("other"), other); | |
@@ -1374,17 +1545,19 @@ | |
Tensor & VariableType::s_m_ne_(Tensor & self, const Tensor & other) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& other_ = unpack(other, "other", 1); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self, other }); | |
- auto grad_fn = std::make_shared<NeBackward1>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<NeBackward1> grad_fn; | |
+ auto flags = compute_flags({ self, other }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<NeBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, other }); | |
grad_fn->self_info = self; | |
grad_fn->other_info = other; | |
} | |
baseType->s_m_ne_(self_, other_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self, other })) { | |
jit::Node *n = jit::tracer::recordTrace( "ne_", { self, other }, self ); | |
(void)n; | |
@@ -1393,16 +1566,19 @@ | |
} | |
std::tuple<Tensor,Tensor> VariableType::min(const Tensor & self, int64_t dim, bool keepdim) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<MinBackward0>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<MinBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<MinBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_sizes = self.sizes(); | |
grad_fn->dim = dim; | |
grad_fn->keepdim = keepdim; | |
} | |
auto ret = as_variable(baseType->min(self_, dim, keepdim)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(std::get<0>(ret), flags, grad_fn); | |
+ if (grad_fn) { | |
auto& min_indices = std::get<1>(ret); | |
grad_fn->min_indices_ = SavedVariable(min_indices, grad_fn.get()); | |
} | |
@@ -1416,14 +1592,17 @@ | |
Tensor VariableType::s_min(const Tensor & self, const Tensor & other) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& other_ = unpack(other, "other", 1); | |
- auto flags = Function::flags({ self, other }); | |
- auto grad_fn = std::make_shared<MinBackward2>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<MinBackward2> grad_fn; | |
+ auto flags = compute_flags({ self, other }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<MinBackward2>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, other }); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
grad_fn->other_ = SavedVariable(other, nullptr); | |
} | |
auto ret = as_variable(baseType->s_min(self_, other_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self, other })) { | |
jit::Node *n = jit::tracer::recordTrace( "min", { self, other }, ret ); | |
(void)n; | |
@@ -1432,14 +1611,17 @@ | |
} | |
Scalar VariableType::min(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<MinBackward1>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<MinBackward1> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<MinBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
} | |
auto ret = as_variable(baseType->min(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(ret, flags, grad_fn); | |
+ if (grad_fn) { | |
auto& result = ret; | |
grad_fn->result_ = SavedVariable(result, grad_fn.get()); | |
} | |
@@ -1451,16 +1633,19 @@ | |
} | |
std::tuple<Tensor,Tensor> VariableType::max(const Tensor & self, int64_t dim, bool keepdim) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<MaxBackward0>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<MaxBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<MaxBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_sizes = self.sizes(); | |
grad_fn->dim = dim; | |
grad_fn->keepdim = keepdim; | |
} | |
auto ret = as_variable(baseType->max(self_, dim, keepdim)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(std::get<0>(ret), flags, grad_fn); | |
+ if (grad_fn) { | |
auto& max_indices = std::get<1>(ret); | |
grad_fn->max_indices_ = SavedVariable(max_indices, grad_fn.get()); | |
} | |
@@ -1474,14 +1659,17 @@ | |
Tensor VariableType::s_max(const Tensor & self, const Tensor & other) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& other_ = unpack(other, "other", 1); | |
- auto flags = Function::flags({ self, other }); | |
- auto grad_fn = std::make_shared<MaxBackward2>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<MaxBackward2> grad_fn; | |
+ auto flags = compute_flags({ self, other }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<MaxBackward2>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, other }); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
grad_fn->other_ = SavedVariable(other, nullptr); | |
} | |
auto ret = as_variable(baseType->s_max(self_, other_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self, other })) { | |
jit::Node *n = jit::tracer::recordTrace( "max", { self, other }, ret ); | |
(void)n; | |
@@ -1490,14 +1678,17 @@ | |
} | |
Scalar VariableType::max(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<MaxBackward1>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<MaxBackward1> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<MaxBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
} | |
auto ret = as_variable(baseType->max(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(ret, flags, grad_fn); | |
+ if (grad_fn) { | |
auto& result = ret; | |
grad_fn->result_ = SavedVariable(result, grad_fn.get()); | |
} | |
@@ -1509,16 +1700,19 @@ | |
} | |
std::tuple<Tensor,Tensor> VariableType::kthvalue(const Tensor & self, int64_t k, int64_t dim, bool keepdim) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<KthvalueBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<KthvalueBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<KthvalueBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_sizes = self.sizes(); | |
grad_fn->dim = dim; | |
grad_fn->keepdim = keepdim; | |
} | |
auto ret = as_variable(baseType->kthvalue(self_, k, dim, keepdim)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(std::get<0>(ret), flags, grad_fn); | |
+ if (grad_fn) { | |
auto& indices = std::get<1>(ret); | |
grad_fn->indices_ = SavedVariable(indices, grad_fn.get()); | |
} | |
@@ -1532,16 +1726,19 @@ | |
} | |
std::tuple<Tensor,Tensor> VariableType::mode(const Tensor & self, int64_t dim, bool keepdim) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<ModeBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<ModeBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<ModeBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_sizes = self.sizes(); | |
grad_fn->dim = dim; | |
grad_fn->keepdim = keepdim; | |
} | |
auto ret = as_variable(baseType->mode(self_, dim, keepdim)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(std::get<0>(ret), flags, grad_fn); | |
+ if (grad_fn) { | |
auto& indices = std::get<1>(ret); | |
grad_fn->indices_ = SavedVariable(indices, grad_fn.get()); | |
} | |
@@ -1554,16 +1751,19 @@ | |
} | |
std::tuple<Tensor,Tensor> VariableType::median(const Tensor & self, int64_t dim, bool keepdim) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<MedianBackward1>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<MedianBackward1> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<MedianBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_sizes = self.sizes(); | |
grad_fn->dim = dim; | |
grad_fn->keepdim = keepdim; | |
} | |
auto ret = as_variable(baseType->median(self_, dim, keepdim)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(std::get<0>(ret), flags, grad_fn); | |
+ if (grad_fn) { | |
auto& indices = std::get<1>(ret); | |
grad_fn->indices_ = SavedVariable(indices, grad_fn.get()); | |
} | |
@@ -1576,14 +1776,17 @@ | |
} | |
Scalar VariableType::median(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<MedianBackward0>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<MedianBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<MedianBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
} | |
auto ret = as_variable(baseType->median(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(ret, flags, grad_fn); | |
+ if (grad_fn) { | |
auto& result = ret; | |
grad_fn->result_ = SavedVariable(result, grad_fn.get()); | |
} | |
@@ -1595,15 +1798,18 @@ | |
} | |
std::tuple<Tensor,Tensor> VariableType::sort(const Tensor & self, int64_t dim, bool descending) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<SortBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<SortBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SortBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_sizes = self.sizes(); | |
grad_fn->dim = dim; | |
} | |
auto ret = as_variable(baseType->sort(self_, dim, descending)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(std::get<0>(ret), flags, grad_fn); | |
+ if (grad_fn) { | |
auto& indices = std::get<1>(ret); | |
grad_fn->indices_ = SavedVariable(indices, grad_fn.get()); | |
} | |
@@ -1616,15 +1822,18 @@ | |
} | |
std::tuple<Tensor,Tensor> VariableType::topk(const Tensor & self, int64_t k, int64_t dim, bool largest, bool sorted) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<TopkBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<TopkBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<TopkBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_sizes = self.sizes(); | |
grad_fn->dim = dim; | |
} | |
auto ret = as_variable(baseType->topk(self_, k, dim, largest, sorted)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(std::get<0>(ret), flags, grad_fn); | |
+ if (grad_fn) { | |
auto& indices = std::get<1>(ret); | |
grad_fn->indices_ = SavedVariable(indices, grad_fn.get()); | |
} | |
@@ -1651,13 +1860,16 @@ | |
} | |
Tensor VariableType::abs(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<AbsBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<AbsBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<AbsBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
} | |
auto ret = as_variable(baseType->abs(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "abs", { self }, ret ); | |
(void)n; | |
@@ -1666,16 +1878,18 @@ | |
} | |
Tensor & VariableType::m_abs_(Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<AbsBackward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<AbsBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<AbsBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self.clone(), nullptr); | |
} | |
baseType->m_abs_(self_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "abs_", { self }, self ); | |
(void)n; | |
@@ -1684,14 +1898,18 @@ | |
} | |
Tensor & VariableType::m_sigmoid_(Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<SigmoidBackward0>(); | |
+ check_inplace(self); | |
+ std::shared_ptr<SigmoidBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SigmoidBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
baseType->m_sigmoid_(self_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
+ if (grad_fn) { | |
grad_fn->result_ = SavedVariable(self, grad_fn.get()); | |
} | |
if (jit::tracer::isTracing({ self })) { | |
@@ -1702,11 +1920,16 @@ | |
} | |
Tensor VariableType::sigmoid(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<SigmoidBackward0>(); | |
+ std::shared_ptr<SigmoidBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SigmoidBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
auto ret = as_variable(baseType->sigmoid(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(ret, flags, grad_fn); | |
+ if (grad_fn) { | |
auto& result = ret; | |
grad_fn->result_ = SavedVariable(result, grad_fn.get()); | |
} | |
@@ -1718,16 +1941,18 @@ | |
} | |
Tensor & VariableType::m_log_(Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<LogBackward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<LogBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<LogBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self.clone(), nullptr); | |
} | |
baseType->m_log_(self_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "log_", { self }, self ); | |
(void)n; | |
@@ -1736,13 +1961,16 @@ | |
} | |
Tensor VariableType::log(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<LogBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<LogBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<LogBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
} | |
auto ret = as_variable(baseType->log(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "log", { self }, ret ); | |
(void)n; | |
@@ -1751,16 +1979,18 @@ | |
} | |
Tensor & VariableType::m_log1p_(Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<Log1PBackward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<Log1PBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<Log1PBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self.clone(), nullptr); | |
} | |
baseType->m_log1p_(self_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "log1p_", { self }, self ); | |
(void)n; | |
@@ -1769,13 +1999,16 @@ | |
} | |
Tensor VariableType::log1p(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<Log1PBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<Log1PBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<Log1PBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
} | |
auto ret = as_variable(baseType->log1p(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "log1p", { self }, ret ); | |
(void)n; | |
@@ -1784,10 +2017,15 @@ | |
} | |
Tensor VariableType::lgamma(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<LgammaBackward>(); | |
+ std::shared_ptr<LgammaBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<LgammaBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
auto ret = as_variable(baseType->lgamma(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "lgamma", { self }, ret ); | |
(void)n; | |
@@ -1796,13 +2034,17 @@ | |
} | |
Tensor & VariableType::m_lgamma_(Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<LgammaBackward>(); | |
+ check_inplace(self); | |
+ std::shared_ptr<LgammaBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<LgammaBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
baseType->m_lgamma_(self_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "lgamma_", { self }, self ); | |
(void)n; | |
@@ -1811,14 +2053,18 @@ | |
} | |
Tensor & VariableType::m_exp_(Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<ExpBackward>(); | |
+ check_inplace(self); | |
+ std::shared_ptr<ExpBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<ExpBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
baseType->m_exp_(self_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
+ if (grad_fn) { | |
grad_fn->result_ = SavedVariable(self, grad_fn.get()); | |
} | |
if (jit::tracer::isTracing({ self })) { | |
@@ -1829,11 +2075,16 @@ | |
} | |
Tensor VariableType::exp(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<ExpBackward>(); | |
+ std::shared_ptr<ExpBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<ExpBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
auto ret = as_variable(baseType->exp(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(ret, flags, grad_fn); | |
+ if (grad_fn) { | |
auto& result = ret; | |
grad_fn->result_ = SavedVariable(result, grad_fn.get()); | |
} | |
@@ -1845,16 +2096,18 @@ | |
} | |
Tensor & VariableType::m_cos_(Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<CosBackward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<CosBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<CosBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self.clone(), nullptr); | |
} | |
baseType->m_cos_(self_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "cos_", { self }, self ); | |
(void)n; | |
@@ -1863,13 +2116,16 @@ | |
} | |
Tensor VariableType::cos(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<CosBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<CosBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<CosBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
} | |
auto ret = as_variable(baseType->cos(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "cos", { self }, ret ); | |
(void)n; | |
@@ -1878,16 +2134,18 @@ | |
} | |
Tensor & VariableType::m_acos_(Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<AcosBackward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<AcosBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<AcosBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self.clone(), nullptr); | |
} | |
baseType->m_acos_(self_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "acos_", { self }, self ); | |
(void)n; | |
@@ -1896,13 +2154,16 @@ | |
} | |
Tensor VariableType::acos(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<AcosBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<AcosBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<AcosBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
} | |
auto ret = as_variable(baseType->acos(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "acos", { self }, ret ); | |
(void)n; | |
@@ -1911,16 +2172,18 @@ | |
} | |
Tensor & VariableType::m_cosh_(Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<CoshBackward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<CoshBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<CoshBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self.clone(), nullptr); | |
} | |
baseType->m_cosh_(self_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "cosh_", { self }, self ); | |
(void)n; | |
@@ -1929,13 +2192,16 @@ | |
} | |
Tensor VariableType::cosh(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<CoshBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<CoshBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<CoshBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
} | |
auto ret = as_variable(baseType->cosh(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "cosh", { self }, ret ); | |
(void)n; | |
@@ -1944,16 +2210,18 @@ | |
} | |
Tensor & VariableType::m_sin_(Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<SinBackward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<SinBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SinBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self.clone(), nullptr); | |
} | |
baseType->m_sin_(self_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "sin_", { self }, self ); | |
(void)n; | |
@@ -1962,13 +2230,16 @@ | |
} | |
Tensor VariableType::sin(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<SinBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<SinBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SinBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
} | |
auto ret = as_variable(baseType->sin(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "sin", { self }, ret ); | |
(void)n; | |
@@ -1977,16 +2248,18 @@ | |
} | |
Tensor & VariableType::m_asin_(Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<AsinBackward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<AsinBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<AsinBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self.clone(), nullptr); | |
} | |
baseType->m_asin_(self_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "asin_", { self }, self ); | |
(void)n; | |
@@ -1995,13 +2268,16 @@ | |
} | |
Tensor VariableType::asin(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<AsinBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<AsinBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<AsinBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
} | |
auto ret = as_variable(baseType->asin(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "asin", { self }, ret ); | |
(void)n; | |
@@ -2010,16 +2286,18 @@ | |
} | |
Tensor & VariableType::m_sinh_(Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<SinhBackward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<SinhBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SinhBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self.clone(), nullptr); | |
} | |
baseType->m_sinh_(self_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "sinh_", { self }, self ); | |
(void)n; | |
@@ -2028,13 +2306,16 @@ | |
} | |
Tensor VariableType::sinh(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<SinhBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<SinhBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SinhBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
} | |
auto ret = as_variable(baseType->sinh(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "sinh", { self }, ret ); | |
(void)n; | |
@@ -2043,16 +2324,18 @@ | |
} | |
Tensor & VariableType::m_tan_(Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<TanBackward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<TanBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<TanBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self.clone(), nullptr); | |
} | |
baseType->m_tan_(self_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "tan_", { self }, self ); | |
(void)n; | |
@@ -2061,13 +2344,16 @@ | |
} | |
Tensor VariableType::tan(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<TanBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<TanBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<TanBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
} | |
auto ret = as_variable(baseType->tan(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "tan", { self }, ret ); | |
(void)n; | |
@@ -2076,16 +2362,18 @@ | |
} | |
Tensor & VariableType::m_atan_(Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<AtanBackward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<AtanBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<AtanBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self.clone(), nullptr); | |
} | |
baseType->m_atan_(self_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "atan_", { self }, self ); | |
(void)n; | |
@@ -2094,13 +2382,16 @@ | |
} | |
Tensor VariableType::atan(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<AtanBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<AtanBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<AtanBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
} | |
auto ret = as_variable(baseType->atan(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "atan", { self }, ret ); | |
(void)n; | |
@@ -2109,14 +2400,18 @@ | |
} | |
Tensor & VariableType::m_tanh_(Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<TanhBackward0>(); | |
+ check_inplace(self); | |
+ std::shared_ptr<TanhBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<TanhBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
baseType->m_tanh_(self_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
+ if (grad_fn) { | |
grad_fn->result_ = SavedVariable(self, grad_fn.get()); | |
} | |
if (jit::tracer::isTracing({ self })) { | |
@@ -2127,11 +2422,16 @@ | |
} | |
Tensor VariableType::tanh(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<TanhBackward0>(); | |
+ std::shared_ptr<TanhBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<TanhBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
auto ret = as_variable(baseType->tanh(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(ret, flags, grad_fn); | |
+ if (grad_fn) { | |
auto& result = ret; | |
grad_fn->result_ = SavedVariable(result, grad_fn.get()); | |
} | |
@@ -2143,16 +2443,18 @@ | |
} | |
Tensor & VariableType::m_erf_(Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<ErfBackward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<ErfBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<ErfBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self.clone(), nullptr); | |
} | |
baseType->m_erf_(self_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "erf_", { self }, self ); | |
(void)n; | |
@@ -2161,13 +2463,16 @@ | |
} | |
Tensor VariableType::erf(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<ErfBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<ErfBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<ErfBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
} | |
auto ret = as_variable(baseType->erf(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "erf", { self }, ret ); | |
(void)n; | |
@@ -2176,16 +2481,18 @@ | |
} | |
Tensor & VariableType::m_erfinv_(Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<ErfinvBackward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<ErfinvBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<ErfinvBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self.clone(), nullptr); | |
} | |
baseType->m_erfinv_(self_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "erfinv_", { self }, self ); | |
(void)n; | |
@@ -2194,13 +2501,16 @@ | |
} | |
Tensor VariableType::erfinv(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<ErfinvBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<ErfinvBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<ErfinvBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
} | |
auto ret = as_variable(baseType->erfinv(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "erfinv", { self }, ret ); | |
(void)n; | |
@@ -2209,16 +2519,18 @@ | |
} | |
Tensor & VariableType::m_sqrt_(Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<SqrtBackward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<SqrtBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SqrtBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self.clone(), nullptr); | |
} | |
baseType->m_sqrt_(self_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "sqrt_", { self }, self ); | |
(void)n; | |
@@ -2227,13 +2539,16 @@ | |
} | |
Tensor VariableType::sqrt(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<SqrtBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<SqrtBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SqrtBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
} | |
auto ret = as_variable(baseType->sqrt(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "sqrt", { self }, ret ); | |
(void)n; | |
@@ -2242,14 +2557,18 @@ | |
} | |
Tensor & VariableType::m_rsqrt_(Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<RsqrtBackward>(); | |
+ check_inplace(self); | |
+ std::shared_ptr<RsqrtBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<RsqrtBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
baseType->m_rsqrt_(self_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
+ if (grad_fn) { | |
grad_fn->result_ = SavedVariable(self, grad_fn.get()); | |
} | |
if (jit::tracer::isTracing({ self })) { | |
@@ -2260,11 +2579,16 @@ | |
} | |
Tensor VariableType::rsqrt(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<RsqrtBackward>(); | |
+ std::shared_ptr<RsqrtBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<RsqrtBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
auto ret = as_variable(baseType->rsqrt(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(ret, flags, grad_fn); | |
+ if (grad_fn) { | |
auto& result = ret; | |
grad_fn->result_ = SavedVariable(result, grad_fn.get()); | |
} | |
@@ -2276,13 +2600,17 @@ | |
} | |
Tensor & VariableType::m_ceil_(Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<CeilBackward>(); | |
+ check_inplace(self); | |
+ std::shared_ptr<CeilBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<CeilBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
baseType->m_ceil_(self_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "ceil_", { self }, self ); | |
(void)n; | |
@@ -2291,10 +2619,15 @@ | |
} | |
Tensor VariableType::ceil(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<CeilBackward>(); | |
+ std::shared_ptr<CeilBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<CeilBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
auto ret = as_variable(baseType->ceil(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "ceil", { self }, ret ); | |
(void)n; | |
@@ -2303,13 +2636,17 @@ | |
} | |
Tensor & VariableType::m_floor_(Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<FloorBackward>(); | |
+ check_inplace(self); | |
+ std::shared_ptr<FloorBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<FloorBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
baseType->m_floor_(self_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "floor_", { self }, self ); | |
(void)n; | |
@@ -2318,10 +2655,15 @@ | |
} | |
Tensor VariableType::floor(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<FloorBackward>(); | |
+ std::shared_ptr<FloorBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<FloorBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
auto ret = as_variable(baseType->floor(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "floor", { self }, ret ); | |
(void)n; | |
@@ -2330,13 +2672,17 @@ | |
} | |
Tensor & VariableType::m_round_(Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<RoundBackward>(); | |
+ check_inplace(self); | |
+ std::shared_ptr<RoundBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<RoundBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
baseType->m_round_(self_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "round_", { self }, self ); | |
(void)n; | |
@@ -2345,10 +2691,15 @@ | |
} | |
Tensor VariableType::round(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<RoundBackward>(); | |
+ std::shared_ptr<RoundBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<RoundBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
auto ret = as_variable(baseType->round(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "round", { self }, ret ); | |
(void)n; | |
@@ -2357,13 +2708,17 @@ | |
} | |
Tensor & VariableType::m_trunc_(Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<TruncBackward>(); | |
+ check_inplace(self); | |
+ std::shared_ptr<TruncBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<TruncBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
baseType->m_trunc_(self_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "trunc_", { self }, self ); | |
(void)n; | |
@@ -2372,10 +2727,15 @@ | |
} | |
Tensor VariableType::trunc(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<TruncBackward>(); | |
+ std::shared_ptr<TruncBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<TruncBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
auto ret = as_variable(baseType->trunc(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "trunc", { self }, ret ); | |
(void)n; | |
@@ -2384,13 +2744,17 @@ | |
} | |
Tensor & VariableType::m_frac_(Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<FracBackward>(); | |
+ check_inplace(self); | |
+ std::shared_ptr<FracBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<FracBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
baseType->m_frac_(self_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "frac_", { self }, self ); | |
(void)n; | |
@@ -2399,10 +2763,15 @@ | |
} | |
Tensor VariableType::frac(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<FracBackward>(); | |
+ std::shared_ptr<FracBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<FracBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
auto ret = as_variable(baseType->frac(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "frac", { self }, ret ); | |
(void)n; | |
@@ -2411,16 +2780,19 @@ | |
} | |
Tensor VariableType::mean(const Tensor & self, int64_t dim, bool keepdim) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<MeanBackward0>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<MeanBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<MeanBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_sizes = self.sizes(); | |
grad_fn->self_argsize_dim = self.size(dim); | |
grad_fn->dim = dim; | |
grad_fn->keepdim = keepdim; | |
} | |
auto ret = as_variable(baseType->mean(self_, dim, keepdim)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "mean", { self }, ret ); | |
setattr(n, jit::stringToSymbol("dim"), dim); | |
@@ -2430,14 +2802,17 @@ | |
} | |
Scalar VariableType::mean(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<MeanBackward1>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<MeanBackward1> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<MeanBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_sizes = self.sizes(); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
} | |
auto ret = as_variable(baseType->mean(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "mean", { self }, ret ); | |
(void)n; | |
@@ -2458,16 +2833,19 @@ | |
} | |
Tensor VariableType::norm(const Tensor & self, Scalar p, int64_t dim, bool keepdim) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<NormBackward1>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<NormBackward1> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<NormBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
grad_fn->p = p; | |
grad_fn->dim = dim; | |
grad_fn->keepdim = keepdim; | |
} | |
auto ret = as_variable(baseType->norm(self_, p, dim, keepdim)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "norm", { self }, ret ); | |
setattr(n, jit::stringToSymbol("p"), p); | |
@@ -2478,14 +2856,17 @@ | |
} | |
Scalar VariableType::norm(const Tensor & self, Scalar p) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<NormBackward0>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<NormBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<NormBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
grad_fn->p = p; | |
} | |
auto ret = as_variable(baseType->norm(self_, p)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "norm", { self }, ret ); | |
setattr(n, jit::stringToSymbol("p"), p); | |
@@ -2501,15 +2882,18 @@ | |
Scalar VariableType::s_dist(const Tensor & self, const Tensor & other, Scalar p) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& other_ = unpack(other, "other", 1); | |
- auto flags = Function::flags({ self, other }); | |
- auto grad_fn = std::make_shared<DistBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<DistBackward> grad_fn; | |
+ auto flags = compute_flags({ self, other }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<DistBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, other }); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
grad_fn->other_ = SavedVariable(other, nullptr); | |
grad_fn->p = p; | |
} | |
auto ret = as_variable(baseType->s_dist(self_, other_, p)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self, other })) { | |
jit::Node *n = jit::tracer::recordTrace( "dist", { self, other }, ret ); | |
setattr(n, jit::stringToSymbol("p"), p); | |
@@ -2518,13 +2902,16 @@ | |
} | |
Tensor VariableType::reciprocal(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<ReciprocalBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<ReciprocalBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<ReciprocalBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
} | |
auto ret = as_variable(baseType->reciprocal(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "reciprocal", { self }, ret ); | |
(void)n; | |
@@ -2533,16 +2920,18 @@ | |
} | |
Tensor & VariableType::m_reciprocal_(Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<ReciprocalBackward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<ReciprocalBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<ReciprocalBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self.clone(), nullptr); | |
} | |
baseType->m_reciprocal_(self_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "reciprocal_", { self }, self ); | |
(void)n; | |
@@ -2551,10 +2940,15 @@ | |
} | |
Tensor VariableType::neg(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<NegBackward>(); | |
+ std::shared_ptr<NegBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<NegBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
auto ret = as_variable(baseType->neg(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "neg", { self }, ret ); | |
(void)n; | |
@@ -2563,13 +2957,17 @@ | |
} | |
Tensor & VariableType::m_neg_(Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<NegBackward>(); | |
+ check_inplace(self); | |
+ std::shared_ptr<NegBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<NegBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
baseType->m_neg_(self_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "neg_", { self }, self ); | |
(void)n; | |
@@ -2579,14 +2977,17 @@ | |
Tensor VariableType::s_atan2(const Tensor & self, const Tensor & other) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& other_ = unpack(other, "other", 1); | |
- auto flags = Function::flags({ self, other }); | |
- auto grad_fn = std::make_shared<Atan2Backward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<Atan2Backward> grad_fn; | |
+ auto flags = compute_flags({ self, other }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<Atan2Backward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, other }); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
grad_fn->other_ = SavedVariable(other, nullptr); | |
} | |
auto ret = as_variable(baseType->s_atan2(self_, other_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self, other })) { | |
jit::Node *n = jit::tracer::recordTrace( "atan2", { self, other }, ret ); | |
(void)n; | |
@@ -2596,17 +2997,19 @@ | |
Tensor & VariableType::s_m_atan2_(Tensor & self, const Tensor & other) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& other_ = unpack(other, "other", 1); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self, other }); | |
- auto grad_fn = std::make_shared<Atan2Backward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<Atan2Backward> grad_fn; | |
+ auto flags = compute_flags({ self, other }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<Atan2Backward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, other }); | |
grad_fn->self_ = SavedVariable(self.clone(), nullptr); | |
grad_fn->other_ = SavedVariable(other, nullptr); | |
} | |
baseType->s_m_atan2_(self_, other_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self, other })) { | |
jit::Node *n = jit::tracer::recordTrace( "atan2_", { self, other }, self ); | |
(void)n; | |
@@ -2615,14 +3018,17 @@ | |
} | |
Tensor VariableType::pow(const Tensor & self, Scalar exponent) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<PowBackward0>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<PowBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<PowBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
grad_fn->exponent = exponent; | |
} | |
auto ret = as_variable(baseType->pow(self_, exponent)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "pow", { self }, ret ); | |
setattr(n, jit::stringToSymbol("exponent"), exponent); | |
@@ -2632,14 +3038,17 @@ | |
Tensor VariableType::s_pow(const Tensor & self, const Tensor & exponent) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& exponent_ = unpack(exponent, "exponent", 1); | |
- auto flags = Function::flags({ self, exponent }); | |
- auto grad_fn = std::make_shared<PowBackward1>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<PowBackward1> grad_fn; | |
+ auto flags = compute_flags({ self, exponent }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<PowBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, exponent }); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
grad_fn->exponent_ = SavedVariable(exponent, nullptr); | |
} | |
auto ret = as_variable(baseType->s_pow(self_, exponent_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self, exponent })) { | |
jit::Node *n = jit::tracer::recordTrace( "pow", { self, exponent }, ret ); | |
(void)n; | |
@@ -2648,17 +3057,19 @@ | |
} | |
Tensor & VariableType::m_pow_(Tensor & self, Scalar exponent) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<PowBackward0>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<PowBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<PowBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_ = SavedVariable(self.clone(), nullptr); | |
grad_fn->exponent = exponent; | |
} | |
baseType->m_pow_(self_, exponent); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "pow_", { self }, self ); | |
setattr(n, jit::stringToSymbol("exponent"), exponent); | |
@@ -2668,17 +3079,19 @@ | |
Tensor & VariableType::s_m_pow_(Tensor & self, const Tensor & exponent) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& exponent_ = unpack(exponent, "exponent", 1); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self, exponent }); | |
- auto grad_fn = std::make_shared<PowBackward1>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<PowBackward1> grad_fn; | |
+ auto flags = compute_flags({ self, exponent }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<PowBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, exponent }); | |
grad_fn->self_ = SavedVariable(self.clone(), nullptr); | |
grad_fn->exponent_ = SavedVariable(exponent, nullptr); | |
} | |
baseType->s_m_pow_(self_, exponent_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self, exponent })) { | |
jit::Node *n = jit::tracer::recordTrace( "pow_", { self, exponent }, self ); | |
(void)n; | |
@@ -2688,13 +3101,16 @@ | |
Tensor VariableType::s_lerp(const Tensor & self, const Tensor & end, Scalar weight) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& end_ = unpack(end, "end", 1); | |
- auto flags = Function::flags({ self, end }); | |
- auto grad_fn = std::make_shared<LerpBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<LerpBackward> grad_fn; | |
+ auto flags = compute_flags({ self, end }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<LerpBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, end }); | |
grad_fn->weight = weight; | |
} | |
auto ret = as_variable(baseType->s_lerp(self_, end_, weight)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self, end })) { | |
jit::Node *n = jit::tracer::recordTrace( "lerp", { self, end }, ret ); | |
setattr(n, jit::stringToSymbol("weight"), weight); | |
@@ -2704,16 +3120,18 @@ | |
Tensor & VariableType::s_m_lerp_(Tensor & self, const Tensor & end, Scalar weight) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& end_ = unpack(end, "end", 1); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self, end }); | |
- auto grad_fn = std::make_shared<LerpBackward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<LerpBackward> grad_fn; | |
+ auto flags = compute_flags({ self, end }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<LerpBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, end }); | |
grad_fn->weight = weight; | |
} | |
baseType->s_m_lerp_(self_, end_, weight); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self, end })) { | |
jit::Node *n = jit::tracer::recordTrace( "lerp_", { self, end }, self ); | |
setattr(n, jit::stringToSymbol("weight"), weight); | |
@@ -2728,10 +3146,15 @@ | |
} | |
Tensor VariableType::histc(const Tensor & self, int64_t bins, Scalar min, Scalar max) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<HistcBackward>(); | |
+ std::shared_ptr<HistcBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<HistcBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
auto ret = as_variable(baseType->histc(self_, bins, min, max)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "histc", { self }, ret ); | |
setattr(n, jit::stringToSymbol("bins"), bins); | |
@@ -2742,13 +3165,17 @@ | |
} | |
Tensor & VariableType::m_zero_(Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<ZeroBackward>(); | |
+ check_inplace(self); | |
+ std::shared_ptr<ZeroBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<ZeroBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
baseType->m_zero_(self_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "zero_", { self }, self ); | |
(void)n; | |
@@ -2757,15 +3184,18 @@ | |
} | |
Tensor VariableType::sum(const Tensor & self, int64_t dim, bool keepdim) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<SumBackward1>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<SumBackward1> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SumBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_sizes = self.sizes(); | |
grad_fn->dim = dim; | |
grad_fn->keepdim = keepdim; | |
} | |
auto ret = as_variable(baseType->sum(self_, dim, keepdim)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "sum", { self }, ret ); | |
setattr(n, jit::stringToSymbol("dim"), dim); | |
@@ -2775,13 +3205,16 @@ | |
} | |
Scalar VariableType::sum(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<SumBackward0>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<SumBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SumBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_sizes = self.sizes(); | |
} | |
auto ret = as_variable(baseType->sum(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "sum", { self }, ret ); | |
(void)n; | |
@@ -2796,13 +3229,16 @@ | |
} | |
Tensor VariableType::cumsum(const Tensor & self, int64_t dim) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<CumsumBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<CumsumBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<CumsumBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->dim = dim; | |
} | |
auto ret = as_variable(baseType->cumsum(self_, dim)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "cumsum", { self }, ret ); | |
setattr(n, jit::stringToSymbol("dim"), dim); | |
@@ -2814,10 +3250,15 @@ | |
} | |
Tensor VariableType::sign(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<SignBackward>(); | |
+ std::shared_ptr<SignBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SignBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
auto ret = as_variable(baseType->sign(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "sign", { self }, ret ); | |
(void)n; | |
@@ -2826,13 +3267,17 @@ | |
} | |
Tensor & VariableType::m_sign_(Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<SignBackward>(); | |
+ check_inplace(self); | |
+ std::shared_ptr<SignBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SignBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
baseType->m_sign_(self_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "sign_", { self }, self ); | |
(void)n; | |
@@ -2841,13 +3286,16 @@ | |
} | |
Scalar VariableType::trace(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<TraceBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<TraceBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<TraceBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_sizes = self.sizes(); | |
} | |
auto ret = as_variable(baseType->trace(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "trace", { self }, ret ); | |
(void)n; | |
@@ -2856,10 +3304,15 @@ | |
} | |
Tensor VariableType::add(const Tensor & self, Scalar other, Scalar alpha) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<AddBackward0>(); | |
+ std::shared_ptr<AddBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<AddBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
auto ret = as_variable(baseType->add(self_, other, alpha)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "add", { self }, ret ); | |
setattr(n, jit::stringToSymbol("other"), other); | |
@@ -2870,13 +3323,16 @@ | |
Tensor VariableType::s_add(const Tensor & self, const Tensor & other, Scalar alpha) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& other_ = unpack(other, "other", 1); | |
- auto flags = Function::flags({ self, other }); | |
- auto grad_fn = std::make_shared<AddBackward1>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<AddBackward1> grad_fn; | |
+ auto flags = compute_flags({ self, other }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<AddBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, other }); | |
grad_fn->alpha = alpha; | |
} | |
auto ret = as_variable(baseType->s_add(self_, other_, alpha)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self, other })) { | |
jit::Node *n = jit::tracer::recordTrace( "add", { self, other }, ret ); | |
setattr(n, jit::stringToSymbol("alpha"), alpha); | |
@@ -2888,13 +3344,17 @@ | |
} | |
Tensor & VariableType::m_add_(Tensor & self, Scalar other, Scalar alpha) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<AddBackward0>(); | |
+ check_inplace(self); | |
+ std::shared_ptr<AddBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<AddBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
baseType->m_add_(self_, other, alpha); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "add_", { self }, self ); | |
setattr(n, jit::stringToSymbol("other"), other); | |
@@ -2905,16 +3365,18 @@ | |
Tensor & VariableType::s_m_add_(Tensor & self, const Tensor & other, Scalar alpha) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& other_ = unpack(other, "other", 1); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self, other }); | |
- auto grad_fn = std::make_shared<AddBackward1>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<AddBackward1> grad_fn; | |
+ auto flags = compute_flags({ self, other }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<AddBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, other }); | |
grad_fn->alpha = alpha; | |
} | |
baseType->s_m_add_(self_, other_, alpha); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self, other })) { | |
jit::Node *n = jit::tracer::recordTrace( "add_", { self, other }, self ); | |
setattr(n, jit::stringToSymbol("alpha"), alpha); | |
@@ -2926,10 +3388,15 @@ | |
} | |
Tensor VariableType::sub(const Tensor & self, Scalar other, Scalar alpha) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<SubBackward0>(); | |
+ std::shared_ptr<SubBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SubBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
auto ret = as_variable(baseType->sub(self_, other, alpha)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "sub", { self }, ret ); | |
setattr(n, jit::stringToSymbol("other"), other); | |
@@ -2940,13 +3407,16 @@ | |
Tensor VariableType::s_sub(const Tensor & self, const Tensor & other, Scalar alpha) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& other_ = unpack(other, "other", 1); | |
- auto flags = Function::flags({ self, other }); | |
- auto grad_fn = std::make_shared<SubBackward1>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<SubBackward1> grad_fn; | |
+ auto flags = compute_flags({ self, other }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SubBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, other }); | |
grad_fn->alpha = alpha; | |
} | |
auto ret = as_variable(baseType->s_sub(self_, other_, alpha)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self, other })) { | |
jit::Node *n = jit::tracer::recordTrace( "sub", { self, other }, ret ); | |
setattr(n, jit::stringToSymbol("alpha"), alpha); | |
@@ -2955,13 +3425,17 @@ | |
} | |
Tensor & VariableType::m_sub_(Tensor & self, Scalar other, Scalar alpha) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<SubBackward0>(); | |
+ check_inplace(self); | |
+ std::shared_ptr<SubBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SubBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
baseType->m_sub_(self_, other, alpha); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "sub_", { self }, self ); | |
setattr(n, jit::stringToSymbol("other"), other); | |
@@ -2972,16 +3446,18 @@ | |
Tensor & VariableType::s_m_sub_(Tensor & self, const Tensor & other, Scalar alpha) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& other_ = unpack(other, "other", 1); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self, other }); | |
- auto grad_fn = std::make_shared<SubBackward1>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<SubBackward1> grad_fn; | |
+ auto flags = compute_flags({ self, other }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SubBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, other }); | |
grad_fn->alpha = alpha; | |
} | |
baseType->s_m_sub_(self_, other_, alpha); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self, other })) { | |
jit::Node *n = jit::tracer::recordTrace( "sub_", { self, other }, self ); | |
setattr(n, jit::stringToSymbol("alpha"), alpha); | |
@@ -2990,13 +3466,16 @@ | |
} | |
Tensor VariableType::mul(const Tensor & self, Scalar other) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<MulBackward0>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<MulBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<MulBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->other = other; | |
} | |
auto ret = as_variable(baseType->mul(self_, other)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "mul", { self }, ret ); | |
setattr(n, jit::stringToSymbol("other"), other); | |
@@ -3006,14 +3485,17 @@ | |
Tensor VariableType::s_mul(const Tensor & self, const Tensor & other) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& other_ = unpack(other, "other", 1); | |
- auto flags = Function::flags({ self, other }); | |
- auto grad_fn = std::make_shared<MulBackward1>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<MulBackward1> grad_fn; | |
+ auto flags = compute_flags({ self, other }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<MulBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, other }); | |
grad_fn->other_ = SavedVariable(other, nullptr); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
} | |
auto ret = as_variable(baseType->s_mul(self_, other_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self, other })) { | |
jit::Node *n = jit::tracer::recordTrace( "mul", { self, other }, ret ); | |
(void)n; | |
@@ -3022,16 +3504,18 @@ | |
} | |
Tensor & VariableType::m_mul_(Tensor & self, Scalar other) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<MulBackward0>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<MulBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<MulBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->other = other; | |
} | |
baseType->m_mul_(self_, other); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "mul_", { self }, self ); | |
setattr(n, jit::stringToSymbol("other"), other); | |
@@ -3041,17 +3525,19 @@ | |
Tensor & VariableType::s_m_mul_(Tensor & self, const Tensor & other) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& other_ = unpack(other, "other", 1); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self, other }); | |
- auto grad_fn = std::make_shared<MulBackward1>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<MulBackward1> grad_fn; | |
+ auto flags = compute_flags({ self, other }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<MulBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, other }); | |
grad_fn->other_ = SavedVariable(other, nullptr); | |
grad_fn->self_ = SavedVariable(self.clone(), nullptr); | |
} | |
baseType->s_m_mul_(self_, other_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self, other })) { | |
jit::Node *n = jit::tracer::recordTrace( "mul_", { self, other }, self ); | |
(void)n; | |
@@ -3060,13 +3546,16 @@ | |
} | |
Tensor VariableType::div(const Tensor & self, Scalar other) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<DivBackward0>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<DivBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<DivBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->other = other; | |
} | |
auto ret = as_variable(baseType->div(self_, other)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "div", { self }, ret ); | |
setattr(n, jit::stringToSymbol("other"), other); | |
@@ -3076,14 +3565,17 @@ | |
Tensor VariableType::s_div(const Tensor & self, const Tensor & other) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& other_ = unpack(other, "other", 1); | |
- auto flags = Function::flags({ self, other }); | |
- auto grad_fn = std::make_shared<DivBackward1>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<DivBackward1> grad_fn; | |
+ auto flags = compute_flags({ self, other }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<DivBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, other }); | |
grad_fn->other_ = SavedVariable(other, nullptr); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
} | |
auto ret = as_variable(baseType->s_div(self_, other_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self, other })) { | |
jit::Node *n = jit::tracer::recordTrace( "div", { self, other }, ret ); | |
(void)n; | |
@@ -3092,16 +3584,18 @@ | |
} | |
Tensor & VariableType::m_div_(Tensor & self, Scalar other) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<DivBackward0>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<DivBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<DivBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->other = other; | |
} | |
baseType->m_div_(self_, other); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "div_", { self }, self ); | |
setattr(n, jit::stringToSymbol("other"), other); | |
@@ -3111,17 +3605,19 @@ | |
Tensor & VariableType::s_m_div_(Tensor & self, const Tensor & other) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& other_ = unpack(other, "other", 1); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self, other }); | |
- auto grad_fn = std::make_shared<DivBackward1>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<DivBackward1> grad_fn; | |
+ auto flags = compute_flags({ self, other }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<DivBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, other }); | |
grad_fn->other_ = SavedVariable(other, nullptr); | |
grad_fn->self_ = SavedVariable(self.clone(), nullptr); | |
} | |
baseType->s_m_div_(self_, other_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self, other })) { | |
jit::Node *n = jit::tracer::recordTrace( "div_", { self, other }, self ); | |
(void)n; | |
@@ -3130,10 +3626,15 @@ | |
} | |
Tensor VariableType::fmod(const Tensor & self, Scalar other) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<FmodBackward0>(); | |
+ std::shared_ptr<FmodBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<FmodBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
auto ret = as_variable(baseType->fmod(self_, other)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "fmod", { self }, ret ); | |
setattr(n, jit::stringToSymbol("other"), other); | |
@@ -3143,13 +3644,16 @@ | |
Tensor VariableType::s_fmod(const Tensor & self, const Tensor & other) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& other_ = unpack(other, "other", 1); | |
- auto flags = Function::flags({ self, other }); | |
- auto grad_fn = std::make_shared<FmodBackward1>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<FmodBackward1> grad_fn; | |
+ auto flags = compute_flags({ self, other }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<FmodBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, other }); | |
grad_fn->other_ = SavedVariable(other, nullptr); | |
} | |
auto ret = as_variable(baseType->s_fmod(self_, other_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self, other })) { | |
jit::Node *n = jit::tracer::recordTrace( "fmod", { self, other }, ret ); | |
(void)n; | |
@@ -3158,13 +3662,17 @@ | |
} | |
Tensor & VariableType::m_fmod_(Tensor & self, Scalar other) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<FmodBackward0>(); | |
+ check_inplace(self); | |
+ std::shared_ptr<FmodBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<FmodBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
baseType->m_fmod_(self_, other); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "fmod_", { self }, self ); | |
setattr(n, jit::stringToSymbol("other"), other); | |
@@ -3174,16 +3682,18 @@ | |
Tensor & VariableType::s_m_fmod_(Tensor & self, const Tensor & other) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& other_ = unpack(other, "other", 1); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self, other }); | |
- auto grad_fn = std::make_shared<FmodBackward1>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<FmodBackward1> grad_fn; | |
+ auto flags = compute_flags({ self, other }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<FmodBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, other }); | |
grad_fn->other_ = SavedVariable(other, nullptr); | |
} | |
baseType->s_m_fmod_(self_, other_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self, other })) { | |
jit::Node *n = jit::tracer::recordTrace( "fmod_", { self, other }, self ); | |
(void)n; | |
@@ -3192,10 +3702,15 @@ | |
} | |
Tensor VariableType::remainder(const Tensor & self, Scalar other) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<RemainderBackward0>(); | |
+ std::shared_ptr<RemainderBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<RemainderBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
auto ret = as_variable(baseType->remainder(self_, other)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "remainder", { self }, ret ); | |
setattr(n, jit::stringToSymbol("other"), other); | |
@@ -3205,25 +3720,35 @@ | |
Tensor VariableType::s_remainder(const Tensor & self, const Tensor & other) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& other_ = unpack(other, "other", 1); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<RemainderBackward1>(); | |
+ check_no_requires_grad(other, "other"); | |
+ std::shared_ptr<RemainderBackward1> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<RemainderBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
auto ret = as_variable(baseType->s_remainder(self_, other_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ self })) { | |
- jit::Node *n = jit::tracer::recordTrace( "remainder", { self }, ret ); | |
+ set_flags(ret, flags, grad_fn); | |
+ if (jit::tracer::isTracing({ self, other })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "remainder", { self, other }, ret ); | |
(void)n; | |
} | |
return Tensor(std::move(ret)); | |
} | |
Tensor & VariableType::m_remainder_(Tensor & self, Scalar other) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<RemainderBackward0>(); | |
+ check_inplace(self); | |
+ std::shared_ptr<RemainderBackward0> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<RemainderBackward0>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
baseType->m_remainder_(self_, other); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "remainder_", { self }, self ); | |
setattr(n, jit::stringToSymbol("other"), other); | |
@@ -3233,15 +3758,20 @@ | |
Tensor & VariableType::s_m_remainder_(Tensor & self, const Tensor & other) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& other_ = unpack(other, "other", 1); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<RemainderBackward1>(); | |
+ check_inplace(self); | |
+ check_no_requires_grad(other, "other"); | |
+ std::shared_ptr<RemainderBackward1> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<RemainderBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
baseType->s_m_remainder_(self_, other_); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ self })) { | |
- jit::Node *n = jit::tracer::recordTrace( "remainder_", { self }, self ); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
+ if (jit::tracer::isTracing({ self, other })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "remainder_", { self, other }, self ); | |
(void)n; | |
} | |
return self; | |
@@ -3261,14 +3791,17 @@ | |
Scalar VariableType::dot(const Tensor & self, const Tensor & tensor) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& tensor_ = unpack(tensor, "tensor", 1); | |
- auto flags = Function::flags({ self, tensor }); | |
- auto grad_fn = std::make_shared<DotBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<DotBackward> grad_fn; | |
+ auto flags = compute_flags({ self, tensor }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<DotBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, tensor }); | |
grad_fn->tensor_ = SavedVariable(tensor, nullptr); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
} | |
auto ret = as_variable(baseType->dot(self_, tensor_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self, tensor })) { | |
jit::Node *n = jit::tracer::recordTrace( "dot", { self, tensor }, ret ); | |
(void)n; | |
@@ -3277,13 +3810,16 @@ | |
} | |
Tensor VariableType::tril(const Tensor & self, int64_t diagonal) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<TrilBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<TrilBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<TrilBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->diagonal = diagonal; | |
} | |
auto ret = as_variable(baseType->tril(self_, diagonal)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "tril", { self }, ret ); | |
setattr(n, jit::stringToSymbol("diagonal"), diagonal); | |
@@ -3292,16 +3828,18 @@ | |
} | |
Tensor & VariableType::m_tril_(Tensor & self, int64_t diagonal) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<TrilBackward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<TrilBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<TrilBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->diagonal = diagonal; | |
} | |
baseType->m_tril_(self_, diagonal); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "tril_", { self }, self ); | |
setattr(n, jit::stringToSymbol("diagonal"), diagonal); | |
@@ -3310,13 +3848,16 @@ | |
} | |
Tensor VariableType::triu(const Tensor & self, int64_t diagonal) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<TriuBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<TriuBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<TriuBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->diagonal = diagonal; | |
} | |
auto ret = as_variable(baseType->triu(self_, diagonal)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "triu", { self }, ret ); | |
setattr(n, jit::stringToSymbol("diagonal"), diagonal); | |
@@ -3325,16 +3866,18 @@ | |
} | |
Tensor & VariableType::m_triu_(Tensor & self, int64_t diagonal) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<TriuBackward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<TriuBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<TriuBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->diagonal = diagonal; | |
} | |
baseType->m_triu_(self_, diagonal); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "triu_", { self }, self ); | |
setattr(n, jit::stringToSymbol("diagonal"), diagonal); | |
@@ -3344,15 +3887,18 @@ | |
Tensor VariableType::cross(const Tensor & self, const Tensor & other, int64_t dim) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& other_ = unpack(other, "other", 1); | |
- auto flags = Function::flags({ self, other }); | |
- auto grad_fn = std::make_shared<CrossBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<CrossBackward> grad_fn; | |
+ auto flags = compute_flags({ self, other }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<CrossBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, other }); | |
grad_fn->other_ = SavedVariable(other, nullptr); | |
grad_fn->dim = dim; | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
} | |
auto ret = as_variable(baseType->cross(self_, other_, dim)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self, other })) { | |
jit::Node *n = jit::tracer::recordTrace( "cross", { self, other }, ret ); | |
setattr(n, jit::stringToSymbol("dim"), dim); | |
@@ -3364,13 +3910,16 @@ | |
} | |
Tensor VariableType::diag(const Tensor & self, int64_t diagonal) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<DiagBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<DiagBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<DiagBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->diagonal = diagonal; | |
} | |
auto ret = as_variable(baseType->diag(self_, diagonal)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "diag", { self }, ret ); | |
setattr(n, jit::stringToSymbol("diagonal"), diagonal); | |
@@ -3381,16 +3930,19 @@ | |
auto& self_ = unpack(self, "self", 0); | |
auto& mat1_ = unpack(mat1, "mat1", 1); | |
auto& mat2_ = unpack(mat2, "mat2", 2); | |
- auto flags = Function::flags({ self, mat1, mat2 }); | |
- auto grad_fn = std::make_shared<AddmmBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<AddmmBackward> grad_fn; | |
+ auto flags = compute_flags({ self, mat1, mat2 }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<AddmmBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, mat1, mat2 }); | |
grad_fn->beta = beta; | |
grad_fn->mat1_ = SavedVariable(mat1, nullptr); | |
grad_fn->mat2_ = SavedVariable(mat2, nullptr); | |
grad_fn->alpha = alpha; | |
} | |
auto ret = as_variable(baseType->s_addmm(self_, mat1_, mat2_, beta, alpha)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self, mat1, mat2 })) { | |
jit::Node *n = jit::tracer::recordTrace( "addmm", { self, mat1, mat2 }, ret ); | |
setattr(n, jit::stringToSymbol("beta"), beta); | |
@@ -3402,19 +3954,21 @@ | |
auto& self_ = unpack(self, "self", 0); | |
auto& mat1_ = unpack(mat1, "mat1", 1); | |
auto& mat2_ = unpack(mat2, "mat2", 2); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self, mat1, mat2 }); | |
- auto grad_fn = std::make_shared<AddmmBackward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<AddmmBackward> grad_fn; | |
+ auto flags = compute_flags({ self, mat1, mat2 }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<AddmmBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, mat1, mat2 }); | |
grad_fn->beta = beta; | |
grad_fn->mat1_ = SavedVariable(mat1, nullptr); | |
grad_fn->mat2_ = SavedVariable(mat2, nullptr); | |
grad_fn->alpha = alpha; | |
} | |
baseType->m_addmm_(self_, mat1_, mat2_, beta, alpha); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self, mat1, mat2 })) { | |
jit::Node *n = jit::tracer::recordTrace( "addmm_", { self, mat1, mat2 }, self ); | |
setattr(n, jit::stringToSymbol("beta"), beta); | |
@@ -3426,16 +3980,19 @@ | |
auto& self_ = unpack(self, "self", 0); | |
auto& mat_ = unpack(mat, "mat", 1); | |
auto& vec_ = unpack(vec, "vec", 2); | |
- auto flags = Function::flags({ self, mat, vec }); | |
- auto grad_fn = std::make_shared<AddmvBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<AddmvBackward> grad_fn; | |
+ auto flags = compute_flags({ self, mat, vec }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<AddmvBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, mat, vec }); | |
grad_fn->beta = beta; | |
grad_fn->vec_ = SavedVariable(vec, nullptr); | |
grad_fn->alpha = alpha; | |
grad_fn->mat_ = SavedVariable(mat, nullptr); | |
} | |
auto ret = as_variable(baseType->s_addmv(self_, mat_, vec_, beta, alpha)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self, mat, vec })) { | |
jit::Node *n = jit::tracer::recordTrace( "addmv", { self, mat, vec }, ret ); | |
setattr(n, jit::stringToSymbol("beta"), beta); | |
@@ -3447,19 +4004,21 @@ | |
auto& self_ = unpack(self, "self", 0); | |
auto& mat_ = unpack(mat, "mat", 1); | |
auto& vec_ = unpack(vec, "vec", 2); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self, mat, vec }); | |
- auto grad_fn = std::make_shared<AddmvBackward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<AddmvBackward> grad_fn; | |
+ auto flags = compute_flags({ self, mat, vec }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<AddmvBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, mat, vec }); | |
grad_fn->beta = beta; | |
grad_fn->vec_ = SavedVariable(vec, nullptr); | |
grad_fn->alpha = alpha; | |
grad_fn->mat_ = SavedVariable(mat, nullptr); | |
} | |
baseType->m_addmv_(self_, mat_, vec_, beta, alpha); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self, mat, vec })) { | |
jit::Node *n = jit::tracer::recordTrace( "addmv_", { self, mat, vec }, self ); | |
setattr(n, jit::stringToSymbol("beta"), beta); | |
@@ -3471,16 +4030,19 @@ | |
auto& self_ = unpack(self, "self", 0); | |
auto& vec1_ = unpack(vec1, "vec1", 1); | |
auto& vec2_ = unpack(vec2, "vec2", 2); | |
- auto flags = Function::flags({ self, vec1, vec2 }); | |
- auto grad_fn = std::make_shared<AddrBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<AddrBackward> grad_fn; | |
+ auto flags = compute_flags({ self, vec1, vec2 }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<AddrBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, vec1, vec2 }); | |
grad_fn->beta = beta; | |
grad_fn->vec2_ = SavedVariable(vec2, nullptr); | |
grad_fn->alpha = alpha; | |
grad_fn->vec1_ = SavedVariable(vec1, nullptr); | |
} | |
auto ret = as_variable(baseType->s_addr(self_, vec1_, vec2_, beta, alpha)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self, vec1, vec2 })) { | |
jit::Node *n = jit::tracer::recordTrace( "addr", { self, vec1, vec2 }, ret ); | |
setattr(n, jit::stringToSymbol("beta"), beta); | |
@@ -3492,19 +4054,21 @@ | |
auto& self_ = unpack(self, "self", 0); | |
auto& vec1_ = unpack(vec1, "vec1", 1); | |
auto& vec2_ = unpack(vec2, "vec2", 2); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self, vec1, vec2 }); | |
- auto grad_fn = std::make_shared<AddrBackward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<AddrBackward> grad_fn; | |
+ auto flags = compute_flags({ self, vec1, vec2 }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<AddrBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, vec1, vec2 }); | |
grad_fn->beta = beta; | |
grad_fn->vec2_ = SavedVariable(vec2, nullptr); | |
grad_fn->alpha = alpha; | |
grad_fn->vec1_ = SavedVariable(vec1, nullptr); | |
} | |
baseType->m_addr_(self_, vec1_, vec2_, beta, alpha); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self, vec1, vec2 })) { | |
jit::Node *n = jit::tracer::recordTrace( "addr_", { self, vec1, vec2 }, self ); | |
setattr(n, jit::stringToSymbol("beta"), beta); | |
@@ -3515,14 +4079,17 @@ | |
Tensor VariableType::ger(const Tensor & self, const Tensor & vec2) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& vec2_ = unpack(vec2, "vec2", 1); | |
- auto flags = Function::flags({ self, vec2 }); | |
- auto grad_fn = std::make_shared<GerBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<GerBackward> grad_fn; | |
+ auto flags = compute_flags({ self, vec2 }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<GerBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, vec2 }); | |
grad_fn->vec2_ = SavedVariable(vec2, nullptr); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
} | |
auto ret = as_variable(baseType->ger(self_, vec2_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self, vec2 })) { | |
jit::Node *n = jit::tracer::recordTrace( "ger", { self, vec2 }, ret ); | |
(void)n; | |
@@ -3532,14 +4099,17 @@ | |
Tensor VariableType::mv(const Tensor & self, const Tensor & vec) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& vec_ = unpack(vec, "vec", 1); | |
- auto flags = Function::flags({ self, vec }); | |
- auto grad_fn = std::make_shared<MvBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<MvBackward> grad_fn; | |
+ auto flags = compute_flags({ self, vec }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<MvBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, vec }); | |
grad_fn->vec_ = SavedVariable(vec, nullptr); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
} | |
auto ret = as_variable(baseType->mv(self_, vec_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self, vec })) { | |
jit::Node *n = jit::tracer::recordTrace( "mv", { self, vec }, ret ); | |
(void)n; | |
@@ -3549,14 +4119,17 @@ | |
Tensor VariableType::mm(const Tensor & self, const Tensor & mat2) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& mat2_ = unpack(mat2, "mat2", 1); | |
- auto flags = Function::flags({ self, mat2 }); | |
- auto grad_fn = std::make_shared<MmBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<MmBackward> grad_fn; | |
+ auto flags = compute_flags({ self, mat2 }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<MmBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, mat2 }); | |
grad_fn->mat2_ = SavedVariable(mat2, nullptr); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
} | |
auto ret = as_variable(baseType->mm(self_, mat2_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self, mat2 })) { | |
jit::Node *n = jit::tracer::recordTrace( "mm", { self, mat2 }, ret ); | |
(void)n; | |
@@ -3566,14 +4139,17 @@ | |
Tensor VariableType::bmm(const Tensor & self, const Tensor & mat2) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& mat2_ = unpack(mat2, "mat2", 1); | |
- auto flags = Function::flags({ self, mat2 }); | |
- auto grad_fn = std::make_shared<BmmBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<BmmBackward> grad_fn; | |
+ auto flags = compute_flags({ self, mat2 }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<BmmBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, mat2 }); | |
grad_fn->mat2_ = SavedVariable(mat2, nullptr); | |
grad_fn->self_ = SavedVariable(self, nullptr); | |
} | |
auto ret = as_variable(baseType->bmm(self_, mat2_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self, mat2 })) { | |
jit::Node *n = jit::tracer::recordTrace( "bmm", { self, mat2 }, ret ); | |
(void)n; | |
@@ -3584,9 +4160,12 @@ | |
auto& self_ = unpack(self, "self", 0); | |
auto& batch1_ = unpack(batch1, "batch1", 1); | |
auto& batch2_ = unpack(batch2, "batch2", 2); | |
- auto flags = Function::flags({ self, batch1, batch2 }); | |
- auto grad_fn = std::make_shared<AddbmmBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<AddbmmBackward> grad_fn; | |
+ auto flags = compute_flags({ self, batch1, batch2 }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<AddbmmBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, batch1, batch2 }); | |
grad_fn->beta = beta; | |
grad_fn->batch1_argsize_0 = batch1.size(0); | |
grad_fn->batch1_argsize_1 = batch1.size(1); | |
@@ -3596,7 +4175,7 @@ | |
grad_fn->batch1_ = SavedVariable(batch1, nullptr); | |
} | |
auto ret = as_variable(baseType->s_addbmm(self_, batch1_, batch2_, beta, alpha)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self, batch1, batch2 })) { | |
jit::Node *n = jit::tracer::recordTrace( "addbmm", { self, batch1, batch2 }, ret ); | |
setattr(n, jit::stringToSymbol("beta"), beta); | |
@@ -3608,11 +4187,13 @@ | |
auto& self_ = unpack(self, "self", 0); | |
auto& batch1_ = unpack(batch1, "batch1", 1); | |
auto& batch2_ = unpack(batch2, "batch2", 2); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self, batch1, batch2 }); | |
- auto grad_fn = std::make_shared<AddbmmBackward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<AddbmmBackward> grad_fn; | |
+ auto flags = compute_flags({ self, batch1, batch2 }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<AddbmmBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, batch1, batch2 }); | |
grad_fn->beta = beta; | |
grad_fn->batch1_argsize_0 = batch1.size(0); | |
grad_fn->batch1_argsize_1 = batch1.size(1); | |
@@ -3622,8 +4203,8 @@ | |
grad_fn->batch1_ = SavedVariable(batch1, nullptr); | |
} | |
baseType->m_addbmm_(self_, batch1_, batch2_, beta, alpha); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self, batch1, batch2 })) { | |
jit::Node *n = jit::tracer::recordTrace( "addbmm_", { self, batch1, batch2 }, self ); | |
setattr(n, jit::stringToSymbol("beta"), beta); | |
@@ -3635,16 +4216,19 @@ | |
auto& self_ = unpack(self, "self", 0); | |
auto& batch1_ = unpack(batch1, "batch1", 1); | |
auto& batch2_ = unpack(batch2, "batch2", 2); | |
- auto flags = Function::flags({ self, batch1, batch2 }); | |
- auto grad_fn = std::make_shared<BaddbmmBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<BaddbmmBackward> grad_fn; | |
+ auto flags = compute_flags({ self, batch1, batch2 }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<BaddbmmBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, batch1, batch2 }); | |
grad_fn->beta = beta; | |
grad_fn->batch2_ = SavedVariable(batch2, nullptr); | |
grad_fn->alpha = alpha; | |
grad_fn->batch1_ = SavedVariable(batch1, nullptr); | |
} | |
auto ret = as_variable(baseType->s_baddbmm(self_, batch1_, batch2_, beta, alpha)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self, batch1, batch2 })) { | |
jit::Node *n = jit::tracer::recordTrace( "baddbmm", { self, batch1, batch2 }, ret ); | |
setattr(n, jit::stringToSymbol("beta"), beta); | |
@@ -3656,19 +4240,21 @@ | |
auto& self_ = unpack(self, "self", 0); | |
auto& batch1_ = unpack(batch1, "batch1", 1); | |
auto& batch2_ = unpack(batch2, "batch2", 2); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self, batch1, batch2 }); | |
- auto grad_fn = std::make_shared<BaddbmmBackward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<BaddbmmBackward> grad_fn; | |
+ auto flags = compute_flags({ self, batch1, batch2 }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<BaddbmmBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, batch1, batch2 }); | |
grad_fn->beta = beta; | |
grad_fn->batch2_ = SavedVariable(batch2, nullptr); | |
grad_fn->alpha = alpha; | |
grad_fn->batch1_ = SavedVariable(batch1, nullptr); | |
} | |
baseType->m_baddbmm_(self_, batch1_, batch2_, beta, alpha); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self, batch1, batch2 })) { | |
jit::Node *n = jit::tracer::recordTrace( "baddbmm_", { self, batch1, batch2 }, self ); | |
setattr(n, jit::stringToSymbol("beta"), beta); | |
@@ -3680,15 +4266,18 @@ | |
auto& self_ = unpack(self, "self", 0); | |
auto& tensor1_ = unpack(tensor1, "tensor1", 1); | |
auto& tensor2_ = unpack(tensor2, "tensor2", 2); | |
- auto flags = Function::flags({ self, tensor1, tensor2 }); | |
- auto grad_fn = std::make_shared<AddcmulBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<AddcmulBackward> grad_fn; | |
+ auto flags = compute_flags({ self, tensor1, tensor2 }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<AddcmulBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, tensor1, tensor2 }); | |
grad_fn->tensor2_ = SavedVariable(tensor2, nullptr); | |
grad_fn->value = value; | |
grad_fn->tensor1_ = SavedVariable(tensor1, nullptr); | |
} | |
auto ret = as_variable(baseType->s_addcmul(self_, tensor1_, tensor2_, value)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self, tensor1, tensor2 })) { | |
jit::Node *n = jit::tracer::recordTrace( "addcmul", { self, tensor1, tensor2 }, ret ); | |
setattr(n, jit::stringToSymbol("value"), value); | |
@@ -3699,18 +4288,20 @@ | |
auto& self_ = unpack(self, "self", 0); | |
auto& tensor1_ = unpack(tensor1, "tensor1", 1); | |
auto& tensor2_ = unpack(tensor2, "tensor2", 2); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self, tensor1, tensor2 }); | |
- auto grad_fn = std::make_shared<AddcmulBackward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<AddcmulBackward> grad_fn; | |
+ auto flags = compute_flags({ self, tensor1, tensor2 }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<AddcmulBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, tensor1, tensor2 }); | |
grad_fn->tensor2_ = SavedVariable(tensor2, nullptr); | |
grad_fn->value = value; | |
grad_fn->tensor1_ = SavedVariable(tensor1, nullptr); | |
} | |
baseType->s_m_addcmul_(self_, tensor1_, tensor2_, value); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self, tensor1, tensor2 })) { | |
jit::Node *n = jit::tracer::recordTrace( "addcmul_", { self, tensor1, tensor2 }, self ); | |
setattr(n, jit::stringToSymbol("value"), value); | |
@@ -3721,15 +4312,18 @@ | |
auto& self_ = unpack(self, "self", 0); | |
auto& tensor1_ = unpack(tensor1, "tensor1", 1); | |
auto& tensor2_ = unpack(tensor2, "tensor2", 2); | |
- auto flags = Function::flags({ self, tensor1, tensor2 }); | |
- auto grad_fn = std::make_shared<AddcdivBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<AddcdivBackward> grad_fn; | |
+ auto flags = compute_flags({ self, tensor1, tensor2 }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<AddcdivBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, tensor1, tensor2 }); | |
grad_fn->tensor2_ = SavedVariable(tensor2, nullptr); | |
grad_fn->value = value; | |
grad_fn->tensor1_ = SavedVariable(tensor1, nullptr); | |
} | |
auto ret = as_variable(baseType->s_addcdiv(self_, tensor1_, tensor2_, value)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self, tensor1, tensor2 })) { | |
jit::Node *n = jit::tracer::recordTrace( "addcdiv", { self, tensor1, tensor2 }, ret ); | |
setattr(n, jit::stringToSymbol("value"), value); | |
@@ -3740,18 +4334,20 @@ | |
auto& self_ = unpack(self, "self", 0); | |
auto& tensor1_ = unpack(tensor1, "tensor1", 1); | |
auto& tensor2_ = unpack(tensor2, "tensor2", 2); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self, tensor1, tensor2 }); | |
- auto grad_fn = std::make_shared<AddcdivBackward>(); | |
- if (flags.is_executable) { | |
+ check_inplace(self); | |
+ std::shared_ptr<AddcdivBackward> grad_fn; | |
+ auto flags = compute_flags({ self, tensor1, tensor2 }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<AddcdivBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, tensor1, tensor2 }); | |
grad_fn->tensor2_ = SavedVariable(tensor2, nullptr); | |
grad_fn->value = value; | |
grad_fn->tensor1_ = SavedVariable(tensor1, nullptr); | |
} | |
baseType->s_m_addcdiv_(self_, tensor1_, tensor2_, value); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
if (jit::tracer::isTracing({ self, tensor1, tensor2 })) { | |
jit::Node *n = jit::tracer::recordTrace( "addcdiv_", { self, tensor1, tensor2 }, self ); | |
setattr(n, jit::stringToSymbol("value"), value); | |
@@ -3761,14 +4357,17 @@ | |
std::tuple<Tensor,Tensor> VariableType::gesv(const Tensor & self, const Tensor & A) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& A_ = unpack(A, "A", 1); | |
- auto flags = Function::flags({ self, A }); | |
- auto grad_fn = std::make_shared<GesvBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<GesvBackward> grad_fn; | |
+ auto flags = compute_flags({ self, A }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<GesvBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, A }); | |
grad_fn->A_ = SavedVariable(A, nullptr); | |
} | |
auto ret = as_variable(baseType->gesv(self_, A_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(std::get<0>(ret), flags, grad_fn); | |
+ if (grad_fn) { | |
auto& solution = std::get<0>(ret); | |
grad_fn->solution_ = SavedVariable(solution, grad_fn.get()); | |
} | |
@@ -3781,10 +4380,15 @@ | |
std::tuple<Tensor,Tensor> VariableType::gels(const Tensor & self, const Tensor & A) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& A_ = unpack(A, "A", 1); | |
- auto flags = Function::flags({ self, A }); | |
- auto grad_fn = std::make_shared<GelsBackward>(); | |
+ std::shared_ptr<GelsBackward> grad_fn; | |
+ auto flags = compute_flags({ self, A }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<GelsBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, A }); | |
+ } | |
auto ret = as_variable(baseType->gels(self_, A_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(std::get<0>(ret), flags, grad_fn); | |
if (jit::tracer::isTracing({ self, A })) { | |
jit::Node *n = jit::tracer::recordTrace( "gels", { self, A }, ret ); | |
(void)n; | |
@@ -3794,12 +4398,18 @@ | |
std::tuple<Tensor,Tensor> VariableType::trtrs(const Tensor & self, const Tensor & A, bool upper, bool transpose, bool unitriangular) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& A_ = unpack(A, "A", 1); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<TrtrsBackward>(); | |
+ check_no_requires_grad(A, "A"); | |
+ std::shared_ptr<TrtrsBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<TrtrsBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
auto ret = as_variable(baseType->trtrs(self_, A_, upper, transpose, unitriangular)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ self })) { | |
- jit::Node *n = jit::tracer::recordTrace( "trtrs", { self }, ret ); | |
+ set_flags(std::get<0>(ret), flags, grad_fn); | |
+ if (jit::tracer::isTracing({ self, A })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "trtrs", { self, A }, ret ); | |
setattr(n, jit::stringToSymbol("upper"), upper); | |
setattr(n, jit::stringToSymbol("transpose"), transpose); | |
setattr(n, jit::stringToSymbol("unitriangular"), unitriangular); | |
@@ -3808,10 +4418,15 @@ | |
} | |
std::tuple<Tensor,Tensor> VariableType::symeig(const Tensor & self, bool eigenvectors, bool upper) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<SymeigBackward>(); | |
+ std::shared_ptr<SymeigBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SymeigBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
auto ret = as_variable(baseType->symeig(self_, eigenvectors, upper)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(std::get<0>(ret), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "symeig", { self }, ret ); | |
setattr(n, jit::stringToSymbol("eigenvectors"), eigenvectors); | |
@@ -3821,10 +4436,15 @@ | |
} | |
std::tuple<Tensor,Tensor> VariableType::eig(const Tensor & self, bool eigenvectors) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<EigBackward>(); | |
+ std::shared_ptr<EigBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<EigBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
auto ret = as_variable(baseType->eig(self_, eigenvectors)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(std::get<0>(ret), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "eig", { self }, ret ); | |
setattr(n, jit::stringToSymbol("eigenvectors"), eigenvectors); | |
@@ -3833,10 +4453,15 @@ | |
} | |
std::tuple<Tensor,Tensor,Tensor> VariableType::svd(const Tensor & self, bool some) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<SvdBackward>(); | |
+ std::shared_ptr<SvdBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SvdBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
auto ret = as_variable(baseType->svd(self_, some)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(std::get<0>(ret), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "svd", { self }, ret ); | |
setattr(n, jit::stringToSymbol("some"), some); | |
@@ -3845,11 +4470,16 @@ | |
} | |
Tensor VariableType::inverse(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<InverseBackward>(); | |
+ std::shared_ptr<InverseBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<InverseBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
auto ret = as_variable(baseType->inverse(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(ret, flags, grad_fn); | |
+ if (grad_fn) { | |
auto& output = ret; | |
grad_fn->output_ = SavedVariable(output, grad_fn.get()); | |
} | |
@@ -3861,14 +4491,17 @@ | |
} | |
Tensor VariableType::potrf(const Tensor & self, bool upper) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<PotrfBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<PotrfBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<PotrfBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->upper = upper; | |
} | |
auto ret = as_variable(baseType->potrf(self_, upper)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(ret, flags, grad_fn); | |
+ if (grad_fn) { | |
auto& output = ret; | |
grad_fn->output_ = SavedVariable(output, grad_fn.get()); | |
} | |
@@ -3881,10 +4514,15 @@ | |
Tensor VariableType::potrs(const Tensor & self, const Tensor & input2, bool upper) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& input2_ = unpack(input2, "input2", 1); | |
- auto flags = Function::flags({ self, input2 }); | |
- auto grad_fn = std::make_shared<PotrsBackward>(); | |
+ std::shared_ptr<PotrsBackward> grad_fn; | |
+ auto flags = compute_flags({ self, input2 }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<PotrsBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, input2 }); | |
+ } | |
auto ret = as_variable(baseType->potrs(self_, input2_, upper)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self, input2 })) { | |
jit::Node *n = jit::tracer::recordTrace( "potrs", { self, input2 }, ret ); | |
setattr(n, jit::stringToSymbol("upper"), upper); | |
@@ -3893,10 +4531,15 @@ | |
} | |
Tensor VariableType::potri(const Tensor & self, bool upper) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<PotriBackward>(); | |
+ std::shared_ptr<PotriBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<PotriBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
auto ret = as_variable(baseType->potri(self_, upper)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "potri", { self }, ret ); | |
setattr(n, jit::stringToSymbol("upper"), upper); | |
@@ -3905,10 +4548,15 @@ | |
} | |
std::tuple<Tensor,Tensor> VariableType::pstrf(const Tensor & self, bool upper, Scalar tol) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<PstrfBackward>(); | |
+ std::shared_ptr<PstrfBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<PstrfBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
auto ret = as_variable(baseType->pstrf(self_, upper, tol)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(std::get<0>(ret), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "pstrf", { self }, ret ); | |
setattr(n, jit::stringToSymbol("upper"), upper); | |
@@ -3918,10 +4566,15 @@ | |
} | |
std::tuple<Tensor,Tensor> VariableType::qr(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<QrBackward>(); | |
+ std::shared_ptr<QrBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<QrBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
auto ret = as_variable(baseType->qr(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(std::get<0>(ret), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "qr", { self }, ret ); | |
(void)n; | |
@@ -3930,10 +4583,15 @@ | |
} | |
std::tuple<Tensor,Tensor> VariableType::geqrf(const Tensor & self) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<GeqrfBackward>(); | |
+ std::shared_ptr<GeqrfBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<GeqrfBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
auto ret = as_variable(baseType->geqrf(self_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(std::get<0>(ret), flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "geqrf", { self }, ret ); | |
(void)n; | |
@@ -3943,10 +4601,15 @@ | |
Tensor VariableType::orgqr(const Tensor & self, const Tensor & input2) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& input2_ = unpack(input2, "input2", 1); | |
- auto flags = Function::flags({ self, input2 }); | |
- auto grad_fn = std::make_shared<OrgqrBackward>(); | |
+ std::shared_ptr<OrgqrBackward> grad_fn; | |
+ auto flags = compute_flags({ self, input2 }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<OrgqrBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, input2 }); | |
+ } | |
auto ret = as_variable(baseType->orgqr(self_, input2_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self, input2 })) { | |
jit::Node *n = jit::tracer::recordTrace( "orgqr", { self, input2 }, ret ); | |
(void)n; | |
@@ -3957,10 +4620,15 @@ | |
auto& self_ = unpack(self, "self", 0); | |
auto& input2_ = unpack(input2, "input2", 1); | |
auto& input3_ = unpack(input3, "input3", 2); | |
- auto flags = Function::flags({ self, input2, input3 }); | |
- auto grad_fn = std::make_shared<OrmqrBackward>(); | |
+ std::shared_ptr<OrmqrBackward> grad_fn; | |
+ auto flags = compute_flags({ self, input2, input3 }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<OrmqrBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self, input2, input3 }); | |
+ } | |
auto ret = as_variable(baseType->ormqr(self_, input2_, input3_, left, transpose)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self, input2, input3 })) { | |
jit::Node *n = jit::tracer::recordTrace( "ormqr", { self, input2, input3 }, ret ); | |
setattr(n, jit::stringToSymbol("left"), left); | |
@@ -3971,12 +4639,18 @@ | |
std::tuple<Tensor,Tensor> VariableType::btrifact(const Tensor & self, const Tensor & info, bool pivot) const { | |
auto& self_ = unpack(self, "self", 0); | |
auto& info_ = unpack(info, "info", 1); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<BtrifactBackward>(); | |
+ check_no_requires_grad(info, "info"); | |
+ std::shared_ptr<BtrifactBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<BtrifactBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
auto ret = as_variable(baseType->btrifact(self_, info_, pivot)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ self })) { | |
- jit::Node *n = jit::tracer::recordTrace( "btrifact", { self }, ret ); | |
+ set_flags(std::get<0>(ret), flags, grad_fn); | |
+ if (jit::tracer::isTracing({ self, info })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "btrifact", { self, info }, ret ); | |
setattr(n, jit::stringToSymbol("pivot"), pivot); | |
} | |
return std::tuple<Tensor,Tensor>(std::move(ret)); | |
@@ -3985,12 +4659,19 @@ | |
auto& self_ = unpack(self, "self", 0); | |
auto& LU_data_ = unpack(LU_data, "LU_data", 1); | |
auto& LU_pivots_ = unpack(LU_pivots, "LU_pivots", 2); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<BtrisolveBackward>(); | |
+ check_no_requires_grad(LU_data, "LU_data"); | |
+ check_no_requires_grad(LU_pivots, "LU_pivots"); | |
+ std::shared_ptr<BtrisolveBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<BtrisolveBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
auto ret = as_variable(baseType->btrisolve(self_, LU_data_, LU_pivots_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ self })) { | |
- jit::Node *n = jit::tracer::recordTrace( "btrisolve", { self }, ret ); | |
+ set_flags(ret, flags, grad_fn); | |
+ if (jit::tracer::isTracing({ self, LU_data, LU_pivots })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "btrisolve", { self, LU_data, LU_pivots }, ret ); | |
(void)n; | |
} | |
return Tensor(std::move(ret)); | |
@@ -4030,13 +4711,17 @@ | |
} | |
Tensor & VariableType::m_log_normal_(Tensor & self, double mean, double std, Generator * generator) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<LogNormalBackward>(); | |
+ check_inplace(self); | |
+ std::shared_ptr<LogNormalBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<LogNormalBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
baseType->m_log_normal_(self_, mean, std, generator); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
return self; | |
} | |
Tensor & VariableType::m_exponential_(Tensor & self, double lambd, Generator * generator) const { | |
@@ -4050,13 +4735,17 @@ | |
} | |
Tensor & VariableType::m_geometric_(Tensor & self, double p, Generator * generator) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto& pImpl = static_cast<VariableImpl&>(*self.get()); | |
- check_inplace(pImpl); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<GeometricBackward>(); | |
+ check_inplace(self); | |
+ std::shared_ptr<GeometricBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<GeometricBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
+ } | |
baseType->m_geometric_(self_, p, generator); | |
- pImpl.version_counter.increment(); | |
- wrap_output(self, std::move(flags), grad_fn); | |
+ increment_version(self); | |
+ set_flags(static_cast<Variable&>(self), flags, grad_fn); | |
return self; | |
} | |
int64_t VariableType::m_size(const Tensor & self, int64_t dim) const { | |
@@ -4084,15 +4773,18 @@ | |
} | |
Tensor VariableType::_unnarrow(const Tensor & self, int64_t dimension, int64_t offset, int64_t dimSize) const { | |
auto& self_ = unpack(self, "self", 0); | |
- auto flags = Function::flags({ self }); | |
- auto grad_fn = std::make_shared<UnnarrowBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<UnnarrowBackward> grad_fn; | |
+ auto flags = compute_flags({ self }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<UnnarrowBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ self }); | |
grad_fn->self_argsize_dimension = self.size(dimension); | |
grad_fn->dimension = dimension; | |
grad_fn->offset = offset; | |
} | |
auto ret = as_variable(baseType->_unnarrow(self_, dimension, offset, dimSize)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ self })) { | |
jit::Node *n = jit::tracer::recordTrace( "_unnarrow", { self }, ret ); | |
setattr(n, jit::stringToSymbol("dimension"), dimension); | |
@@ -4106,14 +4798,17 @@ | |
} | |
Tensor VariableType::cat(TensorList tensors, int64_t dim) const { | |
auto tensors_ = unpack(tensors, "tensors", 0); | |
- auto flags = Function::flags({ tensors }); | |
- auto grad_fn = std::make_shared<CatBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<CatBackward> grad_fn; | |
+ auto flags = compute_flags({ tensors }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<CatBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ tensors }); | |
grad_fn->tensors_argsizes_dim = to_arg_sizes(tensors, dim); | |
grad_fn->dim = dim; | |
} | |
auto ret = as_variable(baseType->cat(tensors_, dim)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ tensors })) { | |
jit::Node *n = jit::tracer::recordTrace( "cat", { tensors }, ret ); | |
setattr(n, jit::stringToSymbol("dim"), dim); | |
@@ -4124,18 +4819,23 @@ | |
auto& input_ = unpack(input, "input", 0); | |
auto& target_ = unpack(target, "target", 1); | |
auto weight_ = unpack_opt(weight, "weight", 2); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<BinaryCrossEntropyBackward>(); | |
- if (flags.is_executable) { | |
+ check_no_requires_grad(target, "target"); | |
+ check_no_requires_grad(weight, "weight"); | |
+ std::shared_ptr<BinaryCrossEntropyBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<BinaryCrossEntropyBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
grad_fn->target_ = SavedVariable(target, nullptr); | |
grad_fn->weight_ = SavedVariable(weight, nullptr); | |
grad_fn->size_average = size_average; | |
} | |
auto ret = as_variable(baseType->binary_cross_entropy_forward(input_, target_, weight_, size_average)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ input })) { | |
- jit::Node *n = jit::tracer::recordTrace( "binary_cross_entropy", { input }, ret ); | |
+ set_flags(ret, flags, grad_fn); | |
+ if (jit::tracer::isTracing({ input, target, weight })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "binary_cross_entropy", { input, target, weight }, ret ); | |
setattr(n, jit::stringToSymbol("size_average"), size_average); | |
} | |
return Tensor(std::move(ret)); | |
@@ -4144,30 +4844,39 @@ | |
auto& input_ = unpack(input, "input", 0); | |
auto& target_ = unpack(target, "target", 1); | |
auto weight_ = unpack_opt(weight, "weight", 2); | |
- auto flags = Function::flags({ input, target, weight }); | |
- auto grad_fn = std::make_shared<Error>("binary_cross_entropy_backward is not differentiable"); | |
+ std::shared_ptr<Error> grad_fn; | |
+ auto flags = compute_flags({ input, target, weight }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<Error>("the derivative for binary_cross_entropy_backward is not implemented"); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input, target, weight }); | |
+ } | |
auto ret = as_variable(baseType->binary_cross_entropy_backward(input_, target_, weight_, size_average)); | |
- wrap_output(ret, std::move(flags), std::move(grad_fn)); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ input, target, weight })) { | |
jit::Node *n = jit::tracer::recordTrace( "binary_cross_entropy_backward", { input, target, weight }, ret ); | |
setattr(n, jit::stringToSymbol("size_average"), size_average); | |
} | |
- return ret; | |
+ return Tensor(std::move(ret)); | |
} | |
Tensor VariableType::kl_div(const Tensor & input, const Tensor & target, bool size_average) const { | |
auto& input_ = unpack(input, "input", 0); | |
auto& target_ = unpack(target, "target", 1); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<KlDivBackward>(); | |
- if (flags.is_executable) { | |
+ check_no_requires_grad(target, "target"); | |
+ std::shared_ptr<KlDivBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<KlDivBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
grad_fn->target_ = SavedVariable(target, nullptr); | |
grad_fn->size_average = size_average; | |
} | |
auto ret = as_variable(baseType->kl_div_forward(input_, target_, size_average)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ input })) { | |
- jit::Node *n = jit::tracer::recordTrace( "kl_div", { input }, ret ); | |
+ set_flags(ret, flags, grad_fn); | |
+ if (jit::tracer::isTracing({ input, target })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "kl_div", { input, target }, ret ); | |
setattr(n, jit::stringToSymbol("size_average"), size_average); | |
} | |
return Tensor(std::move(ret)); | |
@@ -4175,12 +4884,18 @@ | |
Tensor VariableType::kl_div_backward(const Tensor & input, const Tensor & target, bool size_average) const { | |
auto& input_ = unpack(input, "input", 0); | |
auto& target_ = unpack(target, "target", 1); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<KlDivBackwardBackward>(); | |
+ check_no_requires_grad(target, "target"); | |
+ std::shared_ptr<KlDivBackwardBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<KlDivBackwardBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
+ } | |
auto ret = as_variable(baseType->kl_div_backward(input_, target_, size_average)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ input })) { | |
- jit::Node *n = jit::tracer::recordTrace( "kl_div_backward", { input }, ret ); | |
+ set_flags(ret, flags, grad_fn); | |
+ if (jit::tracer::isTracing({ input, target })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "kl_div_backward", { input, target }, ret ); | |
setattr(n, jit::stringToSymbol("size_average"), size_average); | |
} | |
return Tensor(std::move(ret)); | |
@@ -4188,17 +4903,21 @@ | |
Tensor VariableType::l1_loss(const Tensor & input, const Tensor & target, bool size_average) const { | |
auto& input_ = unpack(input, "input", 0); | |
auto& target_ = unpack(target, "target", 1); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<L1LossBackward>(); | |
- if (flags.is_executable) { | |
+ check_no_requires_grad(target, "target"); | |
+ std::shared_ptr<L1LossBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<L1LossBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
grad_fn->target_ = SavedVariable(target, nullptr); | |
grad_fn->size_average = size_average; | |
} | |
auto ret = as_variable(baseType->l1_loss_forward(input_, target_, size_average)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ input })) { | |
- jit::Node *n = jit::tracer::recordTrace( "l1_loss", { input }, ret ); | |
+ set_flags(ret, flags, grad_fn); | |
+ if (jit::tracer::isTracing({ input, target })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "l1_loss", { input, target }, ret ); | |
setattr(n, jit::stringToSymbol("size_average"), size_average); | |
} | |
return Tensor(std::move(ret)); | |
@@ -4206,12 +4925,18 @@ | |
Tensor VariableType::l1_loss_backward(const Tensor & input, const Tensor & target, bool size_average) const { | |
auto& input_ = unpack(input, "input", 0); | |
auto& target_ = unpack(target, "target", 1); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<L1LossBackwardBackward>(); | |
+ check_no_requires_grad(target, "target"); | |
+ std::shared_ptr<L1LossBackwardBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<L1LossBackwardBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
+ } | |
auto ret = as_variable(baseType->l1_loss_backward(input_, target_, size_average)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ input })) { | |
- jit::Node *n = jit::tracer::recordTrace( "l1_loss_backward", { input }, ret ); | |
+ set_flags(ret, flags, grad_fn); | |
+ if (jit::tracer::isTracing({ input, target })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "l1_loss_backward", { input, target }, ret ); | |
setattr(n, jit::stringToSymbol("size_average"), size_average); | |
} | |
return Tensor(std::move(ret)); | |
@@ -4219,18 +4944,22 @@ | |
Tensor VariableType::mse_loss(const Tensor & input, const Tensor & target, bool size_average, bool reduce) const { | |
auto& input_ = unpack(input, "input", 0); | |
auto& target_ = unpack(target, "target", 1); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<MseLossBackward>(); | |
- if (flags.is_executable) { | |
+ check_no_requires_grad(target, "target"); | |
+ std::shared_ptr<MseLossBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<MseLossBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
grad_fn->target_ = SavedVariable(target, nullptr); | |
grad_fn->size_average = size_average; | |
grad_fn->reduce = reduce; | |
} | |
auto ret = as_variable(baseType->mse_loss_forward(input_, target_, size_average, reduce)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ input })) { | |
- jit::Node *n = jit::tracer::recordTrace( "mse_loss", { input }, ret ); | |
+ set_flags(ret, flags, grad_fn); | |
+ if (jit::tracer::isTracing({ input, target })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "mse_loss", { input, target }, ret ); | |
setattr(n, jit::stringToSymbol("size_average"), size_average); | |
setattr(n, jit::stringToSymbol("reduce"), reduce); | |
} | |
@@ -4240,9 +4969,13 @@ | |
auto& grad_output_ = unpack(grad_output, "grad_output", 0); | |
auto& input_ = unpack(input, "input", 1); | |
auto& target_ = unpack(target, "target", 2); | |
- auto flags = Function::flags({ grad_output, input }); | |
- auto grad_fn = std::make_shared<MseLossBackwardBackward>(); | |
- if (flags.is_executable) { | |
+ check_no_requires_grad(target, "target"); | |
+ std::shared_ptr<MseLossBackwardBackward> grad_fn; | |
+ auto flags = compute_flags({ grad_output, input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<MseLossBackwardBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ grad_output, input }); | |
grad_fn->grad_output_ = SavedVariable(grad_output, nullptr); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
grad_fn->target_ = SavedVariable(target, nullptr); | |
@@ -4250,9 +4983,9 @@ | |
grad_fn->reduce = reduce; | |
} | |
auto ret = as_variable(baseType->mse_loss_backward(grad_output_, input_, target_, size_average, reduce)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ grad_output, input })) { | |
- jit::Node *n = jit::tracer::recordTrace( "mse_loss_backward", { grad_output, input }, ret ); | |
+ set_flags(ret, flags, grad_fn); | |
+ if (jit::tracer::isTracing({ grad_output, input, target })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "mse_loss_backward", { grad_output, input, target }, ret ); | |
setattr(n, jit::stringToSymbol("size_average"), size_average); | |
setattr(n, jit::stringToSymbol("reduce"), reduce); | |
} | |
@@ -4262,9 +4995,13 @@ | |
auto& input_ = unpack(input, "input", 0); | |
auto& target_ = unpack_long(target, "target", 1); | |
auto weight_ = unpack_opt(weight, "weight", 4); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<MultiMarginLossBackward>(); | |
- if (flags.is_executable) { | |
+ check_no_requires_grad(weight, "weight"); | |
+ std::shared_ptr<MultiMarginLossBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<MultiMarginLossBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
grad_fn->target_ = SavedVariable(target, nullptr); | |
grad_fn->p = p; | |
@@ -4273,9 +5010,9 @@ | |
grad_fn->size_average = size_average; | |
} | |
auto ret = as_variable(baseType->multi_margin_loss_forward(input_, target_, p, margin, weight_, size_average)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ input })) { | |
- jit::Node *n = jit::tracer::recordTrace( "multi_margin_loss", { input }, ret ); | |
+ set_flags(ret, flags, grad_fn); | |
+ if (jit::tracer::isTracing({ input, target, weight })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "multi_margin_loss", { input, target, weight }, ret ); | |
setattr(n, jit::stringToSymbol("p"), p); | |
setattr(n, jit::stringToSymbol("margin"), margin); | |
setattr(n, jit::stringToSymbol("size_average"), size_average); | |
@@ -4286,35 +5023,43 @@ | |
auto& input_ = unpack(input, "input", 0); | |
auto& target_ = unpack_long(target, "target", 1); | |
auto weight_ = unpack_opt(weight, "weight", 4); | |
- auto flags = Function::flags({ input, target, weight }); | |
- auto grad_fn = std::make_shared<Error>("multi_margin_loss_backward is not differentiable"); | |
+ std::shared_ptr<Error> grad_fn; | |
+ auto flags = compute_flags({ input, target, weight }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<Error>("the derivative for multi_margin_loss_backward is not implemented"); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input, target, weight }); | |
+ } | |
auto ret = as_variable(baseType->multi_margin_loss_backward(input_, target_, p, margin, weight_, size_average)); | |
- wrap_output(ret, std::move(flags), std::move(grad_fn)); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ input, target, weight })) { | |
jit::Node *n = jit::tracer::recordTrace( "multi_margin_loss_backward", { input, target, weight }, ret ); | |
setattr(n, jit::stringToSymbol("p"), p); | |
setattr(n, jit::stringToSymbol("margin"), margin); | |
setattr(n, jit::stringToSymbol("size_average"), size_average); | |
} | |
- return ret; | |
+ return Tensor(std::move(ret)); | |
} | |
Tensor VariableType::multilabel_margin_loss(const Tensor & input, const Tensor & target, bool size_average) const { | |
auto& input_ = unpack(input, "input", 0); | |
auto& target_ = unpack_long(target, "target", 1); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<MultilabelMarginLossBackward>(); | |
auto is_target = tensor(); | |
auto& is_target_ = static_cast<VariableImpl*>(is_target.get())->data; | |
- if (flags.is_executable) { | |
+ std::shared_ptr<MultilabelMarginLossBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<MultilabelMarginLossBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
grad_fn->target_ = SavedVariable(target, nullptr); | |
grad_fn->size_average = size_average; | |
grad_fn->is_target_ = SavedVariable(is_target, nullptr); | |
} | |
auto ret = as_variable(baseType->multilabel_margin_loss_forward(input_, target_, size_average, is_target_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ input })) { | |
- jit::Node *n = jit::tracer::recordTrace( "multilabel_margin_loss", { input }, ret ); | |
+ set_flags(ret, flags, grad_fn); | |
+ if (jit::tracer::isTracing({ input, target })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "multilabel_margin_loss", { input, target }, ret ); | |
setattr(n, jit::stringToSymbol("size_average"), size_average); | |
} | |
return Tensor(std::move(ret)); | |
@@ -4323,25 +5068,34 @@ | |
auto& input_ = unpack(input, "input", 0); | |
auto& target_ = unpack_long(target, "target", 1); | |
auto& is_target_ = unpack(is_target, "is_target", 3); | |
- auto flags = Function::flags({ input, target, is_target }); | |
- auto grad_fn = std::make_shared<Error>("multilabel_margin_loss_backward is not differentiable"); | |
+ std::shared_ptr<Error> grad_fn; | |
+ auto flags = compute_flags({ input, target, is_target }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<Error>("the derivative for multilabel_margin_loss_backward is not implemented"); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input, target, is_target }); | |
+ } | |
auto ret = as_variable(baseType->multilabel_margin_loss_backward(input_, target_, size_average, is_target_)); | |
- wrap_output(ret, std::move(flags), std::move(grad_fn)); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ input, target, is_target })) { | |
jit::Node *n = jit::tracer::recordTrace( "multilabel_margin_loss_backward", { input, target, is_target }, ret ); | |
setattr(n, jit::stringToSymbol("size_average"), size_average); | |
} | |
- return ret; | |
+ return Tensor(std::move(ret)); | |
} | |
Tensor VariableType::nll_loss(const Tensor & input, const Tensor & target, const Tensor & weight, bool size_average, int64_t ignore_index) const { | |
auto& input_ = unpack(input, "input", 0); | |
auto& target_ = unpack_long(target, "target", 1); | |
auto weight_ = unpack_opt(weight, "weight", 2); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<NllLossBackward>(); | |
auto total_weight = tensor(); | |
auto& total_weight_ = static_cast<VariableImpl*>(total_weight.get())->data; | |
- if (flags.is_executable) { | |
+ check_no_requires_grad(weight, "weight"); | |
+ std::shared_ptr<NllLossBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<NllLossBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
grad_fn->target_ = SavedVariable(target, nullptr); | |
grad_fn->weight_ = SavedVariable(weight, nullptr); | |
@@ -4350,9 +5104,9 @@ | |
grad_fn->total_weight_ = SavedVariable(total_weight, nullptr); | |
} | |
auto ret = as_variable(baseType->nll_loss_forward(input_, target_, weight_, size_average, ignore_index, total_weight_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ input })) { | |
- jit::Node *n = jit::tracer::recordTrace( "nll_loss", { input }, ret ); | |
+ set_flags(ret, flags, grad_fn); | |
+ if (jit::tracer::isTracing({ input, target, weight })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "nll_loss", { input, target, weight }, ret ); | |
setattr(n, jit::stringToSymbol("size_average"), size_average); | |
setattr(n, jit::stringToSymbol("ignore_index"), ignore_index); | |
} | |
@@ -4363,12 +5117,19 @@ | |
auto& target_ = unpack_long(target, "target", 1); | |
auto weight_ = unpack_opt(weight, "weight", 2); | |
auto& total_weight_ = unpack(total_weight, "total_weight", 5); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<NllLossBackwardBackward>(); | |
+ check_no_requires_grad(weight, "weight"); | |
+ check_no_requires_grad(total_weight, "total_weight"); | |
+ std::shared_ptr<NllLossBackwardBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<NllLossBackwardBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
+ } | |
auto ret = as_variable(baseType->nll_loss_backward(input_, target_, weight_, size_average, ignore_index, total_weight_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ input })) { | |
- jit::Node *n = jit::tracer::recordTrace( "nll_loss_backward", { input }, ret ); | |
+ set_flags(ret, flags, grad_fn); | |
+ if (jit::tracer::isTracing({ input, target, weight, total_weight })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "nll_loss_backward", { input, target, weight, total_weight }, ret ); | |
setattr(n, jit::stringToSymbol("size_average"), size_average); | |
setattr(n, jit::stringToSymbol("ignore_index"), ignore_index); | |
} | |
@@ -4378,11 +5139,15 @@ | |
auto& input_ = unpack(input, "input", 0); | |
auto& target_ = unpack_long(target, "target", 1); | |
auto weight_ = unpack_opt(weight, "weight", 2); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<NllLoss2DBackward>(); | |
auto total_weight = tensor(); | |
auto& total_weight_ = static_cast<VariableImpl*>(total_weight.get())->data; | |
- if (flags.is_executable) { | |
+ check_no_requires_grad(weight, "weight"); | |
+ std::shared_ptr<NllLoss2DBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<NllLoss2DBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
grad_fn->target_ = SavedVariable(target, nullptr); | |
grad_fn->weight_ = SavedVariable(weight, nullptr); | |
@@ -4391,9 +5156,9 @@ | |
grad_fn->total_weight_ = SavedVariable(total_weight, nullptr); | |
} | |
auto ret = as_variable(baseType->nll_loss2d_forward(input_, target_, weight_, size_average, ignore_index, total_weight_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ input })) { | |
- jit::Node *n = jit::tracer::recordTrace( "nll_loss2d", { input }, ret ); | |
+ set_flags(ret, flags, grad_fn); | |
+ if (jit::tracer::isTracing({ input, target, weight })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "nll_loss2d", { input, target, weight }, ret ); | |
setattr(n, jit::stringToSymbol("size_average"), size_average); | |
setattr(n, jit::stringToSymbol("ignore_index"), ignore_index); | |
} | |
@@ -4404,12 +5169,19 @@ | |
auto& target_ = unpack_long(target, "target", 1); | |
auto weight_ = unpack_opt(weight, "weight", 2); | |
auto& total_weight_ = unpack(total_weight, "total_weight", 5); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<NllLoss2DBackwardBackward>(); | |
+ check_no_requires_grad(weight, "weight"); | |
+ check_no_requires_grad(total_weight, "total_weight"); | |
+ std::shared_ptr<NllLoss2DBackwardBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<NllLoss2DBackwardBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
+ } | |
auto ret = as_variable(baseType->nll_loss2d_backward(input_, target_, weight_, size_average, ignore_index, total_weight_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ input })) { | |
- jit::Node *n = jit::tracer::recordTrace( "nll_loss2d_backward", { input }, ret ); | |
+ set_flags(ret, flags, grad_fn); | |
+ if (jit::tracer::isTracing({ input, target, weight, total_weight })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "nll_loss2d_backward", { input, target, weight, total_weight }, ret ); | |
setattr(n, jit::stringToSymbol("size_average"), size_average); | |
setattr(n, jit::stringToSymbol("ignore_index"), ignore_index); | |
} | |
@@ -4418,17 +5190,21 @@ | |
Tensor VariableType::smooth_l1_loss(const Tensor & input, const Tensor & target, bool size_average) const { | |
auto& input_ = unpack(input, "input", 0); | |
auto& target_ = unpack(target, "target", 1); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<SmoothL1LossBackward>(); | |
- if (flags.is_executable) { | |
+ check_no_requires_grad(target, "target"); | |
+ std::shared_ptr<SmoothL1LossBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SmoothL1LossBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
grad_fn->target_ = SavedVariable(target, nullptr); | |
grad_fn->size_average = size_average; | |
} | |
auto ret = as_variable(baseType->smooth_l1_loss_forward(input_, target_, size_average)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ input })) { | |
- jit::Node *n = jit::tracer::recordTrace( "smooth_l1_loss", { input }, ret ); | |
+ set_flags(ret, flags, grad_fn); | |
+ if (jit::tracer::isTracing({ input, target })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "smooth_l1_loss", { input, target }, ret ); | |
setattr(n, jit::stringToSymbol("size_average"), size_average); | |
} | |
return Tensor(std::move(ret)); | |
@@ -4436,17 +5212,21 @@ | |
Tensor VariableType::smooth_l1_loss_backward(const Tensor & input, const Tensor & target, bool size_average) const { | |
auto& input_ = unpack(input, "input", 0); | |
auto& target_ = unpack(target, "target", 1); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<SmoothL1LossBackwardBackward>(); | |
- if (flags.is_executable) { | |
+ check_no_requires_grad(target, "target"); | |
+ std::shared_ptr<SmoothL1LossBackwardBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SmoothL1LossBackwardBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
grad_fn->target_ = SavedVariable(target, nullptr); | |
grad_fn->size_average = size_average; | |
} | |
auto ret = as_variable(baseType->smooth_l1_loss_backward(input_, target_, size_average)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ input })) { | |
- jit::Node *n = jit::tracer::recordTrace( "smooth_l1_loss_backward", { input }, ret ); | |
+ set_flags(ret, flags, grad_fn); | |
+ if (jit::tracer::isTracing({ input, target })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "smooth_l1_loss_backward", { input, target }, ret ); | |
setattr(n, jit::stringToSymbol("size_average"), size_average); | |
} | |
return Tensor(std::move(ret)); | |
@@ -4454,17 +5234,21 @@ | |
Tensor VariableType::soft_margin_loss(const Tensor & input, const Tensor & target, bool size_average) const { | |
auto& input_ = unpack(input, "input", 0); | |
auto& target_ = unpack(target, "target", 1); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<SoftMarginLossBackward>(); | |
- if (flags.is_executable) { | |
+ check_no_requires_grad(target, "target"); | |
+ std::shared_ptr<SoftMarginLossBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SoftMarginLossBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
grad_fn->target_ = SavedVariable(target, nullptr); | |
grad_fn->size_average = size_average; | |
} | |
auto ret = as_variable(baseType->soft_margin_loss_forward(input_, target_, size_average)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ input })) { | |
- jit::Node *n = jit::tracer::recordTrace( "soft_margin_loss", { input }, ret ); | |
+ set_flags(ret, flags, grad_fn); | |
+ if (jit::tracer::isTracing({ input, target })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "soft_margin_loss", { input, target }, ret ); | |
setattr(n, jit::stringToSymbol("size_average"), size_average); | |
} | |
return Tensor(std::move(ret)); | |
@@ -4472,32 +5256,40 @@ | |
Tensor VariableType::soft_margin_loss_backward(const Tensor & input, const Tensor & target, bool size_average) const { | |
auto& input_ = unpack(input, "input", 0); | |
auto& target_ = unpack(target, "target", 1); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<SoftMarginLossBackwardBackward>(); | |
- if (flags.is_executable) { | |
+ check_no_requires_grad(target, "target"); | |
+ std::shared_ptr<SoftMarginLossBackwardBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SoftMarginLossBackwardBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
grad_fn->target_ = SavedVariable(target, nullptr); | |
grad_fn->size_average = size_average; | |
} | |
auto ret = as_variable(baseType->soft_margin_loss_backward(input_, target_, size_average)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ input })) { | |
- jit::Node *n = jit::tracer::recordTrace( "soft_margin_loss_backward", { input }, ret ); | |
+ set_flags(ret, flags, grad_fn); | |
+ if (jit::tracer::isTracing({ input, target })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "soft_margin_loss_backward", { input, target }, ret ); | |
setattr(n, jit::stringToSymbol("size_average"), size_average); | |
} | |
return Tensor(std::move(ret)); | |
} | |
Tensor VariableType::elu(const Tensor & input, Scalar alpha, bool inplace) const { | |
auto& input_ = unpack(input, "input", 0); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<EluBackward>(); | |
- if (flags.is_executable) { | |
+ if (inplace) check_inplace(input); | |
+ std::shared_ptr<EluBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<EluBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
grad_fn->alpha = alpha; | |
} | |
auto ret = as_variable(baseType->elu_forward(input_, alpha, inplace)); | |
if (inplace) increment_version(input); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(ret, flags, grad_fn); | |
+ if (grad_fn) { | |
auto& output = ret; | |
grad_fn->output_ = SavedVariable(output, grad_fn.get()); | |
grad_fn->input_ = SavedVariable(input, grad_fn.get()); | |
@@ -4513,23 +5305,27 @@ | |
auto& grad_output_ = unpack(grad_output, "grad_output", 0); | |
auto& input_ = unpack(input, "input", 1); | |
auto& output_ = unpack(output, "output", 4); | |
- auto flags = Function::flags({ grad_output, input }); | |
- auto grad_fn = std::make_shared<EluBackwardBackward>(); | |
- if (flags.is_executable) { | |
+ if (inplace) check_inplace(input); | |
+ std::shared_ptr<EluBackwardBackward> grad_fn; | |
+ auto flags = compute_flags({ grad_output, input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<EluBackwardBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ grad_output, input }); | |
grad_fn->alpha = alpha; | |
grad_fn->inplace = inplace; | |
grad_fn->output_ = SavedVariable(output, nullptr); | |
} | |
auto ret = as_variable(baseType->elu_backward(grad_output_, input_, alpha, inplace, output_)); | |
if (inplace) increment_version(input); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(ret, flags, grad_fn); | |
+ if (grad_fn) { | |
grad_fn->input_ = SavedVariable(input, grad_fn.get()); | |
auto& grad_input = ret; | |
grad_fn->grad_input_ = SavedVariable(grad_input, grad_fn.get()); | |
} | |
- if (jit::tracer::isTracing({ grad_output, input })) { | |
- jit::Node *n = jit::tracer::recordTrace( "elu_backward", { grad_output, input }, ret ); | |
+ if (jit::tracer::isTracing({ grad_output, input, output })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "elu_backward", { grad_output, input, output }, ret ); | |
setattr(n, jit::stringToSymbol("alpha"), alpha); | |
setattr(n, jit::stringToSymbol("inplace"), inplace); | |
} | |
@@ -4537,14 +5333,17 @@ | |
} | |
Tensor VariableType::glu(const Tensor & input, int64_t dim) const { | |
auto& input_ = unpack(input, "input", 0); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<GluBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<GluBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<GluBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
grad_fn->dim = dim; | |
} | |
auto ret = as_variable(baseType->glu_forward(input_, dim)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ input })) { | |
jit::Node *n = jit::tracer::recordTrace( "glu", { input }, ret ); | |
setattr(n, jit::stringToSymbol("dim"), dim); | |
@@ -4554,15 +5353,18 @@ | |
Tensor VariableType::glu_backward(const Tensor & grad_output, const Tensor & input, int64_t dim) const { | |
auto& grad_output_ = unpack(grad_output, "grad_output", 0); | |
auto& input_ = unpack(input, "input", 1); | |
- auto flags = Function::flags({ grad_output, input }); | |
- auto grad_fn = std::make_shared<GluBackwardBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<GluBackwardBackward> grad_fn; | |
+ auto flags = compute_flags({ grad_output, input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<GluBackwardBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ grad_output, input }); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
grad_fn->dim = dim; | |
grad_fn->grad_output_ = SavedVariable(grad_output, nullptr); | |
} | |
auto ret = as_variable(baseType->glu_backward(grad_output_, input_, dim)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ grad_output, input })) { | |
jit::Node *n = jit::tracer::recordTrace( "glu_backward", { grad_output, input }, ret ); | |
setattr(n, jit::stringToSymbol("dim"), dim); | |
@@ -4571,14 +5373,17 @@ | |
} | |
Tensor VariableType::hardshrink(const Tensor & input, Scalar lambd) const { | |
auto& input_ = unpack(input, "input", 0); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<HardshrinkBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<HardshrinkBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<HardshrinkBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
grad_fn->lambd = lambd; | |
} | |
auto ret = as_variable(baseType->hardshrink_forward(input_, lambd)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ input })) { | |
jit::Node *n = jit::tracer::recordTrace( "hardshrink", { input }, ret ); | |
setattr(n, jit::stringToSymbol("lambd"), lambd); | |
@@ -4588,14 +5393,17 @@ | |
Tensor VariableType::hardshrink_backward(const Tensor & grad_output, const Tensor & input, Scalar lambd) const { | |
auto& grad_output_ = unpack(grad_output, "grad_output", 0); | |
auto& input_ = unpack(input, "input", 1); | |
- auto flags = Function::flags({ grad_output, input }); | |
- auto grad_fn = std::make_shared<HardshrinkBackwardBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<HardshrinkBackwardBackward> grad_fn; | |
+ auto flags = compute_flags({ grad_output, input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<HardshrinkBackwardBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ grad_output, input }); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
grad_fn->lambd = lambd; | |
} | |
auto ret = as_variable(baseType->hardshrink_backward(grad_output_, input_, lambd)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ grad_output, input })) { | |
jit::Node *n = jit::tracer::recordTrace( "hardshrink_backward", { grad_output, input }, ret ); | |
setattr(n, jit::stringToSymbol("lambd"), lambd); | |
@@ -4604,16 +5412,20 @@ | |
} | |
Tensor VariableType::hardtanh(const Tensor & input, Scalar min_val, Scalar max_val, bool inplace) const { | |
auto& input_ = unpack(input, "input", 0); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<HardtanhBackward>(); | |
- if (flags.is_executable) { | |
+ if (inplace) check_inplace(input); | |
+ std::shared_ptr<HardtanhBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<HardtanhBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
grad_fn->min_val = min_val; | |
grad_fn->max_val = max_val; | |
} | |
auto ret = as_variable(baseType->hardtanh_forward(input_, min_val, max_val, inplace)); | |
if (inplace) increment_version(input); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(ret, flags, grad_fn); | |
+ if (grad_fn) { | |
grad_fn->input_ = SavedVariable(input, grad_fn.get()); | |
} | |
if (jit::tracer::isTracing({ input })) { | |
@@ -4627,16 +5439,20 @@ | |
Tensor VariableType::hardtanh_backward(const Tensor & grad_output, const Tensor & input, Scalar min_val, Scalar max_val, bool inplace) const { | |
auto& grad_output_ = unpack(grad_output, "grad_output", 0); | |
auto& input_ = unpack(input, "input", 1); | |
- auto flags = Function::flags({ grad_output, input }); | |
- auto grad_fn = std::make_shared<HardtanhBackwardBackward>(); | |
- if (flags.is_executable) { | |
+ if (inplace) check_inplace(input); | |
+ std::shared_ptr<HardtanhBackwardBackward> grad_fn; | |
+ auto flags = compute_flags({ grad_output, input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<HardtanhBackwardBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ grad_output, input }); | |
grad_fn->min_val = min_val; | |
grad_fn->max_val = max_val; | |
} | |
auto ret = as_variable(baseType->hardtanh_backward(grad_output_, input_, min_val, max_val, inplace)); | |
if (inplace) increment_version(input); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(ret, flags, grad_fn); | |
+ if (grad_fn) { | |
grad_fn->input_ = SavedVariable(input, grad_fn.get()); | |
} | |
if (jit::tracer::isTracing({ grad_output, input })) { | |
@@ -4649,15 +5465,19 @@ | |
} | |
Tensor VariableType::leaky_relu(const Tensor & input, Scalar negative_slope, bool inplace) const { | |
auto& input_ = unpack(input, "input", 0); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<LeakyReluBackward>(); | |
- if (flags.is_executable) { | |
+ if (inplace) check_inplace(input); | |
+ std::shared_ptr<LeakyReluBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<LeakyReluBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
grad_fn->negative_slope = negative_slope; | |
} | |
auto ret = as_variable(baseType->leaky_relu_forward(input_, negative_slope, inplace)); | |
if (inplace) increment_version(input); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(ret, flags, grad_fn); | |
+ if (grad_fn) { | |
grad_fn->input_ = SavedVariable(input, grad_fn.get()); | |
} | |
if (jit::tracer::isTracing({ input })) { | |
@@ -4670,15 +5490,19 @@ | |
Tensor VariableType::leaky_relu_backward(const Tensor & grad_output, const Tensor & input, Scalar negative_slope, bool inplace) const { | |
auto& grad_output_ = unpack(grad_output, "grad_output", 0); | |
auto& input_ = unpack(input, "input", 1); | |
- auto flags = Function::flags({ grad_output, input }); | |
- auto grad_fn = std::make_shared<LeakyReluBackwardBackward>(); | |
- if (flags.is_executable) { | |
+ if (inplace) check_inplace(input); | |
+ std::shared_ptr<LeakyReluBackwardBackward> grad_fn; | |
+ auto flags = compute_flags({ grad_output, input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<LeakyReluBackwardBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ grad_output, input }); | |
grad_fn->negative_slope = negative_slope; | |
} | |
auto ret = as_variable(baseType->leaky_relu_backward(grad_output_, input_, negative_slope, inplace)); | |
if (inplace) increment_version(input); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(ret, flags, grad_fn); | |
+ if (grad_fn) { | |
grad_fn->input_ = SavedVariable(input, grad_fn.get()); | |
} | |
if (jit::tracer::isTracing({ grad_output, input })) { | |
@@ -4690,16 +5514,19 @@ | |
} | |
Tensor VariableType::log_sigmoid(const Tensor & input) const { | |
auto& input_ = unpack(input, "input", 0); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<LogSigmoidBackward>(); | |
auto buffer = tensor(); | |
auto& buffer_ = static_cast<VariableImpl*>(buffer.get())->data; | |
- if (flags.is_executable) { | |
+ std::shared_ptr<LogSigmoidBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<LogSigmoidBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
grad_fn->buffer_ = SavedVariable(buffer, nullptr); | |
} | |
auto ret = as_variable(baseType->log_sigmoid_forward(input_, buffer_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ input })) { | |
jit::Node *n = jit::tracer::recordTrace( "log_sigmoid", { input }, ret ); | |
(void)n; | |
@@ -4710,32 +5537,39 @@ | |
auto& grad_output_ = unpack(grad_output, "grad_output", 0); | |
auto& input_ = unpack(input, "input", 1); | |
auto& buffer_ = unpack(buffer, "buffer", 2); | |
- auto flags = Function::flags({ grad_output, input }); | |
- auto grad_fn = std::make_shared<LogSigmoidBackwardBackward>(); | |
- if (flags.is_executable) { | |
+ check_no_requires_grad(buffer, "buffer"); | |
+ std::shared_ptr<LogSigmoidBackwardBackward> grad_fn; | |
+ auto flags = compute_flags({ grad_output, input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<LogSigmoidBackwardBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ grad_output, input }); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
grad_fn->buffer_ = SavedVariable(buffer, nullptr); | |
grad_fn->grad_output_ = SavedVariable(grad_output, nullptr); | |
} | |
auto ret = as_variable(baseType->log_sigmoid_backward(grad_output_, input_, buffer_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ grad_output, input })) { | |
- jit::Node *n = jit::tracer::recordTrace( "log_sigmoid_backward", { grad_output, input }, ret ); | |
+ set_flags(ret, flags, grad_fn); | |
+ if (jit::tracer::isTracing({ grad_output, input, buffer })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "log_sigmoid_backward", { grad_output, input, buffer }, ret ); | |
(void)n; | |
} | |
return Tensor(std::move(ret)); | |
} | |
Tensor VariableType::log_softmax(const Tensor & input, int64_t dim) const { | |
auto& input_ = unpack(input, "input", 0); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<LogSoftmaxBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<LogSoftmaxBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<LogSoftmaxBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
grad_fn->dim = dim; | |
} | |
auto ret = as_variable(baseType->log_softmax_forward(input_, dim)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(ret, flags, grad_fn); | |
+ if (grad_fn) { | |
auto& output = ret; | |
grad_fn->output_ = SavedVariable(output, grad_fn.get()); | |
} | |
@@ -4749,17 +5583,20 @@ | |
auto& grad_output_ = unpack(grad_output, "grad_output", 0); | |
auto& input_ = unpack(input, "input", 1); | |
auto& output_ = unpack(output, "output", 3); | |
- auto flags = Function::flags({ grad_output, input }); | |
- auto grad_fn = std::make_shared<LogSoftmaxBackwardBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<LogSoftmaxBackwardBackward> grad_fn; | |
+ auto flags = compute_flags({ grad_output, input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<LogSoftmaxBackwardBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ grad_output, input }); | |
grad_fn->dim = dim; | |
grad_fn->output_ = SavedVariable(output, nullptr); | |
grad_fn->grad_output_ = SavedVariable(grad_output, nullptr); | |
} | |
auto ret = as_variable(baseType->log_softmax_backward(grad_output_, input_, dim, output_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ grad_output, input })) { | |
- jit::Node *n = jit::tracer::recordTrace( "log_softmax_backward", { grad_output, input }, ret ); | |
+ set_flags(ret, flags, grad_fn); | |
+ if (jit::tracer::isTracing({ grad_output, input, output })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "log_softmax_backward", { grad_output, input, output }, ret ); | |
setattr(n, jit::stringToSymbol("dim"), dim); | |
} | |
return Tensor(std::move(ret)); | |
@@ -4767,14 +5604,17 @@ | |
Tensor VariableType::prelu(const Tensor & input, const Tensor & weight) const { | |
auto& input_ = unpack(input, "input", 0); | |
auto& weight_ = unpack(weight, "weight", 1); | |
- auto flags = Function::flags({ input, weight }); | |
- auto grad_fn = std::make_shared<PreluBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<PreluBackward> grad_fn; | |
+ auto flags = compute_flags({ input, weight }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<PreluBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input, weight }); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
grad_fn->weight_ = SavedVariable(weight, nullptr); | |
} | |
auto ret = as_variable(baseType->prelu_forward(input_, weight_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ input, weight })) { | |
jit::Node *n = jit::tracer::recordTrace( "prelu", { input, weight }, ret ); | |
(void)n; | |
@@ -4785,15 +5625,18 @@ | |
auto& grad_output_ = unpack(grad_output, "grad_output", 0); | |
auto& input_ = unpack(input, "input", 1); | |
auto& weight_ = unpack(weight, "weight", 2); | |
- auto flags = Function::flags({ grad_output, input, weight }); | |
- auto grad_fn = std::make_shared<PreluBackwardBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<PreluBackwardBackward> grad_fn; | |
+ auto flags = compute_flags({ grad_output, input, weight }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<PreluBackwardBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ grad_output, input, weight }); | |
grad_fn->grad_output_info = grad_output; | |
grad_fn->input_info = input; | |
grad_fn->weight_info = weight; | |
} | |
auto ret = as_variable(baseType->prelu_backward(grad_output_, input_, weight_, output_mask)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(std::get<0>(ret), flags, grad_fn); | |
if (jit::tracer::isTracing({ grad_output, input, weight })) { | |
jit::Node *n = jit::tracer::recordTrace( "prelu_backward", { grad_output, input, weight }, ret ); | |
setattr(n, jit::stringToSymbol("output_mask"), output_mask); | |
@@ -4802,11 +5645,15 @@ | |
} | |
Tensor VariableType::rrelu(const Tensor & input, Scalar lower, Scalar upper, bool training, bool inplace, Generator * generator) const { | |
auto& input_ = unpack(input, "input", 0); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<RreluBackward>(); | |
auto noise = tensor(); | |
auto& noise_ = static_cast<VariableImpl*>(noise.get())->data; | |
- if (flags.is_executable) { | |
+ if (inplace) check_inplace(input); | |
+ std::shared_ptr<RreluBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<RreluBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
grad_fn->lower = lower; | |
grad_fn->upper = upper; | |
grad_fn->training = training; | |
@@ -4814,8 +5661,8 @@ | |
} | |
auto ret = as_variable(baseType->rrelu_forward(input_, lower, upper, training, inplace, generator, noise_)); | |
if (inplace) increment_version(input); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(ret, flags, grad_fn); | |
+ if (grad_fn) { | |
grad_fn->input_ = SavedVariable(input, grad_fn.get()); | |
} | |
return Tensor(std::move(ret)); | |
@@ -4824,9 +5671,14 @@ | |
auto& grad_output_ = unpack(grad_output, "grad_output", 0); | |
auto& input_ = unpack(input, "input", 1); | |
auto& noise_ = unpack(noise, "noise", 6); | |
- auto flags = Function::flags({ grad_output, input }); | |
- auto grad_fn = std::make_shared<RreluBackwardBackward>(); | |
- if (flags.is_executable) { | |
+ if (inplace) check_inplace(input); | |
+ check_no_requires_grad(noise, "noise"); | |
+ std::shared_ptr<RreluBackwardBackward> grad_fn; | |
+ auto flags = compute_flags({ grad_output, input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<RreluBackwardBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ grad_output, input }); | |
grad_fn->lower = lower; | |
grad_fn->upper = upper; | |
grad_fn->training = training; | |
@@ -4834,12 +5686,12 @@ | |
} | |
auto ret = as_variable(baseType->rrelu_backward(grad_output_, input_, lower, upper, training, inplace, noise_)); | |
if (inplace) increment_version(input); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(ret, flags, grad_fn); | |
+ if (grad_fn) { | |
grad_fn->input_ = SavedVariable(input, grad_fn.get()); | |
} | |
- if (jit::tracer::isTracing({ grad_output, input })) { | |
- jit::Node *n = jit::tracer::recordTrace( "rrelu_backward", { grad_output, input }, ret ); | |
+ if (jit::tracer::isTracing({ grad_output, input, noise })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "rrelu_backward", { grad_output, input, noise }, ret ); | |
setattr(n, jit::stringToSymbol("lower"), lower); | |
setattr(n, jit::stringToSymbol("upper"), upper); | |
setattr(n, jit::stringToSymbol("training"), training); | |
@@ -4849,15 +5701,18 @@ | |
} | |
Tensor VariableType::softmax(const Tensor & input, int64_t dim) const { | |
auto& input_ = unpack(input, "input", 0); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<SoftmaxBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<SoftmaxBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SoftmaxBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
grad_fn->dim = dim; | |
} | |
auto ret = as_variable(baseType->softmax_forward(input_, dim)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(ret, flags, grad_fn); | |
+ if (grad_fn) { | |
auto& output = ret; | |
grad_fn->output_ = SavedVariable(output, grad_fn.get()); | |
} | |
@@ -4871,34 +5726,40 @@ | |
auto& grad_output_ = unpack(grad_output, "grad_output", 0); | |
auto& input_ = unpack(input, "input", 1); | |
auto& output_ = unpack(output, "output", 3); | |
- auto flags = Function::flags({ grad_output, input }); | |
- auto grad_fn = std::make_shared<SoftmaxBackwardBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<SoftmaxBackwardBackward> grad_fn; | |
+ auto flags = compute_flags({ grad_output, input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SoftmaxBackwardBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ grad_output, input }); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
grad_fn->dim = dim; | |
grad_fn->output_ = SavedVariable(output, nullptr); | |
grad_fn->grad_output_ = SavedVariable(grad_output, nullptr); | |
} | |
auto ret = as_variable(baseType->softmax_backward(grad_output_, input_, dim, output_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ grad_output, input })) { | |
- jit::Node *n = jit::tracer::recordTrace( "softmax_backward", { grad_output, input }, ret ); | |
+ set_flags(ret, flags, grad_fn); | |
+ if (jit::tracer::isTracing({ grad_output, input, output })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "softmax_backward", { grad_output, input, output }, ret ); | |
setattr(n, jit::stringToSymbol("dim"), dim); | |
} | |
return Tensor(std::move(ret)); | |
} | |
Tensor VariableType::softplus(const Tensor & input, Scalar beta, Scalar threshold) const { | |
auto& input_ = unpack(input, "input", 0); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<SoftplusBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<SoftplusBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SoftplusBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
grad_fn->beta = beta; | |
grad_fn->threshold = threshold; | |
} | |
auto ret = as_variable(baseType->softplus_forward(input_, beta, threshold)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(ret, flags, grad_fn); | |
+ if (grad_fn) { | |
auto& output = ret; | |
grad_fn->output_ = SavedVariable(output, grad_fn.get()); | |
} | |
@@ -4913,9 +5774,12 @@ | |
auto& grad_output_ = unpack(grad_output, "grad_output", 0); | |
auto& input_ = unpack(input, "input", 1); | |
auto& output_ = unpack(output, "output", 4); | |
- auto flags = Function::flags({ grad_output, input }); | |
- auto grad_fn = std::make_shared<SoftplusBackwardBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<SoftplusBackwardBackward> grad_fn; | |
+ auto flags = compute_flags({ grad_output, input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SoftplusBackwardBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ grad_output, input }); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
grad_fn->beta = beta; | |
grad_fn->threshold = threshold; | |
@@ -4923,9 +5787,9 @@ | |
grad_fn->grad_output_ = SavedVariable(grad_output, nullptr); | |
} | |
auto ret = as_variable(baseType->softplus_backward(grad_output_, input_, beta, threshold, output_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ grad_output, input })) { | |
- jit::Node *n = jit::tracer::recordTrace( "softplus_backward", { grad_output, input }, ret ); | |
+ set_flags(ret, flags, grad_fn); | |
+ if (jit::tracer::isTracing({ grad_output, input, output })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "softplus_backward", { grad_output, input, output }, ret ); | |
setattr(n, jit::stringToSymbol("beta"), beta); | |
setattr(n, jit::stringToSymbol("threshold"), threshold); | |
} | |
@@ -4933,14 +5797,17 @@ | |
} | |
Tensor VariableType::softshrink(const Tensor & input, Scalar lambd) const { | |
auto& input_ = unpack(input, "input", 0); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<SoftshrinkBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<SoftshrinkBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SoftshrinkBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
grad_fn->lambd = lambd; | |
} | |
auto ret = as_variable(baseType->softshrink_forward(input_, lambd)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ input })) { | |
jit::Node *n = jit::tracer::recordTrace( "softshrink", { input }, ret ); | |
setattr(n, jit::stringToSymbol("lambd"), lambd); | |
@@ -4950,14 +5817,17 @@ | |
Tensor VariableType::softshrink_backward(const Tensor & grad_output, const Tensor & input, Scalar lambd) const { | |
auto& grad_output_ = unpack(grad_output, "grad_output", 0); | |
auto& input_ = unpack(input, "input", 1); | |
- auto flags = Function::flags({ grad_output, input }); | |
- auto grad_fn = std::make_shared<SoftshrinkBackwardBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<SoftshrinkBackwardBackward> grad_fn; | |
+ auto flags = compute_flags({ grad_output, input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SoftshrinkBackwardBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ grad_output, input }); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
grad_fn->lambd = lambd; | |
} | |
auto ret = as_variable(baseType->softshrink_backward(grad_output_, input_, lambd)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ grad_output, input })) { | |
jit::Node *n = jit::tracer::recordTrace( "softshrink_backward", { grad_output, input }, ret ); | |
setattr(n, jit::stringToSymbol("lambd"), lambd); | |
@@ -4966,16 +5836,20 @@ | |
} | |
Tensor VariableType::threshold(const Tensor & input, Scalar threshold, Scalar value, bool inplace) const { | |
auto& input_ = unpack(input, "input", 0); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<ThresholdBackward>(); | |
- if (flags.is_executable) { | |
+ if (inplace) check_inplace(input); | |
+ std::shared_ptr<ThresholdBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<ThresholdBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
grad_fn->threshold = threshold; | |
grad_fn->value = value; | |
} | |
auto ret = as_variable(baseType->threshold_forward(input_, threshold, value, inplace)); | |
if (inplace) increment_version(input); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(ret, flags, grad_fn); | |
+ if (grad_fn) { | |
grad_fn->input_ = SavedVariable(input, grad_fn.get()); | |
} | |
if (jit::tracer::isTracing({ input })) { | |
@@ -4989,16 +5863,20 @@ | |
Tensor VariableType::threshold_backward(const Tensor & grad_output, const Tensor & input, Scalar threshold, Scalar value, bool inplace) const { | |
auto& grad_output_ = unpack(grad_output, "grad_output", 0); | |
auto& input_ = unpack(input, "input", 1); | |
- auto flags = Function::flags({ grad_output, input }); | |
- auto grad_fn = std::make_shared<ThresholdBackwardBackward>(); | |
- if (flags.is_executable) { | |
+ if (inplace) check_inplace(input); | |
+ std::shared_ptr<ThresholdBackwardBackward> grad_fn; | |
+ auto flags = compute_flags({ grad_output, input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<ThresholdBackwardBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ grad_output, input }); | |
grad_fn->threshold = threshold; | |
grad_fn->value = value; | |
} | |
auto ret = as_variable(baseType->threshold_backward(grad_output_, input_, threshold, value, inplace)); | |
if (inplace) increment_version(input); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(ret, flags, grad_fn); | |
+ if (grad_fn) { | |
grad_fn->input_ = SavedVariable(input, grad_fn.get()); | |
} | |
if (jit::tracer::isTracing({ grad_output, input })) { | |
@@ -5011,14 +5889,17 @@ | |
} | |
std::tuple<Tensor,Tensor> VariableType::adaptive_max_pool2d(const Tensor & input, IntList output_size) const { | |
auto& input_ = unpack(input, "input", 0); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<AdaptiveMaxPool2DBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<AdaptiveMaxPool2DBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<AdaptiveMaxPool2DBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
} | |
auto ret = as_variable(baseType->adaptive_max_pool2d_forward(input_, output_size)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(std::get<0>(ret), flags, grad_fn); | |
+ if (grad_fn) { | |
auto& indices = std::get<1>(ret); | |
grad_fn->indices_ = SavedVariable(indices, grad_fn.get()); | |
} | |
@@ -5032,21 +5913,29 @@ | |
auto& grad_output_ = unpack(grad_output, "grad_output", 0); | |
auto& input_ = unpack(input, "input", 1); | |
auto& indices_ = unpack_long(indices, "indices", 2); | |
- auto flags = Function::flags({ grad_output, input, indices }); | |
- auto grad_fn = std::make_shared<Error>("adaptive_max_pool2d_backward is not differentiable"); | |
+ std::shared_ptr<Error> grad_fn; | |
+ auto flags = compute_flags({ grad_output, input, indices }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<Error>("the derivative for adaptive_max_pool2d_backward is not implemented"); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ grad_output, input, indices }); | |
+ } | |
auto ret = as_variable(baseType->adaptive_max_pool2d_backward(grad_output_, input_, indices_)); | |
- wrap_output(ret, std::move(flags), std::move(grad_fn)); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ grad_output, input, indices })) { | |
jit::Node *n = jit::tracer::recordTrace( "adaptive_max_pool2d_backward", { grad_output, input, indices }, ret ); | |
(void)n; | |
} | |
- return ret; | |
+ return Tensor(std::move(ret)); | |
} | |
Tensor VariableType::avg_pool2d(const Tensor & input, IntList kernel_size, IntList stride, IntList padding, bool ceil_mode, bool count_include_pad) const { | |
auto& input_ = unpack(input, "input", 0); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<AvgPool2DBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<AvgPool2DBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<AvgPool2DBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
grad_fn->kernel_size = kernel_size; | |
grad_fn->stride = stride; | |
@@ -5055,7 +5944,7 @@ | |
grad_fn->count_include_pad = count_include_pad; | |
} | |
auto ret = as_variable(baseType->avg_pool2d_forward(input_, kernel_size, stride, padding, ceil_mode, count_include_pad)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ input })) { | |
jit::Node *n = jit::tracer::recordTrace( "avg_pool2d", { input }, ret ); | |
setattr(n, jit::stringToSymbol("kernel_size"), kernel_size); | |
@@ -5069,9 +5958,12 @@ | |
Tensor VariableType::avg_pool2d_backward(const Tensor & grad_output, const Tensor & input, IntList kernel_size, IntList stride, IntList padding, bool ceil_mode, bool count_include_pad) const { | |
auto& grad_output_ = unpack(grad_output, "grad_output", 0); | |
auto& input_ = unpack(input, "input", 1); | |
- auto flags = Function::flags({ grad_output, input }); | |
- auto grad_fn = std::make_shared<AvgPool2DBackwardBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<AvgPool2DBackwardBackward> grad_fn; | |
+ auto flags = compute_flags({ grad_output, input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<AvgPool2DBackwardBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ grad_output, input }); | |
grad_fn->kernel_size = kernel_size; | |
grad_fn->stride = stride; | |
grad_fn->padding = padding; | |
@@ -5080,7 +5972,7 @@ | |
grad_fn->input_info = input; | |
} | |
auto ret = as_variable(baseType->avg_pool2d_backward(grad_output_, input_, kernel_size, stride, padding, ceil_mode, count_include_pad)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ grad_output, input })) { | |
jit::Node *n = jit::tracer::recordTrace( "avg_pool2d_backward", { grad_output, input }, ret ); | |
setattr(n, jit::stringToSymbol("kernel_size"), kernel_size); | |
@@ -5093,9 +5985,12 @@ | |
} | |
Tensor VariableType::avg_pool3d(const Tensor & input, IntList kernel_size, IntList stride, IntList padding, bool ceil_mode, bool count_include_pad) const { | |
auto& input_ = unpack(input, "input", 0); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<AvgPool3DBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<AvgPool3DBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<AvgPool3DBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
grad_fn->kernel_size = kernel_size; | |
grad_fn->stride = stride; | |
@@ -5104,7 +5999,7 @@ | |
grad_fn->count_include_pad = count_include_pad; | |
} | |
auto ret = as_variable(baseType->avg_pool3d_forward(input_, kernel_size, stride, padding, ceil_mode, count_include_pad)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ input })) { | |
jit::Node *n = jit::tracer::recordTrace( "avg_pool3d", { input }, ret ); | |
setattr(n, jit::stringToSymbol("kernel_size"), kernel_size); | |
@@ -5118,9 +6013,12 @@ | |
Tensor VariableType::avg_pool3d_backward(const Tensor & grad_output, const Tensor & input, IntList kernel_size, IntList stride, IntList padding, bool ceil_mode, bool count_include_pad) const { | |
auto& grad_output_ = unpack(grad_output, "grad_output", 0); | |
auto& input_ = unpack(input, "input", 1); | |
- auto flags = Function::flags({ grad_output, input }); | |
- auto grad_fn = std::make_shared<AvgPool3DBackwardBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<AvgPool3DBackwardBackward> grad_fn; | |
+ auto flags = compute_flags({ grad_output, input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<AvgPool3DBackwardBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ grad_output, input }); | |
grad_fn->kernel_size = kernel_size; | |
grad_fn->stride = stride; | |
grad_fn->padding = padding; | |
@@ -5129,7 +6027,7 @@ | |
grad_fn->input_info = input; | |
} | |
auto ret = as_variable(baseType->avg_pool3d_backward(grad_output_, input_, kernel_size, stride, padding, ceil_mode, count_include_pad)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ grad_output, input })) { | |
jit::Node *n = jit::tracer::recordTrace( "avg_pool3d_backward", { grad_output, input }, ret ); | |
setattr(n, jit::stringToSymbol("kernel_size"), kernel_size); | |
@@ -5142,9 +6040,12 @@ | |
} | |
std::tuple<Tensor,Tensor> VariableType::max_pool2d(const Tensor & input, IntList kernel_size, IntList stride, IntList padding, IntList dilation, bool ceil_mode) const { | |
auto& input_ = unpack(input, "input", 0); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<MaxPool2DBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<MaxPool2DBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<MaxPool2DBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
grad_fn->kernel_size = kernel_size; | |
grad_fn->stride = stride; | |
@@ -5153,8 +6054,8 @@ | |
grad_fn->ceil_mode = ceil_mode; | |
} | |
auto ret = as_variable(baseType->max_pool2d_forward(input_, kernel_size, stride, padding, dilation, ceil_mode)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(std::get<0>(ret), flags, grad_fn); | |
+ if (grad_fn) { | |
auto& indices = std::get<1>(ret); | |
grad_fn->indices_ = SavedVariable(indices, grad_fn.get()); | |
} | |
@@ -5172,16 +6073,19 @@ | |
auto& grad_output_ = unpack(grad_output, "grad_output", 0); | |
auto& input_ = unpack(input, "input", 1); | |
auto& indices_ = unpack_long(indices, "indices", 7); | |
- auto flags = Function::flags({ grad_output, input }); | |
- auto grad_fn = std::make_shared<MaxPool2DBackwardBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<MaxPool2DBackwardBackward> grad_fn; | |
+ auto flags = compute_flags({ grad_output, input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<MaxPool2DBackwardBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ grad_output, input }); | |
grad_fn->indices_ = SavedVariable(indices, nullptr); | |
grad_fn->input_info = input; | |
} | |
auto ret = as_variable(baseType->max_pool2d_backward(grad_output_, input_, kernel_size, stride, padding, dilation, ceil_mode, indices_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ grad_output, input })) { | |
- jit::Node *n = jit::tracer::recordTrace( "max_pool2d_backward", { grad_output, input }, ret ); | |
+ set_flags(ret, flags, grad_fn); | |
+ if (jit::tracer::isTracing({ grad_output, input, indices })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "max_pool2d_backward", { grad_output, input, indices }, ret ); | |
setattr(n, jit::stringToSymbol("kernel_size"), kernel_size); | |
setattr(n, jit::stringToSymbol("stride"), stride); | |
setattr(n, jit::stringToSymbol("padding"), padding); | |
@@ -5192,9 +6096,12 @@ | |
} | |
std::tuple<Tensor,Tensor> VariableType::max_pool3d(const Tensor & input, IntList kernel_size, IntList stride, IntList padding, IntList dilation, bool ceil_mode) const { | |
auto& input_ = unpack(input, "input", 0); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<MaxPool3DBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<MaxPool3DBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<MaxPool3DBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
grad_fn->kernel_size = kernel_size; | |
grad_fn->stride = stride; | |
@@ -5203,8 +6110,8 @@ | |
grad_fn->ceil_mode = ceil_mode; | |
} | |
auto ret = as_variable(baseType->max_pool3d_forward(input_, kernel_size, stride, padding, dilation, ceil_mode)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(std::get<0>(ret), flags, grad_fn); | |
+ if (grad_fn) { | |
auto& indices = std::get<1>(ret); | |
grad_fn->indices_ = SavedVariable(indices, grad_fn.get()); | |
} | |
@@ -5222,10 +6129,15 @@ | |
auto& grad_output_ = unpack(grad_output, "grad_output", 0); | |
auto& input_ = unpack(input, "input", 1); | |
auto& indices_ = unpack_long(indices, "indices", 7); | |
- auto flags = Function::flags({ grad_output, input, indices }); | |
- auto grad_fn = std::make_shared<Error>("max_pool3d_backward is not differentiable"); | |
+ std::shared_ptr<Error> grad_fn; | |
+ auto flags = compute_flags({ grad_output, input, indices }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<Error>("the derivative for max_pool3d_backward is not implemented"); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ grad_output, input, indices }); | |
+ } | |
auto ret = as_variable(baseType->max_pool3d_backward(grad_output_, input_, kernel_size, stride, padding, dilation, ceil_mode, indices_)); | |
- wrap_output(ret, std::move(flags), std::move(grad_fn)); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ grad_output, input, indices })) { | |
jit::Node *n = jit::tracer::recordTrace( "max_pool3d_backward", { grad_output, input, indices }, ret ); | |
setattr(n, jit::stringToSymbol("kernel_size"), kernel_size); | |
@@ -5234,22 +6146,25 @@ | |
setattr(n, jit::stringToSymbol("dilation"), dilation); | |
setattr(n, jit::stringToSymbol("ceil_mode"), ceil_mode); | |
} | |
- return ret; | |
+ return Tensor(std::move(ret)); | |
} | |
Tensor VariableType::max_unpool2d(const Tensor & input, const Tensor & indices, IntList output_size) const { | |
auto& input_ = unpack(input, "input", 0); | |
auto& indices_ = unpack_long(indices, "indices", 1); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<MaxUnpool2DBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<MaxUnpool2DBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<MaxUnpool2DBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
grad_fn->indices_ = SavedVariable(indices, nullptr); | |
grad_fn->output_size = output_size; | |
} | |
auto ret = as_variable(baseType->max_unpool2d_forward(input_, indices_, output_size)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ input })) { | |
- jit::Node *n = jit::tracer::recordTrace( "max_unpool2d", { input }, ret ); | |
+ set_flags(ret, flags, grad_fn); | |
+ if (jit::tracer::isTracing({ input, indices })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "max_unpool2d", { input, indices }, ret ); | |
setattr(n, jit::stringToSymbol("output_size"), output_size); | |
} | |
return Tensor(std::move(ret)); | |
@@ -5258,17 +6173,20 @@ | |
auto& grad_output_ = unpack(grad_output, "grad_output", 0); | |
auto& input_ = unpack(input, "input", 1); | |
auto& indices_ = unpack_long(indices, "indices", 2); | |
- auto flags = Function::flags({ grad_output, input }); | |
- auto grad_fn = std::make_shared<MaxUnpool2DBackwardBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<MaxUnpool2DBackwardBackward> grad_fn; | |
+ auto flags = compute_flags({ grad_output, input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<MaxUnpool2DBackwardBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ grad_output, input }); | |
grad_fn->indices_ = SavedVariable(indices, nullptr); | |
grad_fn->output_size = output_size; | |
grad_fn->input_info = input; | |
} | |
auto ret = as_variable(baseType->max_unpool2d_backward(grad_output_, input_, indices_, output_size)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ grad_output, input })) { | |
- jit::Node *n = jit::tracer::recordTrace( "max_unpool2d_backward", { grad_output, input }, ret ); | |
+ set_flags(ret, flags, grad_fn); | |
+ if (jit::tracer::isTracing({ grad_output, input, indices })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "max_unpool2d_backward", { grad_output, input, indices }, ret ); | |
setattr(n, jit::stringToSymbol("output_size"), output_size); | |
} | |
return Tensor(std::move(ret)); | |
@@ -5276,9 +6194,12 @@ | |
Tensor VariableType::max_unpool3d(const Tensor & input, const Tensor & indices, IntList output_size, IntList stride, IntList padding) const { | |
auto& input_ = unpack(input, "input", 0); | |
auto& indices_ = unpack_long(indices, "indices", 1); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<MaxUnpool3DBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<MaxUnpool3DBackward> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<MaxUnpool3DBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
grad_fn->input_ = SavedVariable(input, nullptr); | |
grad_fn->indices_ = SavedVariable(indices, nullptr); | |
grad_fn->output_size = output_size; | |
@@ -5286,9 +6207,9 @@ | |
grad_fn->padding = padding; | |
} | |
auto ret = as_variable(baseType->max_unpool3d_forward(input_, indices_, output_size, stride, padding)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (jit::tracer::isTracing({ input })) { | |
- jit::Node *n = jit::tracer::recordTrace( "max_unpool3d", { input }, ret ); | |
+ set_flags(ret, flags, grad_fn); | |
+ if (jit::tracer::isTracing({ input, indices })) { | |
+ jit::Node *n = jit::tracer::recordTrace( "max_unpool3d", { input, indices }, ret ); | |
setattr(n, jit::stringToSymbol("output_size"), output_size); | |
setattr(n, jit::stringToSymbol("stride"), stride); | |
setattr(n, jit::stringToSymbol("padding"), padding); | |
@@ -5299,25 +6220,35 @@ | |
auto& grad_output_ = unpack(grad_output, "grad_output", 0); | |
auto& input_ = unpack(input, "input", 1); | |
auto& indices_ = unpack_long(indices, "indices", 2); | |
- auto flags = Function::flags({ grad_output, input, indices }); | |
- auto grad_fn = std::make_shared<Error>("max_unpool3d_backward is not differentiable"); | |
+ std::shared_ptr<Error> grad_fn; | |
+ auto flags = compute_flags({ grad_output, input, indices }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<Error>("the derivative for max_unpool3d_backward is not implemented"); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ grad_output, input, indices }); | |
+ } | |
auto ret = as_variable(baseType->max_unpool3d_backward(grad_output_, input_, indices_, output_size, stride, padding)); | |
- wrap_output(ret, std::move(flags), std::move(grad_fn)); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ grad_output, input, indices })) { | |
jit::Node *n = jit::tracer::recordTrace( "max_unpool3d_backward", { grad_output, input, indices }, ret ); | |
setattr(n, jit::stringToSymbol("output_size"), output_size); | |
setattr(n, jit::stringToSymbol("stride"), stride); | |
setattr(n, jit::stringToSymbol("padding"), padding); | |
} | |
- return ret; | |
+ return Tensor(std::move(ret)); | |
} | |
Tensor VariableType::_sigmoid(const Tensor & input) const { | |
auto& input_ = unpack(input, "input", 0); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<SigmoidBackward1>(); | |
+ std::shared_ptr<SigmoidBackward1> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SigmoidBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
+ } | |
auto ret = as_variable(baseType->_sigmoid_forward(input_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(ret, flags, grad_fn); | |
+ if (grad_fn) { | |
auto& output = ret; | |
grad_fn->output_ = SavedVariable(output, grad_fn.get()); | |
} | |
@@ -5330,14 +6261,17 @@ | |
Tensor VariableType::_sigmoid_backward(const Tensor & grad_output, const Tensor & output) const { | |
auto& grad_output_ = unpack(grad_output, "grad_output", 0); | |
auto& output_ = unpack(output, "output", 1); | |
- auto flags = Function::flags({ grad_output, output }); | |
- auto grad_fn = std::make_shared<SigmoidBackwardBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<SigmoidBackwardBackward> grad_fn; | |
+ auto flags = compute_flags({ grad_output, output }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<SigmoidBackwardBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ grad_output, output }); | |
grad_fn->output_ = SavedVariable(output, nullptr); | |
grad_fn->grad_output_ = SavedVariable(grad_output, nullptr); | |
} | |
auto ret = as_variable(baseType->_sigmoid_backward(grad_output_, output_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ grad_output, output })) { | |
jit::Node *n = jit::tracer::recordTrace( "_sigmoid_backward", { grad_output, output }, ret ); | |
(void)n; | |
@@ -5346,11 +6280,16 @@ | |
} | |
Tensor VariableType::_tanh(const Tensor & input) const { | |
auto& input_ = unpack(input, "input", 0); | |
- auto flags = Function::flags({ input }); | |
- auto grad_fn = std::make_shared<TanhBackward1>(); | |
+ std::shared_ptr<TanhBackward1> grad_fn; | |
+ auto flags = compute_flags({ input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<TanhBackward1>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ input }); | |
+ } | |
auto ret = as_variable(baseType->_tanh_forward(input_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
- if (flags.is_executable) { | |
+ set_flags(ret, flags, grad_fn); | |
+ if (grad_fn) { | |
auto& output = ret; | |
grad_fn->output_ = SavedVariable(output, grad_fn.get()); | |
} | |
@@ -5363,14 +6302,17 @@ | |
Tensor VariableType::_tanh_backward(const Tensor & grad_output, const Tensor & output) const { | |
auto& grad_output_ = unpack(grad_output, "grad_output", 0); | |
auto& output_ = unpack(output, "output", 1); | |
- auto flags = Function::flags({ grad_output, output }); | |
- auto grad_fn = std::make_shared<TanhBackwardBackward>(); | |
- if (flags.is_executable) { | |
+ std::shared_ptr<TanhBackwardBackward> grad_fn; | |
+ auto flags = compute_flags({ grad_output, output }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<TanhBackwardBackward>(); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ grad_output, output }); | |
grad_fn->output_ = SavedVariable(output, nullptr); | |
grad_fn->grad_output_ = SavedVariable(grad_output, nullptr); | |
} | |
auto ret = as_variable(baseType->_tanh_backward(grad_output_, output_)); | |
- wrap_output(ret, std::move(flags), grad_fn); | |
+ set_flags(ret, flags, grad_fn); | |
if (jit::tracer::isTracing({ grad_output, output })) { | |
jit::Node *n = jit::tracer::recordTrace( "_tanh_backward", { grad_output, output }, ret ); | |
(void)n; | |
@@ -5388,17 +6330,22 @@ | |
auto& running_var_ = unpack(running_var, "running_var", 4); | |
auto& save_mean_ = unpack(save_mean, "save_mean", 7); | |
auto& save_std_ = unpack(save_std, "save_std", 8); | |
- auto flags = Function::flags({ grad_output, input, weight, running_mean, running_var, save_mean, save_std }); | |
- auto grad_fn = std::make_shared<Error>("batch_norm_backward is not differentiable"); | |
+ std::shared_ptr<Error> grad_fn; | |
+ auto flags = compute_flags({ grad_output, input, weight, running_mean, running_var, save_mean, save_std }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<Error>("the derivative for batch_norm_backward is not implemented"); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ grad_output, input, weight, running_mean, running_var, save_mean, save_std }); | |
+ } | |
auto ret = as_variable(baseType->batch_norm_backward(grad_output_, input_, weight_, running_mean_, running_var_, training, eps, save_mean_, save_std_, output_mask)); | |
- wrap_output(ret, std::move(flags), std::move(grad_fn)); | |
+ set_flags(std::get<0>(ret), flags, grad_fn); | |
if (jit::tracer::isTracing({ grad_output, input, weight, running_mean, running_var, save_mean, save_std })) { | |
jit::Node *n = jit::tracer::recordTrace( "batch_norm_backward", { grad_output, input, weight, running_mean, running_var, save_mean, save_std }, ret ); | |
setattr(n, jit::stringToSymbol("training"), training); | |
setattr(n, jit::stringToSymbol("eps"), eps); | |
setattr(n, jit::stringToSymbol("output_mask"), output_mask); | |
} | |
- return ret; | |
+ return std::tuple<Tensor,Tensor,Tensor>(std::move(ret)); | |
} | |
Tensor VariableType::conv_transpose2d(const Tensor & input, const Tensor & weight, IntList kernel_size, const Tensor & bias, IntList stride, IntList padding, IntList output_padding, IntList dilation) const { | |
throw std::runtime_error("conv_transpose2d: NYI"); | |
@@ -5409,10 +6356,15 @@ | |
auto& weight_ = unpack(weight, "weight", 2); | |
auto& columns_ = unpack(columns, "columns", 8); | |
auto& ones_ = unpack(ones, "ones", 9); | |
- auto flags = Function::flags({ grad_output, input, weight, columns, ones }); | |
- auto grad_fn = std::make_shared<Error>("conv_transpose2d_backward is not differentiable"); | |
+ std::shared_ptr<Error> grad_fn; | |
+ auto flags = compute_flags({ grad_output, input, weight, columns, ones }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<Error>("the derivative for conv_transpose2d_backward is not implemented"); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ grad_output, input, weight, columns, ones }); | |
+ } | |
auto ret = as_variable(baseType->conv_transpose2d_backward(grad_output_, input_, weight_, kernel_size, stride, padding, output_padding, dilation, columns_, ones_, output_mask)); | |
- wrap_output(ret, std::move(flags), std::move(grad_fn)); | |
+ set_flags(std::get<0>(ret), flags, grad_fn); | |
if (jit::tracer::isTracing({ grad_output, input, weight, columns, ones })) { | |
jit::Node *n = jit::tracer::recordTrace( "conv_transpose2d_backward", { grad_output, input, weight, columns, ones }, ret ); | |
setattr(n, jit::stringToSymbol("kernel_size"), kernel_size); | |
@@ -5422,7 +6374,7 @@ | |
setattr(n, jit::stringToSymbol("dilation"), dilation); | |
setattr(n, jit::stringToSymbol("output_mask"), output_mask); | |
} | |
- return ret; | |
+ return std::tuple<Tensor,Tensor,Tensor>(std::move(ret)); | |
} | |
Tensor VariableType::conv_transpose3d(const Tensor & input, const Tensor & weight, const Tensor & bias, IntList stride, IntList padding, IntList output_padding, IntList dilation) const { | |
throw std::runtime_error("conv_transpose3d: NYI"); | |
@@ -5433,10 +6385,15 @@ | |
auto& weight_ = unpack(weight, "weight", 2); | |
auto& finput_ = unpack(finput, "finput", 7); | |
auto& fgrad_input_ = unpack(fgrad_input, "fgrad_input", 8); | |
- auto flags = Function::flags({ grad_output, input, weight, finput, fgrad_input }); | |
- auto grad_fn = std::make_shared<Error>("conv_transpose3d_backward is not differentiable"); | |
+ std::shared_ptr<Error> grad_fn; | |
+ auto flags = compute_flags({ grad_output, input, weight, finput, fgrad_input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<Error>("the derivative for conv_transpose3d_backward is not implemented"); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ grad_output, input, weight, finput, fgrad_input }); | |
+ } | |
auto ret = as_variable(baseType->conv_transpose3d_backward(grad_output_, input_, weight_, stride, padding, output_padding, dilation, finput_, fgrad_input_, output_mask)); | |
- wrap_output(ret, std::move(flags), std::move(grad_fn)); | |
+ set_flags(std::get<0>(ret), flags, grad_fn); | |
if (jit::tracer::isTracing({ grad_output, input, weight, finput, fgrad_input })) { | |
jit::Node *n = jit::tracer::recordTrace( "conv_transpose3d_backward", { grad_output, input, weight, finput, fgrad_input }, ret ); | |
setattr(n, jit::stringToSymbol("stride"), stride); | |
@@ -5445,7 +6402,7 @@ | |
setattr(n, jit::stringToSymbol("dilation"), dilation); | |
setattr(n, jit::stringToSymbol("output_mask"), output_mask); | |
} | |
- return ret; | |
+ return std::tuple<Tensor,Tensor,Tensor>(std::move(ret)); | |
} | |
Tensor VariableType::conv2d(const Tensor & input, const Tensor & weight, IntList kernel_size, const Tensor & bias, IntList stride, IntList padding) const { | |
throw std::runtime_error("conv2d: NYI"); | |
@@ -5456,10 +6413,15 @@ | |
auto& weight_ = unpack(weight, "weight", 2); | |
auto& finput_ = unpack(finput, "finput", 6); | |
auto& fgrad_input_ = unpack(fgrad_input, "fgrad_input", 7); | |
- auto flags = Function::flags({ grad_output, input, weight, finput, fgrad_input }); | |
- auto grad_fn = std::make_shared<Error>("conv2d_backward is not differentiable"); | |
+ std::shared_ptr<Error> grad_fn; | |
+ auto flags = compute_flags({ grad_output, input, weight, finput, fgrad_input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<Error>("the derivative for conv2d_backward is not implemented"); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ grad_output, input, weight, finput, fgrad_input }); | |
+ } | |
auto ret = as_variable(baseType->conv2d_backward(grad_output_, input_, weight_, kernel_size, stride, padding, finput_, fgrad_input_, output_mask)); | |
- wrap_output(ret, std::move(flags), std::move(grad_fn)); | |
+ set_flags(std::get<0>(ret), flags, grad_fn); | |
if (jit::tracer::isTracing({ grad_output, input, weight, finput, fgrad_input })) { | |
jit::Node *n = jit::tracer::recordTrace( "conv2d_backward", { grad_output, input, weight, finput, fgrad_input }, ret ); | |
setattr(n, jit::stringToSymbol("kernel_size"), kernel_size); | |
@@ -5467,7 +6429,7 @@ | |
setattr(n, jit::stringToSymbol("padding"), padding); | |
setattr(n, jit::stringToSymbol("output_mask"), output_mask); | |
} | |
- return ret; | |
+ return std::tuple<Tensor,Tensor,Tensor>(std::move(ret)); | |
} | |
Tensor VariableType::conv_depthwise2d(const Tensor & input, const Tensor & weight, IntList kernel_size, const Tensor & bias, IntList stride, IntList padding, IntList dilation) const { | |
throw std::runtime_error("conv_depthwise2d: NYI"); | |
@@ -5476,10 +6438,15 @@ | |
auto& grad_output_ = unpack(grad_output, "grad_output", 0); | |
auto& input_ = unpack(input, "input", 1); | |
auto& weight_ = unpack(weight, "weight", 2); | |
- auto flags = Function::flags({ grad_output, input, weight }); | |
- auto grad_fn = std::make_shared<Error>("conv_depthwise2d_backward is not differentiable"); | |
+ std::shared_ptr<Error> grad_fn; | |
+ auto flags = compute_flags({ grad_output, input, weight }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<Error>("the derivative for conv_depthwise2d_backward is not implemented"); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ grad_output, input, weight }); | |
+ } | |
auto ret = as_variable(baseType->conv_depthwise2d_backward(grad_output_, input_, weight_, kernel_size, stride, padding, dilation, output_mask)); | |
- wrap_output(ret, std::move(flags), std::move(grad_fn)); | |
+ set_flags(std::get<0>(ret), flags, grad_fn); | |
if (jit::tracer::isTracing({ grad_output, input, weight })) { | |
jit::Node *n = jit::tracer::recordTrace( "conv_depthwise2d_backward", { grad_output, input, weight }, ret ); | |
setattr(n, jit::stringToSymbol("kernel_size"), kernel_size); | |
@@ -5488,7 +6455,7 @@ | |
setattr(n, jit::stringToSymbol("dilation"), dilation); | |
setattr(n, jit::stringToSymbol("output_mask"), output_mask); | |
} | |
- return ret; | |
+ return std::tuple<Tensor,Tensor>(std::move(ret)); | |
} | |
Tensor VariableType::conv3d(const Tensor & input, const Tensor & weight, IntList kernel_size, const Tensor & bias, IntList stride, IntList padding) const { | |
throw std::runtime_error("conv3d: NYI"); | |
@@ -5499,10 +6466,15 @@ | |
auto& weight_ = unpack(weight, "weight", 2); | |
auto& finput_ = unpack(finput, "finput", 6); | |
auto& fgrad_input_ = unpack(fgrad_input, "fgrad_input", 7); | |
- auto flags = Function::flags({ grad_output, input, weight, finput, fgrad_input }); | |
- auto grad_fn = std::make_shared<Error>("conv3d_backward is not differentiable"); | |
+ std::shared_ptr<Error> grad_fn; | |
+ auto flags = compute_flags({ grad_output, input, weight, finput, fgrad_input }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<Error>("the derivative for conv3d_backward is not implemented"); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ grad_output, input, weight, finput, fgrad_input }); | |
+ } | |
auto ret = as_variable(baseType->conv3d_backward(grad_output_, input_, weight_, kernel_size, stride, padding, finput_, fgrad_input_, output_mask)); | |
- wrap_output(ret, std::move(flags), std::move(grad_fn)); | |
+ set_flags(std::get<0>(ret), flags, grad_fn); | |
if (jit::tracer::isTracing({ grad_output, input, weight, finput, fgrad_input })) { | |
jit::Node *n = jit::tracer::recordTrace( "conv3d_backward", { grad_output, input, weight, finput, fgrad_input }, ret ); | |
setattr(n, jit::stringToSymbol("kernel_size"), kernel_size); | |
@@ -5510,7 +6482,7 @@ | |
setattr(n, jit::stringToSymbol("padding"), padding); | |
setattr(n, jit::stringToSymbol("output_mask"), output_mask); | |
} | |
- return ret; | |
+ return std::tuple<Tensor,Tensor,Tensor>(std::move(ret)); | |
} | |
Tensor VariableType::conv_dilated2d(const Tensor & input, const Tensor & weight, IntList kernel_size, const Tensor & bias, IntList stride, IntList padding, IntList dilation) const { | |
throw std::runtime_error("conv_dilated2d: NYI"); | |
@@ -5521,10 +6493,15 @@ | |
auto& weight_ = unpack(weight, "weight", 2); | |
auto& columns_ = unpack(columns, "columns", 7); | |
auto& ones_ = unpack(ones, "ones", 8); | |
- auto flags = Function::flags({ grad_output, input, weight, columns, ones }); | |
- auto grad_fn = std::make_shared<Error>("conv_dilated2d_backward is not differentiable"); | |
+ std::shared_ptr<Error> grad_fn; | |
+ auto flags = compute_flags({ grad_output, input, weight, columns, ones }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<Error>("the derivative for conv_dilated2d_backward is not implemented"); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ grad_output, input, weight, columns, ones }); | |
+ } | |
auto ret = as_variable(baseType->conv_dilated2d_backward(grad_output_, input_, weight_, kernel_size, stride, padding, dilation, columns_, ones_, output_mask)); | |
- wrap_output(ret, std::move(flags), std::move(grad_fn)); | |
+ set_flags(std::get<0>(ret), flags, grad_fn); | |
if (jit::tracer::isTracing({ grad_output, input, weight, columns, ones })) { | |
jit::Node *n = jit::tracer::recordTrace( "conv_dilated2d_backward", { grad_output, input, weight, columns, ones }, ret ); | |
setattr(n, jit::stringToSymbol("kernel_size"), kernel_size); | |
@@ -5533,7 +6510,7 @@ | |
setattr(n, jit::stringToSymbol("dilation"), dilation); | |
setattr(n, jit::stringToSymbol("output_mask"), output_mask); | |
} | |
- return ret; | |
+ return std::tuple<Tensor,Tensor,Tensor>(std::move(ret)); | |
} | |
Tensor VariableType::conv_dilated3d(const Tensor & input, const Tensor & weight, IntList kernel_size, const Tensor & bias, IntList stride, IntList padding, IntList dilation) const { | |
throw std::runtime_error("conv_dilated3d: NYI"); | |
@@ -5544,10 +6521,15 @@ | |
auto& weight_ = unpack(weight, "weight", 2); | |
auto& columns_ = unpack(columns, "columns", 7); | |
auto& ones_ = unpack(ones, "ones", 8); | |
- auto flags = Function::flags({ grad_output, input, weight, columns, ones }); | |
- auto grad_fn = std::make_shared<Error>("conv_dilated3d_backward is not differentiable"); | |
+ std::shared_ptr<Error> grad_fn; | |
+ auto flags = compute_flags({ grad_output, input, weight, columns, ones }); | |
+ if (flags.requires_grad) { | |
+ grad_fn = std::make_shared<Error>("the derivative for conv_dilated3d_backward is not implemented"); | |
+ grad_fn->is_executable = true; | |
+ grad_fn->next_functions = compute_next_functions({ grad_output, input, weight, columns, ones }); | |
+ } | |
auto ret = as_variable(baseType->conv_dilated3d_backward(grad_output_, input_, weight_, kernel_size, stride, padding, dilation, columns_, ones_, output_mask)); | |
- wrap_output(ret, std::move(flags), std::move(grad_fn)); | |
+ set_flags(std::get<0>(ret), flags, grad_fn); | |
if (jit::tracer::isTracing({ grad_output, input, weight, columns, ones })) { | |
jit::Node *n = jit::tracer::recordTrace( "conv_dilated3d_backward", { grad_output, input, weight, columns, ones }, ret ); | |
setattr(n, jit::stringToSymbol("kernel_size"), kernel_size); | |
@@ -5556,7 +6538,7 @@ | |
setattr(n, jit::stringToSymbol("dilation"), dilation); | |
setattr(n, jit::stringToSymbol("output_mask"), output_mask); | |
} | |
- return ret; | |
+ return std::tuple<Tensor,Tensor,Tensor>(std::move(ret)); | |
} | |
std::vector<Tensor> VariableType::split(Tensor self, int64_t split_size, int64_t dim) const { | |
throw std::runtime_error("split: NYI"); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment