Skip to content

Instantly share code, notes, and snippets.

@colesbury
Created October 21, 2017 11:32
Show Gist options
  • Save colesbury/1b9d4ed6a05c0e49fc0c95eab8f997f5 to your computer and use it in GitHub Desktop.
Save colesbury/1b9d4ed6a05c0e49fc0c95eab8f997f5 to your computer and use it in GitHub Desktop.
--- VariableType.cpp 2017-10-21 04:32:34.984221040 -0700
+++ ./torch/csrc/autograd/generated/VariableType.cpp 2017-10-21 04:32:46.720255617 -0700
@@ -173,12 +173,61 @@
return make_variable(std::move(tensor));
}
-static void check_inplace(const VariableImpl& pImpl) {
- if (pImpl.requires_grad && !pImpl.grad_fn) {
+struct VariableFlags {
+ bool requires_grad;
+ bool is_volatile;
+};
+
+template<typename T>
+static VariableFlags compute_flags_tmpl(T tensors) {
+ VariableFlags flags = {false, false};
+ for (const Tensor& tensor : tensors) {
+ auto& var = static_cast<const Variable&>(tensor);
+ if (var.defined()) {
+ flags.requires_grad |= var.requires_grad();
+ flags.is_volatile |= var.is_volatile();
+ }
+ }
+ flags.requires_grad &= !flags.is_volatile;
+ return flags;
+}
+
+using TensorRef = std::reference_wrapper<const Tensor>;
+using TensorRefList = std::initializer_list<TensorRef>;
+
+static VariableFlags compute_flags(const TensorRefList& tensors) {
+ return compute_flags_tmpl(tensors);
+}
+
+static VariableFlags compute_flags(TensorList tensors) {
+ return compute_flags_tmpl(tensors);
+}
+
+static void check_no_requires_grad(const Tensor& tensor, const char* name) {
+ auto& var = static_cast<const Variable&>(tensor);
+ if (var.defined() && var.requires_grad()) {
+ std::string msg = "the derivative for '";
+ msg += name;
+ msg += "' is not implemented";
+ throw std::runtime_error(msg);
+ }
+}
+
+static function_list compute_next_functions(const std::initializer_list<Tensor>& tensors) {
+ return Function::flags(tensors).next_functions;
+}
+
+static function_list compute_next_functions(TensorList tensors) {
+ return Function::flags(tensors).next_functions;
+}
+
+static void check_inplace(const Tensor& tensor) {
+ auto& var = static_cast<const Variable&>(tensor);
+ if (var.requires_grad() && !var.grad_fn()) {
at::runtime_error(
"a leaf Variable that requires grad has been used in an in-place operation.");
}
- auto live_refs = pImpl.version_counter.live_refs();
+ auto live_refs = var.version_counter().live_refs();
if (live_refs > 1) {
at::runtime_error(
"in-place operations can be only used on variables that don't share "
@@ -187,29 +236,13 @@
}
}
-static void wrap_output(VariableImpl& pImpl, FunctionFlags flags, std::shared_ptr<Function> grad_fn) {
- // Hooks up the grad_fn and sets the flags of the function output. This only
- // supports a single differentiable output.
- pImpl.requires_grad = flags.is_executable;
- pImpl.is_volatile = flags.is_volatile;
- if (!flags.is_volatile) {
- pImpl.output_nr = grad_fn->num_inputs++;
- grad_fn->set_flags(std::move(flags));
- pImpl.grad_fn = std::move(grad_fn);
- }
-}
-
-static void wrap_output(Tensor& t, FunctionFlags flags, std::shared_ptr<Function> grad_fn) {
- auto pImpl = static_cast<VariableImpl*>(t.get());
- wrap_output(*pImpl, std::move(flags), std::move(grad_fn));
-}
-
-static void wrap_output(std::tuple<Variable, Variable>& t, FunctionFlags flags, std::shared_ptr<Function> grad_fn) {
- wrap_output(std::get<0>(t), std::move(flags), std::move(grad_fn));
+static void set_flags(Variable& var, VariableFlags flags, std::shared_ptr<Function> grad_fn) {
+ var.requires_grad() = flags.requires_grad;
+ var.is_volatile() = flags.is_volatile;
+ if (grad_fn) {
+ var.output_nr() = grad_fn->num_inputs++;
+ var.grad_fn() = std::move(grad_fn);
}
-
-static void wrap_output(std::tuple<Variable, Variable, Variable>& t, FunctionFlags flags, std::shared_ptr<Function> grad_fn) {
- wrap_output(std::get<0>(t), std::move(flags), std::move(grad_fn));
}
static void increment_version(const Tensor & t) {
@@ -234,30 +267,29 @@
// it automatically
auto& src_ = unpack_any(src, "src", 0);
auto& dst_ = unpack(dst, "dst", 1);
- auto& pImpl = static_cast<VariableImpl&>(*dst.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ src });
- baseType->s_copy(src_, dst_);
- pImpl.version_counter.increment();
- if (isFloatingPoint(dst.type().scalarType())) {
- if (isFloatingPoint(src.type().scalarType())) {
- // TODO: handle type conversions
- wrap_output(pImpl, std::move(flags), std::make_shared<Identity>());
- } else {
+ check_inplace(dst);
+ std::shared_ptr<Identity> grad_fn;
+ auto flags = compute_flags({ src });
+ flags.requires_grad &= isFloatingPoint(dst.type().scalarType());
+ if (flags.requires_grad) {
// TODO: handle type conversions
- wrap_output(pImpl, std::move(flags), std::make_shared<Identity>());
- }
+ grad_fn = std::make_shared<Identity>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ src });
}
+ baseType->s_copy(src_, dst_);
+ increment_version(dst);
+ set_flags(static_cast<Variable&>(dst), flags, std::move(grad_fn));
}
Tensor & VariableType::m_resize_(Tensor & self, IntList size) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- if (pImpl.grad_fn) {
+ check_inplace(self);
+ auto& self_var = static_cast<Variable&>(self);
+ if (self_var.grad_fn()) {
at::runtime_error("cannot resize non-leaf variables");
}
- if (pImpl.requires_grad) {
+ if (self_var.requires_grad()) {
at::runtime_error("cannot resize variables which require grad");
}
baseType->m_resize_(self_, size);
@@ -333,15 +365,20 @@
}
Tensor & VariableType::m_fill_(Tensor & self, Scalar value) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ });
- auto grad_fn = std::make_shared<FillBackward>();
+ check_inplace(self);
+ check_no_requires_grad(self, "self");
+ std::shared_ptr<FillBackward> grad_fn;
+ auto flags = compute_flags({ });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<FillBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ });
+ }
baseType->m_fill_(self_, value);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ })) {
- jit::Node *n = jit::tracer::recordTrace( "fill_", { }, self );
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
+ if (jit::tracer::isTracing({ self })) {
+ jit::Node *n = jit::tracer::recordTrace( "fill_", { self }, self );
setattr(n, jit::stringToSymbol("value"), value);
}
return self;
@@ -363,18 +400,20 @@
Tensor & VariableType::s_m_masked_fill_(Tensor & self, const Tensor & mask, Scalar value) const {
auto& self_ = unpack(self, "self", 0);
auto& mask_ = unpack_byte(mask, "mask", 1);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<MaskedFillBackward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<MaskedFillBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<MaskedFillBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->mask_ = SavedVariable(mask, nullptr);
}
baseType->s_m_masked_fill_(self_, mask_, value);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ self })) {
- jit::Node *n = jit::tracer::recordTrace( "masked_fill_", { self }, self );
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
+ if (jit::tracer::isTracing({ self, mask })) {
+ jit::Node *n = jit::tracer::recordTrace( "masked_fill_", { self, mask }, self );
setattr(n, jit::stringToSymbol("value"), value);
}
return self;
@@ -383,19 +422,21 @@
auto& self_ = unpack(self, "self", 0);
auto& mask_ = unpack_byte(mask, "mask", 1);
auto& source_ = unpack(source, "source", 2);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self, source });
- auto grad_fn = std::make_shared<MaskedScatterBackward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<MaskedScatterBackward> grad_fn;
+ auto flags = compute_flags({ self, source });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<MaskedScatterBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, source });
grad_fn->mask_ = SavedVariable(mask, nullptr);
grad_fn->source_sizes = source.sizes();
}
baseType->s_m_masked_scatter_(self_, mask_, source_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ self, source })) {
- jit::Node *n = jit::tracer::recordTrace( "masked_scatter_", { self, source }, self );
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
+ if (jit::tracer::isTracing({ self, mask, source })) {
+ jit::Node *n = jit::tracer::recordTrace( "masked_scatter_", { self, mask, source }, self );
(void)n;
}
return self;
@@ -403,31 +444,37 @@
Tensor VariableType::s_masked_select(const Tensor & self, const Tensor & mask) const {
auto& self_ = unpack(self, "self", 0);
auto& mask_ = unpack_byte(mask, "mask", 1);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<MaskedSelectBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<MaskedSelectBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<MaskedSelectBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_info = self;
grad_fn->mask_ = SavedVariable(mask, nullptr);
}
auto ret = as_variable(baseType->s_masked_select(self_, mask_));
- wrap_output(ret, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ self })) {
- jit::Node *n = jit::tracer::recordTrace( "masked_select", { self }, ret );
+ set_flags(ret, flags, grad_fn);
+ if (jit::tracer::isTracing({ self, mask })) {
+ jit::Node *n = jit::tracer::recordTrace( "masked_select", { self, mask }, ret );
(void)n;
}
return Tensor(std::move(ret));
}
Tensor VariableType::transpose(const Tensor & self, int64_t dim0, int64_t dim1) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<TransposeBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<TransposeBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<TransposeBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->dim0 = dim0;
grad_fn->dim1 = dim1;
}
auto ret = as_variable(baseType->transpose(self_, dim0, dim1));
take_version_counter(ret, self);
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "transpose", { self }, ret );
setattr(n, jit::stringToSymbol("dim0"), dim0);
@@ -437,17 +484,19 @@
}
Tensor & VariableType::m_transpose_(Tensor & self, int64_t dim0, int64_t dim1) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<TransposeBackward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<TransposeBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<TransposeBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->dim0 = dim0;
grad_fn->dim1 = dim1;
}
baseType->m_transpose_(self_, dim0, dim1);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "transpose_", { self }, self );
setattr(n, jit::stringToSymbol("dim0"), dim0);
@@ -457,11 +506,16 @@
}
Tensor VariableType::t(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<TBackward>();
+ std::shared_ptr<TBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<TBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
auto ret = as_variable(baseType->t(self_));
take_version_counter(ret, self);
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "t", { self }, ret );
(void)n;
@@ -470,13 +524,17 @@
}
Tensor & VariableType::m_t_(Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<TBackward>();
+ check_inplace(self);
+ std::shared_ptr<TBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<TBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
baseType->m_t_(self_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "t_", { self }, self );
(void)n;
@@ -485,15 +543,18 @@
}
Tensor VariableType::squeeze(const Tensor & self, int64_t dim) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<SqueezeBackward1>();
- if (flags.is_executable) {
+ std::shared_ptr<SqueezeBackward1> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SqueezeBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_argsize_dim = self.size(dim);
grad_fn->dim = dim;
}
auto ret = as_variable(baseType->squeeze(self_, dim));
take_version_counter(ret, self);
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "squeeze", { self }, ret );
setattr(n, jit::stringToSymbol("dim"), dim);
@@ -502,14 +563,17 @@
}
Tensor VariableType::squeeze(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<SqueezeBackward0>();
- if (flags.is_executable) {
+ std::shared_ptr<SqueezeBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SqueezeBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_sizes = self.sizes();
}
auto ret = as_variable(baseType->squeeze(self_));
take_version_counter(ret, self);
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "squeeze", { self }, ret );
(void)n;
@@ -518,17 +582,19 @@
}
Tensor & VariableType::m_squeeze_(Tensor & self, int64_t dim) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<SqueezeBackward1>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<SqueezeBackward1> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SqueezeBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_argsize_dim = self.size(dim);
grad_fn->dim = dim;
}
baseType->m_squeeze_(self_, dim);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "squeeze_", { self }, self );
setattr(n, jit::stringToSymbol("dim"), dim);
@@ -537,16 +603,18 @@
}
Tensor & VariableType::m_squeeze_(Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<SqueezeBackward0>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<SqueezeBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SqueezeBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_sizes = self.sizes();
}
baseType->m_squeeze_(self_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "squeeze_", { self }, self );
(void)n;
@@ -555,14 +623,17 @@
}
Tensor VariableType::unsqueeze(const Tensor & self, int64_t dim) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<UnsqueezeBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<UnsqueezeBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<UnsqueezeBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->dim = dim;
}
auto ret = as_variable(baseType->unsqueeze(self_, dim));
take_version_counter(ret, self);
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "unsqueeze", { self }, ret );
setattr(n, jit::stringToSymbol("dim"), dim);
@@ -571,16 +642,18 @@
}
Tensor & VariableType::m_unsqueeze_(Tensor & self, int64_t dim) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<UnsqueezeBackward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<UnsqueezeBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<UnsqueezeBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->dim = dim;
}
baseType->m_unsqueeze_(self_, dim);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "unsqueeze_", { self }, self );
setattr(n, jit::stringToSymbol("dim"), dim);
@@ -589,10 +662,15 @@
}
Tensor VariableType::nonzero(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<NonzeroBackward>();
+ std::shared_ptr<NonzeroBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<NonzeroBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
auto ret = as_variable(baseType->nonzero(self_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "nonzero", { self }, ret );
(void)n;
@@ -601,10 +679,15 @@
}
Tensor VariableType::m_clone(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<CloneBackward>();
+ std::shared_ptr<CloneBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<CloneBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
auto ret = as_variable(baseType->m_clone(self_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "clone", { self }, ret );
(void)n;
@@ -613,14 +696,17 @@
}
Tensor VariableType::m_view(const Tensor & self, IntList size) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<ViewBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<ViewBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<ViewBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_sizes = self.sizes();
}
auto ret = as_variable(baseType->m_view(self_, size));
take_version_counter(ret, self);
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "view", { self }, ret );
setattr(n, jit::stringToSymbol("size"), size);
@@ -629,14 +715,17 @@
}
Tensor VariableType::m_expand(const Tensor & self, IntList size) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<ExpandBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<ExpandBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<ExpandBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_sizes = self.sizes();
}
auto ret = as_variable(baseType->m_expand(self_, size));
take_version_counter(ret, self);
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "expand", { self }, ret );
setattr(n, jit::stringToSymbol("size"), size);
@@ -646,18 +735,21 @@
Tensor VariableType::index_select(const Tensor & self, int64_t dim, const Tensor & index) const {
auto& self_ = unpack(self, "self", 0);
auto& index_ = unpack_long(index, "index", 2);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<IndexSelectBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<IndexSelectBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<IndexSelectBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_sizes = self.sizes();
grad_fn->dim = dim;
grad_fn->index_ = SavedVariable(index, nullptr);
}
auto ret = as_variable(baseType->index_select(self_, dim, index_));
take_version_counter(ret, self);
- wrap_output(ret, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ self })) {
- jit::Node *n = jit::tracer::recordTrace( "index_select", { self }, ret );
+ set_flags(ret, flags, grad_fn);
+ if (jit::tracer::isTracing({ self, index })) {
+ jit::Node *n = jit::tracer::recordTrace( "index_select", { self, index }, ret );
setattr(n, jit::stringToSymbol("dim"), dim);
}
return Tensor(std::move(ret));
@@ -666,19 +758,21 @@
auto& self_ = unpack(self, "self", 0);
auto& index_ = unpack_long(index, "index", 2);
auto& source_ = unpack(source, "source", 3);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self, source });
- auto grad_fn = std::make_shared<IndexCopyBackward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<IndexCopyBackward> grad_fn;
+ auto flags = compute_flags({ self, source });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<IndexCopyBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, source });
grad_fn->dim = dim;
grad_fn->index_ = SavedVariable(index, nullptr);
}
baseType->m_index_copy_(self_, dim, index_, source_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ self, source })) {
- jit::Node *n = jit::tracer::recordTrace( "index_copy_", { self, source }, self );
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
+ if (jit::tracer::isTracing({ self, index, source })) {
+ jit::Node *n = jit::tracer::recordTrace( "index_copy_", { self, index, source }, self );
setattr(n, jit::stringToSymbol("dim"), dim);
}
return self;
@@ -687,19 +781,21 @@
auto& self_ = unpack(self, "self", 0);
auto& index_ = unpack_long(index, "index", 2);
auto& source_ = unpack(source, "source", 3);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self, source });
- auto grad_fn = std::make_shared<IndexAddBackward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<IndexAddBackward> grad_fn;
+ auto flags = compute_flags({ self, source });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<IndexAddBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, source });
grad_fn->dim = dim;
grad_fn->index_ = SavedVariable(index, nullptr);
}
baseType->m_index_add_(self_, dim, index_, source_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ self, source })) {
- jit::Node *n = jit::tracer::recordTrace( "index_add_", { self, source }, self );
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
+ if (jit::tracer::isTracing({ self, index, source })) {
+ jit::Node *n = jit::tracer::recordTrace( "index_add_", { self, index, source }, self );
setattr(n, jit::stringToSymbol("dim"), dim);
}
return self;
@@ -707,19 +803,21 @@
Tensor & VariableType::m_index_fill_(Tensor & self, int64_t dim, const Tensor & index, Scalar value) const {
auto& self_ = unpack(self, "self", 0);
auto& index_ = unpack_long(index, "index", 2);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<IndexFillBackward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<IndexFillBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<IndexFillBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->dim = dim;
grad_fn->index_ = SavedVariable(index, nullptr);
}
baseType->m_index_fill_(self_, dim, index_, value);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ self })) {
- jit::Node *n = jit::tracer::recordTrace( "index_fill_", { self }, self );
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
+ if (jit::tracer::isTracing({ self, index })) {
+ jit::Node *n = jit::tracer::recordTrace( "index_fill_", { self, index }, self );
setattr(n, jit::stringToSymbol("dim"), dim);
setattr(n, jit::stringToSymbol("value"), value);
}
@@ -727,16 +825,19 @@
}
Tensor VariableType::m_narrow(const Tensor & self, int64_t dimension, int64_t start, int64_t length) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<NarrowBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<NarrowBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<NarrowBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_argsize_dimension = self.size(dimension);
grad_fn->dimension = dimension;
grad_fn->start = start;
}
auto ret = as_variable(baseType->m_narrow(self_, dimension, start, length));
take_version_counter(ret, self);
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "narrow", { self }, ret );
setattr(n, jit::stringToSymbol("dimension"), dimension);
@@ -747,16 +848,19 @@
}
Tensor VariableType::m_unfold(const Tensor & self, int64_t dimension, int64_t size, int64_t step) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<UnfoldBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<UnfoldBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<UnfoldBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_sizes = self.sizes();
grad_fn->dimension = dimension;
grad_fn->size = size;
grad_fn->step = step;
}
auto ret = as_variable(baseType->m_unfold(self_, dimension, size, step));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "unfold", { self }, ret );
setattr(n, jit::stringToSymbol("dimension"), dimension);
@@ -775,19 +879,21 @@
auto& self_ = unpack(self, "self", 0);
auto& index_ = unpack_long(index, "index", 2);
auto& src_ = unpack(src, "src", 3);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self, src });
- auto grad_fn = std::make_shared<ScatterBackward0>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<ScatterBackward0> grad_fn;
+ auto flags = compute_flags({ self, src });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<ScatterBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, src });
grad_fn->dim = dim;
grad_fn->index_ = SavedVariable(index, nullptr);
}
baseType->m_scatter_(self_, dim, index_, src_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ self, src })) {
- jit::Node *n = jit::tracer::recordTrace( "scatter_", { self, src }, self );
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
+ if (jit::tracer::isTracing({ self, index, src })) {
+ jit::Node *n = jit::tracer::recordTrace( "scatter_", { self, index, src }, self );
setattr(n, jit::stringToSymbol("dim"), dim);
}
return self;
@@ -795,19 +901,21 @@
Tensor & VariableType::m_scatter_(Tensor & self, int64_t dim, const Tensor & index, Scalar value) const {
auto& self_ = unpack(self, "self", 0);
auto& index_ = unpack_long(index, "index", 2);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<ScatterBackward1>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<ScatterBackward1> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<ScatterBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->dim = dim;
grad_fn->index_ = SavedVariable(index, nullptr);
}
baseType->m_scatter_(self_, dim, index_, value);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ self })) {
- jit::Node *n = jit::tracer::recordTrace( "scatter_", { self }, self );
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
+ if (jit::tracer::isTracing({ self, index })) {
+ jit::Node *n = jit::tracer::recordTrace( "scatter_", { self, index }, self );
setattr(n, jit::stringToSymbol("dim"), dim);
setattr(n, jit::stringToSymbol("value"), value);
}
@@ -817,19 +925,21 @@
auto& self_ = unpack(self, "self", 0);
auto& index_ = unpack_long(index, "index", 2);
auto& src_ = unpack(src, "src", 3);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self, src });
- auto grad_fn = std::make_shared<ScatterAddBackward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<ScatterAddBackward> grad_fn;
+ auto flags = compute_flags({ self, src });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<ScatterAddBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, src });
grad_fn->dim = dim;
grad_fn->index_ = SavedVariable(index, nullptr);
}
baseType->m_scatter_add_(self_, dim, index_, src_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ self, src })) {
- jit::Node *n = jit::tracer::recordTrace( "scatter_add_", { self, src }, self );
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
+ if (jit::tracer::isTracing({ self, index, src })) {
+ jit::Node *n = jit::tracer::recordTrace( "scatter_add_", { self, index, src }, self );
setattr(n, jit::stringToSymbol("dim"), dim);
}
return self;
@@ -837,17 +947,20 @@
Tensor VariableType::gather(const Tensor & self, int64_t dim, const Tensor & index) const {
auto& self_ = unpack(self, "self", 0);
auto& index_ = unpack_long(index, "index", 2);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<GatherBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<GatherBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<GatherBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_sizes = self.sizes();
grad_fn->dim = dim;
grad_fn->index_ = SavedVariable(index, nullptr);
}
auto ret = as_variable(baseType->gather(self_, dim, index_));
- wrap_output(ret, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ self })) {
- jit::Node *n = jit::tracer::recordTrace( "gather", { self }, ret );
+ set_flags(ret, flags, grad_fn);
+ if (jit::tracer::isTracing({ self, index })) {
+ jit::Node *n = jit::tracer::recordTrace( "gather", { self, index }, ret );
setattr(n, jit::stringToSymbol("dim"), dim);
}
return Tensor(std::move(ret));
@@ -973,13 +1086,16 @@
}
Tensor VariableType::lt(const Tensor & self, Scalar other) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<LtBackward0>();
- if (flags.is_executable) {
+ std::shared_ptr<LtBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<LtBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_info = self;
}
auto ret = as_variable(baseType->lt(self_, other));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "lt", { self }, ret );
setattr(n, jit::stringToSymbol("other"), other);
@@ -989,14 +1105,17 @@
Tensor VariableType::s_lt(const Tensor & self, const Tensor & other) const {
auto& self_ = unpack(self, "self", 0);
auto& other_ = unpack(other, "other", 1);
- auto flags = Function::flags({ self, other });
- auto grad_fn = std::make_shared<LtBackward1>();
- if (flags.is_executable) {
+ std::shared_ptr<LtBackward1> grad_fn;
+ auto flags = compute_flags({ self, other });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<LtBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, other });
grad_fn->self_info = self;
grad_fn->other_info = other;
}
auto ret = as_variable(baseType->s_lt(self_, other_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self, other })) {
jit::Node *n = jit::tracer::recordTrace( "lt", { self, other }, ret );
(void)n;
@@ -1005,16 +1124,18 @@
}
Tensor & VariableType::m_lt_(Tensor & self, Scalar other) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<LtBackward0>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<LtBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<LtBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_info = self;
}
baseType->m_lt_(self_, other);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "lt_", { self }, self );
setattr(n, jit::stringToSymbol("other"), other);
@@ -1024,17 +1145,19 @@
Tensor & VariableType::s_m_lt_(Tensor & self, const Tensor & other) const {
auto& self_ = unpack(self, "self", 0);
auto& other_ = unpack(other, "other", 1);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self, other });
- auto grad_fn = std::make_shared<LtBackward1>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<LtBackward1> grad_fn;
+ auto flags = compute_flags({ self, other });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<LtBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, other });
grad_fn->self_info = self;
grad_fn->other_info = other;
}
baseType->s_m_lt_(self_, other_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self, other })) {
jit::Node *n = jit::tracer::recordTrace( "lt_", { self, other }, self );
(void)n;
@@ -1043,13 +1166,16 @@
}
Tensor VariableType::gt(const Tensor & self, Scalar other) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<GtBackward0>();
- if (flags.is_executable) {
+ std::shared_ptr<GtBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<GtBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_info = self;
}
auto ret = as_variable(baseType->gt(self_, other));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "gt", { self }, ret );
setattr(n, jit::stringToSymbol("other"), other);
@@ -1059,14 +1185,17 @@
Tensor VariableType::s_gt(const Tensor & self, const Tensor & other) const {
auto& self_ = unpack(self, "self", 0);
auto& other_ = unpack(other, "other", 1);
- auto flags = Function::flags({ self, other });
- auto grad_fn = std::make_shared<GtBackward1>();
- if (flags.is_executable) {
+ std::shared_ptr<GtBackward1> grad_fn;
+ auto flags = compute_flags({ self, other });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<GtBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, other });
grad_fn->self_info = self;
grad_fn->other_info = other;
}
auto ret = as_variable(baseType->s_gt(self_, other_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self, other })) {
jit::Node *n = jit::tracer::recordTrace( "gt", { self, other }, ret );
(void)n;
@@ -1075,16 +1204,18 @@
}
Tensor & VariableType::m_gt_(Tensor & self, Scalar other) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<GtBackward0>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<GtBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<GtBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_info = self;
}
baseType->m_gt_(self_, other);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "gt_", { self }, self );
setattr(n, jit::stringToSymbol("other"), other);
@@ -1094,17 +1225,19 @@
Tensor & VariableType::s_m_gt_(Tensor & self, const Tensor & other) const {
auto& self_ = unpack(self, "self", 0);
auto& other_ = unpack(other, "other", 1);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self, other });
- auto grad_fn = std::make_shared<GtBackward1>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<GtBackward1> grad_fn;
+ auto flags = compute_flags({ self, other });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<GtBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, other });
grad_fn->self_info = self;
grad_fn->other_info = other;
}
baseType->s_m_gt_(self_, other_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self, other })) {
jit::Node *n = jit::tracer::recordTrace( "gt_", { self, other }, self );
(void)n;
@@ -1113,13 +1246,16 @@
}
Tensor VariableType::le(const Tensor & self, Scalar other) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<LeBackward0>();
- if (flags.is_executable) {
+ std::shared_ptr<LeBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<LeBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_info = self;
}
auto ret = as_variable(baseType->le(self_, other));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "le", { self }, ret );
setattr(n, jit::stringToSymbol("other"), other);
@@ -1129,14 +1265,17 @@
Tensor VariableType::s_le(const Tensor & self, const Tensor & other) const {
auto& self_ = unpack(self, "self", 0);
auto& other_ = unpack(other, "other", 1);
- auto flags = Function::flags({ self, other });
- auto grad_fn = std::make_shared<LeBackward1>();
- if (flags.is_executable) {
+ std::shared_ptr<LeBackward1> grad_fn;
+ auto flags = compute_flags({ self, other });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<LeBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, other });
grad_fn->self_info = self;
grad_fn->other_info = other;
}
auto ret = as_variable(baseType->s_le(self_, other_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self, other })) {
jit::Node *n = jit::tracer::recordTrace( "le", { self, other }, ret );
(void)n;
@@ -1145,16 +1284,18 @@
}
Tensor & VariableType::m_le_(Tensor & self, Scalar other) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<LeBackward0>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<LeBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<LeBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_info = self;
}
baseType->m_le_(self_, other);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "le_", { self }, self );
setattr(n, jit::stringToSymbol("other"), other);
@@ -1164,17 +1305,19 @@
Tensor & VariableType::s_m_le_(Tensor & self, const Tensor & other) const {
auto& self_ = unpack(self, "self", 0);
auto& other_ = unpack(other, "other", 1);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self, other });
- auto grad_fn = std::make_shared<LeBackward1>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<LeBackward1> grad_fn;
+ auto flags = compute_flags({ self, other });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<LeBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, other });
grad_fn->self_info = self;
grad_fn->other_info = other;
}
baseType->s_m_le_(self_, other_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self, other })) {
jit::Node *n = jit::tracer::recordTrace( "le_", { self, other }, self );
(void)n;
@@ -1183,13 +1326,16 @@
}
Tensor VariableType::ge(const Tensor & self, Scalar other) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<GeBackward0>();
- if (flags.is_executable) {
+ std::shared_ptr<GeBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<GeBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_info = self;
}
auto ret = as_variable(baseType->ge(self_, other));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "ge", { self }, ret );
setattr(n, jit::stringToSymbol("other"), other);
@@ -1199,14 +1345,17 @@
Tensor VariableType::s_ge(const Tensor & self, const Tensor & other) const {
auto& self_ = unpack(self, "self", 0);
auto& other_ = unpack(other, "other", 1);
- auto flags = Function::flags({ self, other });
- auto grad_fn = std::make_shared<GeBackward1>();
- if (flags.is_executable) {
+ std::shared_ptr<GeBackward1> grad_fn;
+ auto flags = compute_flags({ self, other });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<GeBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, other });
grad_fn->self_info = self;
grad_fn->other_info = other;
}
auto ret = as_variable(baseType->s_ge(self_, other_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self, other })) {
jit::Node *n = jit::tracer::recordTrace( "ge", { self, other }, ret );
(void)n;
@@ -1215,16 +1364,18 @@
}
Tensor & VariableType::m_ge_(Tensor & self, Scalar other) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<GeBackward0>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<GeBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<GeBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_info = self;
}
baseType->m_ge_(self_, other);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "ge_", { self }, self );
setattr(n, jit::stringToSymbol("other"), other);
@@ -1234,17 +1385,19 @@
Tensor & VariableType::s_m_ge_(Tensor & self, const Tensor & other) const {
auto& self_ = unpack(self, "self", 0);
auto& other_ = unpack(other, "other", 1);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self, other });
- auto grad_fn = std::make_shared<GeBackward1>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<GeBackward1> grad_fn;
+ auto flags = compute_flags({ self, other });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<GeBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, other });
grad_fn->self_info = self;
grad_fn->other_info = other;
}
baseType->s_m_ge_(self_, other_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self, other })) {
jit::Node *n = jit::tracer::recordTrace( "ge_", { self, other }, self );
(void)n;
@@ -1253,13 +1406,16 @@
}
Tensor VariableType::eq(const Tensor & self, Scalar other) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<EqBackward0>();
- if (flags.is_executable) {
+ std::shared_ptr<EqBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<EqBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_info = self;
}
auto ret = as_variable(baseType->eq(self_, other));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "eq", { self }, ret );
setattr(n, jit::stringToSymbol("other"), other);
@@ -1269,14 +1425,17 @@
Tensor VariableType::s_eq(const Tensor & self, const Tensor & other) const {
auto& self_ = unpack(self, "self", 0);
auto& other_ = unpack(other, "other", 1);
- auto flags = Function::flags({ self, other });
- auto grad_fn = std::make_shared<EqBackward1>();
- if (flags.is_executable) {
+ std::shared_ptr<EqBackward1> grad_fn;
+ auto flags = compute_flags({ self, other });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<EqBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, other });
grad_fn->self_info = self;
grad_fn->other_info = other;
}
auto ret = as_variable(baseType->s_eq(self_, other_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self, other })) {
jit::Node *n = jit::tracer::recordTrace( "eq", { self, other }, ret );
(void)n;
@@ -1285,16 +1444,18 @@
}
Tensor & VariableType::m_eq_(Tensor & self, Scalar other) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<EqBackward0>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<EqBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<EqBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_info = self;
}
baseType->m_eq_(self_, other);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "eq_", { self }, self );
setattr(n, jit::stringToSymbol("other"), other);
@@ -1304,17 +1465,19 @@
Tensor & VariableType::s_m_eq_(Tensor & self, const Tensor & other) const {
auto& self_ = unpack(self, "self", 0);
auto& other_ = unpack(other, "other", 1);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self, other });
- auto grad_fn = std::make_shared<EqBackward1>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<EqBackward1> grad_fn;
+ auto flags = compute_flags({ self, other });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<EqBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, other });
grad_fn->self_info = self;
grad_fn->other_info = other;
}
baseType->s_m_eq_(self_, other_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self, other })) {
jit::Node *n = jit::tracer::recordTrace( "eq_", { self, other }, self );
(void)n;
@@ -1323,13 +1486,16 @@
}
Tensor VariableType::ne(const Tensor & self, Scalar other) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<NeBackward0>();
- if (flags.is_executable) {
+ std::shared_ptr<NeBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<NeBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_info = self;
}
auto ret = as_variable(baseType->ne(self_, other));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "ne", { self }, ret );
setattr(n, jit::stringToSymbol("other"), other);
@@ -1339,14 +1505,17 @@
Tensor VariableType::s_ne(const Tensor & self, const Tensor & other) const {
auto& self_ = unpack(self, "self", 0);
auto& other_ = unpack(other, "other", 1);
- auto flags = Function::flags({ self, other });
- auto grad_fn = std::make_shared<NeBackward1>();
- if (flags.is_executable) {
+ std::shared_ptr<NeBackward1> grad_fn;
+ auto flags = compute_flags({ self, other });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<NeBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, other });
grad_fn->self_info = self;
grad_fn->other_info = other;
}
auto ret = as_variable(baseType->s_ne(self_, other_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self, other })) {
jit::Node *n = jit::tracer::recordTrace( "ne", { self, other }, ret );
(void)n;
@@ -1355,16 +1524,18 @@
}
Tensor & VariableType::m_ne_(Tensor & self, Scalar other) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<NeBackward0>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<NeBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<NeBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_info = self;
}
baseType->m_ne_(self_, other);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "ne_", { self }, self );
setattr(n, jit::stringToSymbol("other"), other);
@@ -1374,17 +1545,19 @@
Tensor & VariableType::s_m_ne_(Tensor & self, const Tensor & other) const {
auto& self_ = unpack(self, "self", 0);
auto& other_ = unpack(other, "other", 1);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self, other });
- auto grad_fn = std::make_shared<NeBackward1>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<NeBackward1> grad_fn;
+ auto flags = compute_flags({ self, other });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<NeBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, other });
grad_fn->self_info = self;
grad_fn->other_info = other;
}
baseType->s_m_ne_(self_, other_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self, other })) {
jit::Node *n = jit::tracer::recordTrace( "ne_", { self, other }, self );
(void)n;
@@ -1393,16 +1566,19 @@
}
std::tuple<Tensor,Tensor> VariableType::min(const Tensor & self, int64_t dim, bool keepdim) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<MinBackward0>();
- if (flags.is_executable) {
+ std::shared_ptr<MinBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<MinBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_sizes = self.sizes();
grad_fn->dim = dim;
grad_fn->keepdim = keepdim;
}
auto ret = as_variable(baseType->min(self_, dim, keepdim));
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(std::get<0>(ret), flags, grad_fn);
+ if (grad_fn) {
auto& min_indices = std::get<1>(ret);
grad_fn->min_indices_ = SavedVariable(min_indices, grad_fn.get());
}
@@ -1416,14 +1592,17 @@
Tensor VariableType::s_min(const Tensor & self, const Tensor & other) const {
auto& self_ = unpack(self, "self", 0);
auto& other_ = unpack(other, "other", 1);
- auto flags = Function::flags({ self, other });
- auto grad_fn = std::make_shared<MinBackward2>();
- if (flags.is_executable) {
+ std::shared_ptr<MinBackward2> grad_fn;
+ auto flags = compute_flags({ self, other });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<MinBackward2>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, other });
grad_fn->self_ = SavedVariable(self, nullptr);
grad_fn->other_ = SavedVariable(other, nullptr);
}
auto ret = as_variable(baseType->s_min(self_, other_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self, other })) {
jit::Node *n = jit::tracer::recordTrace( "min", { self, other }, ret );
(void)n;
@@ -1432,14 +1611,17 @@
}
Scalar VariableType::min(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<MinBackward1>();
- if (flags.is_executable) {
+ std::shared_ptr<MinBackward1> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<MinBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self, nullptr);
}
auto ret = as_variable(baseType->min(self_));
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(ret, flags, grad_fn);
+ if (grad_fn) {
auto& result = ret;
grad_fn->result_ = SavedVariable(result, grad_fn.get());
}
@@ -1451,16 +1633,19 @@
}
std::tuple<Tensor,Tensor> VariableType::max(const Tensor & self, int64_t dim, bool keepdim) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<MaxBackward0>();
- if (flags.is_executable) {
+ std::shared_ptr<MaxBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<MaxBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_sizes = self.sizes();
grad_fn->dim = dim;
grad_fn->keepdim = keepdim;
}
auto ret = as_variable(baseType->max(self_, dim, keepdim));
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(std::get<0>(ret), flags, grad_fn);
+ if (grad_fn) {
auto& max_indices = std::get<1>(ret);
grad_fn->max_indices_ = SavedVariable(max_indices, grad_fn.get());
}
@@ -1474,14 +1659,17 @@
Tensor VariableType::s_max(const Tensor & self, const Tensor & other) const {
auto& self_ = unpack(self, "self", 0);
auto& other_ = unpack(other, "other", 1);
- auto flags = Function::flags({ self, other });
- auto grad_fn = std::make_shared<MaxBackward2>();
- if (flags.is_executable) {
+ std::shared_ptr<MaxBackward2> grad_fn;
+ auto flags = compute_flags({ self, other });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<MaxBackward2>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, other });
grad_fn->self_ = SavedVariable(self, nullptr);
grad_fn->other_ = SavedVariable(other, nullptr);
}
auto ret = as_variable(baseType->s_max(self_, other_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self, other })) {
jit::Node *n = jit::tracer::recordTrace( "max", { self, other }, ret );
(void)n;
@@ -1490,14 +1678,17 @@
}
Scalar VariableType::max(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<MaxBackward1>();
- if (flags.is_executable) {
+ std::shared_ptr<MaxBackward1> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<MaxBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self, nullptr);
}
auto ret = as_variable(baseType->max(self_));
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(ret, flags, grad_fn);
+ if (grad_fn) {
auto& result = ret;
grad_fn->result_ = SavedVariable(result, grad_fn.get());
}
@@ -1509,16 +1700,19 @@
}
std::tuple<Tensor,Tensor> VariableType::kthvalue(const Tensor & self, int64_t k, int64_t dim, bool keepdim) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<KthvalueBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<KthvalueBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<KthvalueBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_sizes = self.sizes();
grad_fn->dim = dim;
grad_fn->keepdim = keepdim;
}
auto ret = as_variable(baseType->kthvalue(self_, k, dim, keepdim));
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(std::get<0>(ret), flags, grad_fn);
+ if (grad_fn) {
auto& indices = std::get<1>(ret);
grad_fn->indices_ = SavedVariable(indices, grad_fn.get());
}
@@ -1532,16 +1726,19 @@
}
std::tuple<Tensor,Tensor> VariableType::mode(const Tensor & self, int64_t dim, bool keepdim) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<ModeBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<ModeBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<ModeBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_sizes = self.sizes();
grad_fn->dim = dim;
grad_fn->keepdim = keepdim;
}
auto ret = as_variable(baseType->mode(self_, dim, keepdim));
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(std::get<0>(ret), flags, grad_fn);
+ if (grad_fn) {
auto& indices = std::get<1>(ret);
grad_fn->indices_ = SavedVariable(indices, grad_fn.get());
}
@@ -1554,16 +1751,19 @@
}
std::tuple<Tensor,Tensor> VariableType::median(const Tensor & self, int64_t dim, bool keepdim) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<MedianBackward1>();
- if (flags.is_executable) {
+ std::shared_ptr<MedianBackward1> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<MedianBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_sizes = self.sizes();
grad_fn->dim = dim;
grad_fn->keepdim = keepdim;
}
auto ret = as_variable(baseType->median(self_, dim, keepdim));
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(std::get<0>(ret), flags, grad_fn);
+ if (grad_fn) {
auto& indices = std::get<1>(ret);
grad_fn->indices_ = SavedVariable(indices, grad_fn.get());
}
@@ -1576,14 +1776,17 @@
}
Scalar VariableType::median(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<MedianBackward0>();
- if (flags.is_executable) {
+ std::shared_ptr<MedianBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<MedianBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self, nullptr);
}
auto ret = as_variable(baseType->median(self_));
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(ret, flags, grad_fn);
+ if (grad_fn) {
auto& result = ret;
grad_fn->result_ = SavedVariable(result, grad_fn.get());
}
@@ -1595,15 +1798,18 @@
}
std::tuple<Tensor,Tensor> VariableType::sort(const Tensor & self, int64_t dim, bool descending) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<SortBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<SortBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SortBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_sizes = self.sizes();
grad_fn->dim = dim;
}
auto ret = as_variable(baseType->sort(self_, dim, descending));
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(std::get<0>(ret), flags, grad_fn);
+ if (grad_fn) {
auto& indices = std::get<1>(ret);
grad_fn->indices_ = SavedVariable(indices, grad_fn.get());
}
@@ -1616,15 +1822,18 @@
}
std::tuple<Tensor,Tensor> VariableType::topk(const Tensor & self, int64_t k, int64_t dim, bool largest, bool sorted) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<TopkBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<TopkBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<TopkBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_sizes = self.sizes();
grad_fn->dim = dim;
}
auto ret = as_variable(baseType->topk(self_, k, dim, largest, sorted));
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(std::get<0>(ret), flags, grad_fn);
+ if (grad_fn) {
auto& indices = std::get<1>(ret);
grad_fn->indices_ = SavedVariable(indices, grad_fn.get());
}
@@ -1651,13 +1860,16 @@
}
Tensor VariableType::abs(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<AbsBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<AbsBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<AbsBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self, nullptr);
}
auto ret = as_variable(baseType->abs(self_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "abs", { self }, ret );
(void)n;
@@ -1666,16 +1878,18 @@
}
Tensor & VariableType::m_abs_(Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<AbsBackward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<AbsBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<AbsBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self.clone(), nullptr);
}
baseType->m_abs_(self_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "abs_", { self }, self );
(void)n;
@@ -1684,14 +1898,18 @@
}
Tensor & VariableType::m_sigmoid_(Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<SigmoidBackward0>();
+ check_inplace(self);
+ std::shared_ptr<SigmoidBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SigmoidBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
baseType->m_sigmoid_(self_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
+ if (grad_fn) {
grad_fn->result_ = SavedVariable(self, grad_fn.get());
}
if (jit::tracer::isTracing({ self })) {
@@ -1702,11 +1920,16 @@
}
Tensor VariableType::sigmoid(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<SigmoidBackward0>();
+ std::shared_ptr<SigmoidBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SigmoidBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
auto ret = as_variable(baseType->sigmoid(self_));
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(ret, flags, grad_fn);
+ if (grad_fn) {
auto& result = ret;
grad_fn->result_ = SavedVariable(result, grad_fn.get());
}
@@ -1718,16 +1941,18 @@
}
Tensor & VariableType::m_log_(Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<LogBackward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<LogBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<LogBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self.clone(), nullptr);
}
baseType->m_log_(self_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "log_", { self }, self );
(void)n;
@@ -1736,13 +1961,16 @@
}
Tensor VariableType::log(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<LogBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<LogBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<LogBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self, nullptr);
}
auto ret = as_variable(baseType->log(self_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "log", { self }, ret );
(void)n;
@@ -1751,16 +1979,18 @@
}
Tensor & VariableType::m_log1p_(Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<Log1PBackward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<Log1PBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<Log1PBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self.clone(), nullptr);
}
baseType->m_log1p_(self_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "log1p_", { self }, self );
(void)n;
@@ -1769,13 +1999,16 @@
}
Tensor VariableType::log1p(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<Log1PBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<Log1PBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<Log1PBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self, nullptr);
}
auto ret = as_variable(baseType->log1p(self_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "log1p", { self }, ret );
(void)n;
@@ -1784,10 +2017,15 @@
}
Tensor VariableType::lgamma(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<LgammaBackward>();
+ std::shared_ptr<LgammaBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<LgammaBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
auto ret = as_variable(baseType->lgamma(self_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "lgamma", { self }, ret );
(void)n;
@@ -1796,13 +2034,17 @@
}
Tensor & VariableType::m_lgamma_(Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<LgammaBackward>();
+ check_inplace(self);
+ std::shared_ptr<LgammaBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<LgammaBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
baseType->m_lgamma_(self_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "lgamma_", { self }, self );
(void)n;
@@ -1811,14 +2053,18 @@
}
Tensor & VariableType::m_exp_(Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<ExpBackward>();
+ check_inplace(self);
+ std::shared_ptr<ExpBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<ExpBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
baseType->m_exp_(self_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
+ if (grad_fn) {
grad_fn->result_ = SavedVariable(self, grad_fn.get());
}
if (jit::tracer::isTracing({ self })) {
@@ -1829,11 +2075,16 @@
}
Tensor VariableType::exp(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<ExpBackward>();
+ std::shared_ptr<ExpBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<ExpBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
auto ret = as_variable(baseType->exp(self_));
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(ret, flags, grad_fn);
+ if (grad_fn) {
auto& result = ret;
grad_fn->result_ = SavedVariable(result, grad_fn.get());
}
@@ -1845,16 +2096,18 @@
}
Tensor & VariableType::m_cos_(Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<CosBackward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<CosBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<CosBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self.clone(), nullptr);
}
baseType->m_cos_(self_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "cos_", { self }, self );
(void)n;
@@ -1863,13 +2116,16 @@
}
Tensor VariableType::cos(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<CosBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<CosBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<CosBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self, nullptr);
}
auto ret = as_variable(baseType->cos(self_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "cos", { self }, ret );
(void)n;
@@ -1878,16 +2134,18 @@
}
Tensor & VariableType::m_acos_(Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<AcosBackward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<AcosBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<AcosBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self.clone(), nullptr);
}
baseType->m_acos_(self_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "acos_", { self }, self );
(void)n;
@@ -1896,13 +2154,16 @@
}
Tensor VariableType::acos(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<AcosBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<AcosBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<AcosBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self, nullptr);
}
auto ret = as_variable(baseType->acos(self_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "acos", { self }, ret );
(void)n;
@@ -1911,16 +2172,18 @@
}
Tensor & VariableType::m_cosh_(Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<CoshBackward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<CoshBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<CoshBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self.clone(), nullptr);
}
baseType->m_cosh_(self_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "cosh_", { self }, self );
(void)n;
@@ -1929,13 +2192,16 @@
}
Tensor VariableType::cosh(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<CoshBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<CoshBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<CoshBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self, nullptr);
}
auto ret = as_variable(baseType->cosh(self_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "cosh", { self }, ret );
(void)n;
@@ -1944,16 +2210,18 @@
}
Tensor & VariableType::m_sin_(Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<SinBackward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<SinBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SinBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self.clone(), nullptr);
}
baseType->m_sin_(self_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "sin_", { self }, self );
(void)n;
@@ -1962,13 +2230,16 @@
}
Tensor VariableType::sin(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<SinBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<SinBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SinBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self, nullptr);
}
auto ret = as_variable(baseType->sin(self_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "sin", { self }, ret );
(void)n;
@@ -1977,16 +2248,18 @@
}
Tensor & VariableType::m_asin_(Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<AsinBackward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<AsinBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<AsinBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self.clone(), nullptr);
}
baseType->m_asin_(self_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "asin_", { self }, self );
(void)n;
@@ -1995,13 +2268,16 @@
}
Tensor VariableType::asin(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<AsinBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<AsinBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<AsinBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self, nullptr);
}
auto ret = as_variable(baseType->asin(self_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "asin", { self }, ret );
(void)n;
@@ -2010,16 +2286,18 @@
}
Tensor & VariableType::m_sinh_(Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<SinhBackward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<SinhBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SinhBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self.clone(), nullptr);
}
baseType->m_sinh_(self_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "sinh_", { self }, self );
(void)n;
@@ -2028,13 +2306,16 @@
}
Tensor VariableType::sinh(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<SinhBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<SinhBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SinhBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self, nullptr);
}
auto ret = as_variable(baseType->sinh(self_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "sinh", { self }, ret );
(void)n;
@@ -2043,16 +2324,18 @@
}
Tensor & VariableType::m_tan_(Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<TanBackward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<TanBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<TanBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self.clone(), nullptr);
}
baseType->m_tan_(self_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "tan_", { self }, self );
(void)n;
@@ -2061,13 +2344,16 @@
}
Tensor VariableType::tan(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<TanBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<TanBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<TanBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self, nullptr);
}
auto ret = as_variable(baseType->tan(self_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "tan", { self }, ret );
(void)n;
@@ -2076,16 +2362,18 @@
}
Tensor & VariableType::m_atan_(Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<AtanBackward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<AtanBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<AtanBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self.clone(), nullptr);
}
baseType->m_atan_(self_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "atan_", { self }, self );
(void)n;
@@ -2094,13 +2382,16 @@
}
Tensor VariableType::atan(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<AtanBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<AtanBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<AtanBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self, nullptr);
}
auto ret = as_variable(baseType->atan(self_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "atan", { self }, ret );
(void)n;
@@ -2109,14 +2400,18 @@
}
Tensor & VariableType::m_tanh_(Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<TanhBackward0>();
+ check_inplace(self);
+ std::shared_ptr<TanhBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<TanhBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
baseType->m_tanh_(self_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
+ if (grad_fn) {
grad_fn->result_ = SavedVariable(self, grad_fn.get());
}
if (jit::tracer::isTracing({ self })) {
@@ -2127,11 +2422,16 @@
}
Tensor VariableType::tanh(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<TanhBackward0>();
+ std::shared_ptr<TanhBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<TanhBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
auto ret = as_variable(baseType->tanh(self_));
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(ret, flags, grad_fn);
+ if (grad_fn) {
auto& result = ret;
grad_fn->result_ = SavedVariable(result, grad_fn.get());
}
@@ -2143,16 +2443,18 @@
}
Tensor & VariableType::m_erf_(Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<ErfBackward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<ErfBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<ErfBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self.clone(), nullptr);
}
baseType->m_erf_(self_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "erf_", { self }, self );
(void)n;
@@ -2161,13 +2463,16 @@
}
Tensor VariableType::erf(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<ErfBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<ErfBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<ErfBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self, nullptr);
}
auto ret = as_variable(baseType->erf(self_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "erf", { self }, ret );
(void)n;
@@ -2176,16 +2481,18 @@
}
Tensor & VariableType::m_erfinv_(Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<ErfinvBackward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<ErfinvBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<ErfinvBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self.clone(), nullptr);
}
baseType->m_erfinv_(self_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "erfinv_", { self }, self );
(void)n;
@@ -2194,13 +2501,16 @@
}
Tensor VariableType::erfinv(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<ErfinvBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<ErfinvBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<ErfinvBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self, nullptr);
}
auto ret = as_variable(baseType->erfinv(self_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "erfinv", { self }, ret );
(void)n;
@@ -2209,16 +2519,18 @@
}
Tensor & VariableType::m_sqrt_(Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<SqrtBackward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<SqrtBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SqrtBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self.clone(), nullptr);
}
baseType->m_sqrt_(self_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "sqrt_", { self }, self );
(void)n;
@@ -2227,13 +2539,16 @@
}
Tensor VariableType::sqrt(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<SqrtBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<SqrtBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SqrtBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self, nullptr);
}
auto ret = as_variable(baseType->sqrt(self_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "sqrt", { self }, ret );
(void)n;
@@ -2242,14 +2557,18 @@
}
Tensor & VariableType::m_rsqrt_(Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<RsqrtBackward>();
+ check_inplace(self);
+ std::shared_ptr<RsqrtBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<RsqrtBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
baseType->m_rsqrt_(self_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
+ if (grad_fn) {
grad_fn->result_ = SavedVariable(self, grad_fn.get());
}
if (jit::tracer::isTracing({ self })) {
@@ -2260,11 +2579,16 @@
}
Tensor VariableType::rsqrt(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<RsqrtBackward>();
+ std::shared_ptr<RsqrtBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<RsqrtBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
auto ret = as_variable(baseType->rsqrt(self_));
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(ret, flags, grad_fn);
+ if (grad_fn) {
auto& result = ret;
grad_fn->result_ = SavedVariable(result, grad_fn.get());
}
@@ -2276,13 +2600,17 @@
}
Tensor & VariableType::m_ceil_(Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<CeilBackward>();
+ check_inplace(self);
+ std::shared_ptr<CeilBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<CeilBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
baseType->m_ceil_(self_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "ceil_", { self }, self );
(void)n;
@@ -2291,10 +2619,15 @@
}
Tensor VariableType::ceil(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<CeilBackward>();
+ std::shared_ptr<CeilBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<CeilBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
auto ret = as_variable(baseType->ceil(self_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "ceil", { self }, ret );
(void)n;
@@ -2303,13 +2636,17 @@
}
Tensor & VariableType::m_floor_(Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<FloorBackward>();
+ check_inplace(self);
+ std::shared_ptr<FloorBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<FloorBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
baseType->m_floor_(self_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "floor_", { self }, self );
(void)n;
@@ -2318,10 +2655,15 @@
}
Tensor VariableType::floor(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<FloorBackward>();
+ std::shared_ptr<FloorBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<FloorBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
auto ret = as_variable(baseType->floor(self_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "floor", { self }, ret );
(void)n;
@@ -2330,13 +2672,17 @@
}
Tensor & VariableType::m_round_(Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<RoundBackward>();
+ check_inplace(self);
+ std::shared_ptr<RoundBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<RoundBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
baseType->m_round_(self_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "round_", { self }, self );
(void)n;
@@ -2345,10 +2691,15 @@
}
Tensor VariableType::round(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<RoundBackward>();
+ std::shared_ptr<RoundBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<RoundBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
auto ret = as_variable(baseType->round(self_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "round", { self }, ret );
(void)n;
@@ -2357,13 +2708,17 @@
}
Tensor & VariableType::m_trunc_(Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<TruncBackward>();
+ check_inplace(self);
+ std::shared_ptr<TruncBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<TruncBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
baseType->m_trunc_(self_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "trunc_", { self }, self );
(void)n;
@@ -2372,10 +2727,15 @@
}
Tensor VariableType::trunc(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<TruncBackward>();
+ std::shared_ptr<TruncBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<TruncBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
auto ret = as_variable(baseType->trunc(self_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "trunc", { self }, ret );
(void)n;
@@ -2384,13 +2744,17 @@
}
Tensor & VariableType::m_frac_(Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<FracBackward>();
+ check_inplace(self);
+ std::shared_ptr<FracBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<FracBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
baseType->m_frac_(self_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "frac_", { self }, self );
(void)n;
@@ -2399,10 +2763,15 @@
}
Tensor VariableType::frac(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<FracBackward>();
+ std::shared_ptr<FracBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<FracBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
auto ret = as_variable(baseType->frac(self_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "frac", { self }, ret );
(void)n;
@@ -2411,16 +2780,19 @@
}
Tensor VariableType::mean(const Tensor & self, int64_t dim, bool keepdim) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<MeanBackward0>();
- if (flags.is_executable) {
+ std::shared_ptr<MeanBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<MeanBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_sizes = self.sizes();
grad_fn->self_argsize_dim = self.size(dim);
grad_fn->dim = dim;
grad_fn->keepdim = keepdim;
}
auto ret = as_variable(baseType->mean(self_, dim, keepdim));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "mean", { self }, ret );
setattr(n, jit::stringToSymbol("dim"), dim);
@@ -2430,14 +2802,17 @@
}
Scalar VariableType::mean(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<MeanBackward1>();
- if (flags.is_executable) {
+ std::shared_ptr<MeanBackward1> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<MeanBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_sizes = self.sizes();
grad_fn->self_ = SavedVariable(self, nullptr);
}
auto ret = as_variable(baseType->mean(self_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "mean", { self }, ret );
(void)n;
@@ -2458,16 +2833,19 @@
}
Tensor VariableType::norm(const Tensor & self, Scalar p, int64_t dim, bool keepdim) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<NormBackward1>();
- if (flags.is_executable) {
+ std::shared_ptr<NormBackward1> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<NormBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self, nullptr);
grad_fn->p = p;
grad_fn->dim = dim;
grad_fn->keepdim = keepdim;
}
auto ret = as_variable(baseType->norm(self_, p, dim, keepdim));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "norm", { self }, ret );
setattr(n, jit::stringToSymbol("p"), p);
@@ -2478,14 +2856,17 @@
}
Scalar VariableType::norm(const Tensor & self, Scalar p) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<NormBackward0>();
- if (flags.is_executable) {
+ std::shared_ptr<NormBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<NormBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self, nullptr);
grad_fn->p = p;
}
auto ret = as_variable(baseType->norm(self_, p));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "norm", { self }, ret );
setattr(n, jit::stringToSymbol("p"), p);
@@ -2501,15 +2882,18 @@
Scalar VariableType::s_dist(const Tensor & self, const Tensor & other, Scalar p) const {
auto& self_ = unpack(self, "self", 0);
auto& other_ = unpack(other, "other", 1);
- auto flags = Function::flags({ self, other });
- auto grad_fn = std::make_shared<DistBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<DistBackward> grad_fn;
+ auto flags = compute_flags({ self, other });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<DistBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, other });
grad_fn->self_ = SavedVariable(self, nullptr);
grad_fn->other_ = SavedVariable(other, nullptr);
grad_fn->p = p;
}
auto ret = as_variable(baseType->s_dist(self_, other_, p));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self, other })) {
jit::Node *n = jit::tracer::recordTrace( "dist", { self, other }, ret );
setattr(n, jit::stringToSymbol("p"), p);
@@ -2518,13 +2902,16 @@
}
Tensor VariableType::reciprocal(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<ReciprocalBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<ReciprocalBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<ReciprocalBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self, nullptr);
}
auto ret = as_variable(baseType->reciprocal(self_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "reciprocal", { self }, ret );
(void)n;
@@ -2533,16 +2920,18 @@
}
Tensor & VariableType::m_reciprocal_(Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<ReciprocalBackward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<ReciprocalBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<ReciprocalBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self.clone(), nullptr);
}
baseType->m_reciprocal_(self_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "reciprocal_", { self }, self );
(void)n;
@@ -2551,10 +2940,15 @@
}
Tensor VariableType::neg(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<NegBackward>();
+ std::shared_ptr<NegBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<NegBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
auto ret = as_variable(baseType->neg(self_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "neg", { self }, ret );
(void)n;
@@ -2563,13 +2957,17 @@
}
Tensor & VariableType::m_neg_(Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<NegBackward>();
+ check_inplace(self);
+ std::shared_ptr<NegBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<NegBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
baseType->m_neg_(self_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "neg_", { self }, self );
(void)n;
@@ -2579,14 +2977,17 @@
Tensor VariableType::s_atan2(const Tensor & self, const Tensor & other) const {
auto& self_ = unpack(self, "self", 0);
auto& other_ = unpack(other, "other", 1);
- auto flags = Function::flags({ self, other });
- auto grad_fn = std::make_shared<Atan2Backward>();
- if (flags.is_executable) {
+ std::shared_ptr<Atan2Backward> grad_fn;
+ auto flags = compute_flags({ self, other });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<Atan2Backward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, other });
grad_fn->self_ = SavedVariable(self, nullptr);
grad_fn->other_ = SavedVariable(other, nullptr);
}
auto ret = as_variable(baseType->s_atan2(self_, other_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self, other })) {
jit::Node *n = jit::tracer::recordTrace( "atan2", { self, other }, ret );
(void)n;
@@ -2596,17 +2997,19 @@
Tensor & VariableType::s_m_atan2_(Tensor & self, const Tensor & other) const {
auto& self_ = unpack(self, "self", 0);
auto& other_ = unpack(other, "other", 1);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self, other });
- auto grad_fn = std::make_shared<Atan2Backward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<Atan2Backward> grad_fn;
+ auto flags = compute_flags({ self, other });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<Atan2Backward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, other });
grad_fn->self_ = SavedVariable(self.clone(), nullptr);
grad_fn->other_ = SavedVariable(other, nullptr);
}
baseType->s_m_atan2_(self_, other_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self, other })) {
jit::Node *n = jit::tracer::recordTrace( "atan2_", { self, other }, self );
(void)n;
@@ -2615,14 +3018,17 @@
}
Tensor VariableType::pow(const Tensor & self, Scalar exponent) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<PowBackward0>();
- if (flags.is_executable) {
+ std::shared_ptr<PowBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<PowBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self, nullptr);
grad_fn->exponent = exponent;
}
auto ret = as_variable(baseType->pow(self_, exponent));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "pow", { self }, ret );
setattr(n, jit::stringToSymbol("exponent"), exponent);
@@ -2632,14 +3038,17 @@
Tensor VariableType::s_pow(const Tensor & self, const Tensor & exponent) const {
auto& self_ = unpack(self, "self", 0);
auto& exponent_ = unpack(exponent, "exponent", 1);
- auto flags = Function::flags({ self, exponent });
- auto grad_fn = std::make_shared<PowBackward1>();
- if (flags.is_executable) {
+ std::shared_ptr<PowBackward1> grad_fn;
+ auto flags = compute_flags({ self, exponent });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<PowBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, exponent });
grad_fn->self_ = SavedVariable(self, nullptr);
grad_fn->exponent_ = SavedVariable(exponent, nullptr);
}
auto ret = as_variable(baseType->s_pow(self_, exponent_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self, exponent })) {
jit::Node *n = jit::tracer::recordTrace( "pow", { self, exponent }, ret );
(void)n;
@@ -2648,17 +3057,19 @@
}
Tensor & VariableType::m_pow_(Tensor & self, Scalar exponent) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<PowBackward0>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<PowBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<PowBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_ = SavedVariable(self.clone(), nullptr);
grad_fn->exponent = exponent;
}
baseType->m_pow_(self_, exponent);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "pow_", { self }, self );
setattr(n, jit::stringToSymbol("exponent"), exponent);
@@ -2668,17 +3079,19 @@
Tensor & VariableType::s_m_pow_(Tensor & self, const Tensor & exponent) const {
auto& self_ = unpack(self, "self", 0);
auto& exponent_ = unpack(exponent, "exponent", 1);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self, exponent });
- auto grad_fn = std::make_shared<PowBackward1>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<PowBackward1> grad_fn;
+ auto flags = compute_flags({ self, exponent });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<PowBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, exponent });
grad_fn->self_ = SavedVariable(self.clone(), nullptr);
grad_fn->exponent_ = SavedVariable(exponent, nullptr);
}
baseType->s_m_pow_(self_, exponent_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self, exponent })) {
jit::Node *n = jit::tracer::recordTrace( "pow_", { self, exponent }, self );
(void)n;
@@ -2688,13 +3101,16 @@
Tensor VariableType::s_lerp(const Tensor & self, const Tensor & end, Scalar weight) const {
auto& self_ = unpack(self, "self", 0);
auto& end_ = unpack(end, "end", 1);
- auto flags = Function::flags({ self, end });
- auto grad_fn = std::make_shared<LerpBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<LerpBackward> grad_fn;
+ auto flags = compute_flags({ self, end });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<LerpBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, end });
grad_fn->weight = weight;
}
auto ret = as_variable(baseType->s_lerp(self_, end_, weight));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self, end })) {
jit::Node *n = jit::tracer::recordTrace( "lerp", { self, end }, ret );
setattr(n, jit::stringToSymbol("weight"), weight);
@@ -2704,16 +3120,18 @@
Tensor & VariableType::s_m_lerp_(Tensor & self, const Tensor & end, Scalar weight) const {
auto& self_ = unpack(self, "self", 0);
auto& end_ = unpack(end, "end", 1);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self, end });
- auto grad_fn = std::make_shared<LerpBackward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<LerpBackward> grad_fn;
+ auto flags = compute_flags({ self, end });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<LerpBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, end });
grad_fn->weight = weight;
}
baseType->s_m_lerp_(self_, end_, weight);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self, end })) {
jit::Node *n = jit::tracer::recordTrace( "lerp_", { self, end }, self );
setattr(n, jit::stringToSymbol("weight"), weight);
@@ -2728,10 +3146,15 @@
}
Tensor VariableType::histc(const Tensor & self, int64_t bins, Scalar min, Scalar max) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<HistcBackward>();
+ std::shared_ptr<HistcBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<HistcBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
auto ret = as_variable(baseType->histc(self_, bins, min, max));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "histc", { self }, ret );
setattr(n, jit::stringToSymbol("bins"), bins);
@@ -2742,13 +3165,17 @@
}
Tensor & VariableType::m_zero_(Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<ZeroBackward>();
+ check_inplace(self);
+ std::shared_ptr<ZeroBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<ZeroBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
baseType->m_zero_(self_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "zero_", { self }, self );
(void)n;
@@ -2757,15 +3184,18 @@
}
Tensor VariableType::sum(const Tensor & self, int64_t dim, bool keepdim) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<SumBackward1>();
- if (flags.is_executable) {
+ std::shared_ptr<SumBackward1> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SumBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_sizes = self.sizes();
grad_fn->dim = dim;
grad_fn->keepdim = keepdim;
}
auto ret = as_variable(baseType->sum(self_, dim, keepdim));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "sum", { self }, ret );
setattr(n, jit::stringToSymbol("dim"), dim);
@@ -2775,13 +3205,16 @@
}
Scalar VariableType::sum(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<SumBackward0>();
- if (flags.is_executable) {
+ std::shared_ptr<SumBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SumBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_sizes = self.sizes();
}
auto ret = as_variable(baseType->sum(self_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "sum", { self }, ret );
(void)n;
@@ -2796,13 +3229,16 @@
}
Tensor VariableType::cumsum(const Tensor & self, int64_t dim) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<CumsumBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<CumsumBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<CumsumBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->dim = dim;
}
auto ret = as_variable(baseType->cumsum(self_, dim));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "cumsum", { self }, ret );
setattr(n, jit::stringToSymbol("dim"), dim);
@@ -2814,10 +3250,15 @@
}
Tensor VariableType::sign(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<SignBackward>();
+ std::shared_ptr<SignBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SignBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
auto ret = as_variable(baseType->sign(self_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "sign", { self }, ret );
(void)n;
@@ -2826,13 +3267,17 @@
}
Tensor & VariableType::m_sign_(Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<SignBackward>();
+ check_inplace(self);
+ std::shared_ptr<SignBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SignBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
baseType->m_sign_(self_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "sign_", { self }, self );
(void)n;
@@ -2841,13 +3286,16 @@
}
Scalar VariableType::trace(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<TraceBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<TraceBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<TraceBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_sizes = self.sizes();
}
auto ret = as_variable(baseType->trace(self_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "trace", { self }, ret );
(void)n;
@@ -2856,10 +3304,15 @@
}
Tensor VariableType::add(const Tensor & self, Scalar other, Scalar alpha) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<AddBackward0>();
+ std::shared_ptr<AddBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<AddBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
auto ret = as_variable(baseType->add(self_, other, alpha));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "add", { self }, ret );
setattr(n, jit::stringToSymbol("other"), other);
@@ -2870,13 +3323,16 @@
Tensor VariableType::s_add(const Tensor & self, const Tensor & other, Scalar alpha) const {
auto& self_ = unpack(self, "self", 0);
auto& other_ = unpack(other, "other", 1);
- auto flags = Function::flags({ self, other });
- auto grad_fn = std::make_shared<AddBackward1>();
- if (flags.is_executable) {
+ std::shared_ptr<AddBackward1> grad_fn;
+ auto flags = compute_flags({ self, other });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<AddBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, other });
grad_fn->alpha = alpha;
}
auto ret = as_variable(baseType->s_add(self_, other_, alpha));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self, other })) {
jit::Node *n = jit::tracer::recordTrace( "add", { self, other }, ret );
setattr(n, jit::stringToSymbol("alpha"), alpha);
@@ -2888,13 +3344,17 @@
}
Tensor & VariableType::m_add_(Tensor & self, Scalar other, Scalar alpha) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<AddBackward0>();
+ check_inplace(self);
+ std::shared_ptr<AddBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<AddBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
baseType->m_add_(self_, other, alpha);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "add_", { self }, self );
setattr(n, jit::stringToSymbol("other"), other);
@@ -2905,16 +3365,18 @@
Tensor & VariableType::s_m_add_(Tensor & self, const Tensor & other, Scalar alpha) const {
auto& self_ = unpack(self, "self", 0);
auto& other_ = unpack(other, "other", 1);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self, other });
- auto grad_fn = std::make_shared<AddBackward1>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<AddBackward1> grad_fn;
+ auto flags = compute_flags({ self, other });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<AddBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, other });
grad_fn->alpha = alpha;
}
baseType->s_m_add_(self_, other_, alpha);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self, other })) {
jit::Node *n = jit::tracer::recordTrace( "add_", { self, other }, self );
setattr(n, jit::stringToSymbol("alpha"), alpha);
@@ -2926,10 +3388,15 @@
}
Tensor VariableType::sub(const Tensor & self, Scalar other, Scalar alpha) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<SubBackward0>();
+ std::shared_ptr<SubBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SubBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
auto ret = as_variable(baseType->sub(self_, other, alpha));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "sub", { self }, ret );
setattr(n, jit::stringToSymbol("other"), other);
@@ -2940,13 +3407,16 @@
Tensor VariableType::s_sub(const Tensor & self, const Tensor & other, Scalar alpha) const {
auto& self_ = unpack(self, "self", 0);
auto& other_ = unpack(other, "other", 1);
- auto flags = Function::flags({ self, other });
- auto grad_fn = std::make_shared<SubBackward1>();
- if (flags.is_executable) {
+ std::shared_ptr<SubBackward1> grad_fn;
+ auto flags = compute_flags({ self, other });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SubBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, other });
grad_fn->alpha = alpha;
}
auto ret = as_variable(baseType->s_sub(self_, other_, alpha));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self, other })) {
jit::Node *n = jit::tracer::recordTrace( "sub", { self, other }, ret );
setattr(n, jit::stringToSymbol("alpha"), alpha);
@@ -2955,13 +3425,17 @@
}
Tensor & VariableType::m_sub_(Tensor & self, Scalar other, Scalar alpha) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<SubBackward0>();
+ check_inplace(self);
+ std::shared_ptr<SubBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SubBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
baseType->m_sub_(self_, other, alpha);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "sub_", { self }, self );
setattr(n, jit::stringToSymbol("other"), other);
@@ -2972,16 +3446,18 @@
Tensor & VariableType::s_m_sub_(Tensor & self, const Tensor & other, Scalar alpha) const {
auto& self_ = unpack(self, "self", 0);
auto& other_ = unpack(other, "other", 1);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self, other });
- auto grad_fn = std::make_shared<SubBackward1>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<SubBackward1> grad_fn;
+ auto flags = compute_flags({ self, other });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SubBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, other });
grad_fn->alpha = alpha;
}
baseType->s_m_sub_(self_, other_, alpha);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self, other })) {
jit::Node *n = jit::tracer::recordTrace( "sub_", { self, other }, self );
setattr(n, jit::stringToSymbol("alpha"), alpha);
@@ -2990,13 +3466,16 @@
}
Tensor VariableType::mul(const Tensor & self, Scalar other) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<MulBackward0>();
- if (flags.is_executable) {
+ std::shared_ptr<MulBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<MulBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->other = other;
}
auto ret = as_variable(baseType->mul(self_, other));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "mul", { self }, ret );
setattr(n, jit::stringToSymbol("other"), other);
@@ -3006,14 +3485,17 @@
Tensor VariableType::s_mul(const Tensor & self, const Tensor & other) const {
auto& self_ = unpack(self, "self", 0);
auto& other_ = unpack(other, "other", 1);
- auto flags = Function::flags({ self, other });
- auto grad_fn = std::make_shared<MulBackward1>();
- if (flags.is_executable) {
+ std::shared_ptr<MulBackward1> grad_fn;
+ auto flags = compute_flags({ self, other });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<MulBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, other });
grad_fn->other_ = SavedVariable(other, nullptr);
grad_fn->self_ = SavedVariable(self, nullptr);
}
auto ret = as_variable(baseType->s_mul(self_, other_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self, other })) {
jit::Node *n = jit::tracer::recordTrace( "mul", { self, other }, ret );
(void)n;
@@ -3022,16 +3504,18 @@
}
Tensor & VariableType::m_mul_(Tensor & self, Scalar other) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<MulBackward0>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<MulBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<MulBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->other = other;
}
baseType->m_mul_(self_, other);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "mul_", { self }, self );
setattr(n, jit::stringToSymbol("other"), other);
@@ -3041,17 +3525,19 @@
Tensor & VariableType::s_m_mul_(Tensor & self, const Tensor & other) const {
auto& self_ = unpack(self, "self", 0);
auto& other_ = unpack(other, "other", 1);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self, other });
- auto grad_fn = std::make_shared<MulBackward1>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<MulBackward1> grad_fn;
+ auto flags = compute_flags({ self, other });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<MulBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, other });
grad_fn->other_ = SavedVariable(other, nullptr);
grad_fn->self_ = SavedVariable(self.clone(), nullptr);
}
baseType->s_m_mul_(self_, other_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self, other })) {
jit::Node *n = jit::tracer::recordTrace( "mul_", { self, other }, self );
(void)n;
@@ -3060,13 +3546,16 @@
}
Tensor VariableType::div(const Tensor & self, Scalar other) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<DivBackward0>();
- if (flags.is_executable) {
+ std::shared_ptr<DivBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<DivBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->other = other;
}
auto ret = as_variable(baseType->div(self_, other));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "div", { self }, ret );
setattr(n, jit::stringToSymbol("other"), other);
@@ -3076,14 +3565,17 @@
Tensor VariableType::s_div(const Tensor & self, const Tensor & other) const {
auto& self_ = unpack(self, "self", 0);
auto& other_ = unpack(other, "other", 1);
- auto flags = Function::flags({ self, other });
- auto grad_fn = std::make_shared<DivBackward1>();
- if (flags.is_executable) {
+ std::shared_ptr<DivBackward1> grad_fn;
+ auto flags = compute_flags({ self, other });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<DivBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, other });
grad_fn->other_ = SavedVariable(other, nullptr);
grad_fn->self_ = SavedVariable(self, nullptr);
}
auto ret = as_variable(baseType->s_div(self_, other_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self, other })) {
jit::Node *n = jit::tracer::recordTrace( "div", { self, other }, ret );
(void)n;
@@ -3092,16 +3584,18 @@
}
Tensor & VariableType::m_div_(Tensor & self, Scalar other) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<DivBackward0>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<DivBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<DivBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->other = other;
}
baseType->m_div_(self_, other);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "div_", { self }, self );
setattr(n, jit::stringToSymbol("other"), other);
@@ -3111,17 +3605,19 @@
Tensor & VariableType::s_m_div_(Tensor & self, const Tensor & other) const {
auto& self_ = unpack(self, "self", 0);
auto& other_ = unpack(other, "other", 1);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self, other });
- auto grad_fn = std::make_shared<DivBackward1>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<DivBackward1> grad_fn;
+ auto flags = compute_flags({ self, other });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<DivBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, other });
grad_fn->other_ = SavedVariable(other, nullptr);
grad_fn->self_ = SavedVariable(self.clone(), nullptr);
}
baseType->s_m_div_(self_, other_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self, other })) {
jit::Node *n = jit::tracer::recordTrace( "div_", { self, other }, self );
(void)n;
@@ -3130,10 +3626,15 @@
}
Tensor VariableType::fmod(const Tensor & self, Scalar other) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<FmodBackward0>();
+ std::shared_ptr<FmodBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<FmodBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
auto ret = as_variable(baseType->fmod(self_, other));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "fmod", { self }, ret );
setattr(n, jit::stringToSymbol("other"), other);
@@ -3143,13 +3644,16 @@
Tensor VariableType::s_fmod(const Tensor & self, const Tensor & other) const {
auto& self_ = unpack(self, "self", 0);
auto& other_ = unpack(other, "other", 1);
- auto flags = Function::flags({ self, other });
- auto grad_fn = std::make_shared<FmodBackward1>();
- if (flags.is_executable) {
+ std::shared_ptr<FmodBackward1> grad_fn;
+ auto flags = compute_flags({ self, other });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<FmodBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, other });
grad_fn->other_ = SavedVariable(other, nullptr);
}
auto ret = as_variable(baseType->s_fmod(self_, other_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self, other })) {
jit::Node *n = jit::tracer::recordTrace( "fmod", { self, other }, ret );
(void)n;
@@ -3158,13 +3662,17 @@
}
Tensor & VariableType::m_fmod_(Tensor & self, Scalar other) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<FmodBackward0>();
+ check_inplace(self);
+ std::shared_ptr<FmodBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<FmodBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
baseType->m_fmod_(self_, other);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "fmod_", { self }, self );
setattr(n, jit::stringToSymbol("other"), other);
@@ -3174,16 +3682,18 @@
Tensor & VariableType::s_m_fmod_(Tensor & self, const Tensor & other) const {
auto& self_ = unpack(self, "self", 0);
auto& other_ = unpack(other, "other", 1);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self, other });
- auto grad_fn = std::make_shared<FmodBackward1>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<FmodBackward1> grad_fn;
+ auto flags = compute_flags({ self, other });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<FmodBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, other });
grad_fn->other_ = SavedVariable(other, nullptr);
}
baseType->s_m_fmod_(self_, other_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self, other })) {
jit::Node *n = jit::tracer::recordTrace( "fmod_", { self, other }, self );
(void)n;
@@ -3192,10 +3702,15 @@
}
Tensor VariableType::remainder(const Tensor & self, Scalar other) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<RemainderBackward0>();
+ std::shared_ptr<RemainderBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<RemainderBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
auto ret = as_variable(baseType->remainder(self_, other));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "remainder", { self }, ret );
setattr(n, jit::stringToSymbol("other"), other);
@@ -3205,25 +3720,35 @@
Tensor VariableType::s_remainder(const Tensor & self, const Tensor & other) const {
auto& self_ = unpack(self, "self", 0);
auto& other_ = unpack(other, "other", 1);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<RemainderBackward1>();
+ check_no_requires_grad(other, "other");
+ std::shared_ptr<RemainderBackward1> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<RemainderBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
auto ret = as_variable(baseType->s_remainder(self_, other_));
- wrap_output(ret, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ self })) {
- jit::Node *n = jit::tracer::recordTrace( "remainder", { self }, ret );
+ set_flags(ret, flags, grad_fn);
+ if (jit::tracer::isTracing({ self, other })) {
+ jit::Node *n = jit::tracer::recordTrace( "remainder", { self, other }, ret );
(void)n;
}
return Tensor(std::move(ret));
}
Tensor & VariableType::m_remainder_(Tensor & self, Scalar other) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<RemainderBackward0>();
+ check_inplace(self);
+ std::shared_ptr<RemainderBackward0> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<RemainderBackward0>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
baseType->m_remainder_(self_, other);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "remainder_", { self }, self );
setattr(n, jit::stringToSymbol("other"), other);
@@ -3233,15 +3758,20 @@
Tensor & VariableType::s_m_remainder_(Tensor & self, const Tensor & other) const {
auto& self_ = unpack(self, "self", 0);
auto& other_ = unpack(other, "other", 1);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<RemainderBackward1>();
+ check_inplace(self);
+ check_no_requires_grad(other, "other");
+ std::shared_ptr<RemainderBackward1> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<RemainderBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
baseType->s_m_remainder_(self_, other_);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ self })) {
- jit::Node *n = jit::tracer::recordTrace( "remainder_", { self }, self );
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
+ if (jit::tracer::isTracing({ self, other })) {
+ jit::Node *n = jit::tracer::recordTrace( "remainder_", { self, other }, self );
(void)n;
}
return self;
@@ -3261,14 +3791,17 @@
Scalar VariableType::dot(const Tensor & self, const Tensor & tensor) const {
auto& self_ = unpack(self, "self", 0);
auto& tensor_ = unpack(tensor, "tensor", 1);
- auto flags = Function::flags({ self, tensor });
- auto grad_fn = std::make_shared<DotBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<DotBackward> grad_fn;
+ auto flags = compute_flags({ self, tensor });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<DotBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, tensor });
grad_fn->tensor_ = SavedVariable(tensor, nullptr);
grad_fn->self_ = SavedVariable(self, nullptr);
}
auto ret = as_variable(baseType->dot(self_, tensor_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self, tensor })) {
jit::Node *n = jit::tracer::recordTrace( "dot", { self, tensor }, ret );
(void)n;
@@ -3277,13 +3810,16 @@
}
Tensor VariableType::tril(const Tensor & self, int64_t diagonal) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<TrilBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<TrilBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<TrilBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->diagonal = diagonal;
}
auto ret = as_variable(baseType->tril(self_, diagonal));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "tril", { self }, ret );
setattr(n, jit::stringToSymbol("diagonal"), diagonal);
@@ -3292,16 +3828,18 @@
}
Tensor & VariableType::m_tril_(Tensor & self, int64_t diagonal) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<TrilBackward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<TrilBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<TrilBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->diagonal = diagonal;
}
baseType->m_tril_(self_, diagonal);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "tril_", { self }, self );
setattr(n, jit::stringToSymbol("diagonal"), diagonal);
@@ -3310,13 +3848,16 @@
}
Tensor VariableType::triu(const Tensor & self, int64_t diagonal) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<TriuBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<TriuBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<TriuBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->diagonal = diagonal;
}
auto ret = as_variable(baseType->triu(self_, diagonal));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "triu", { self }, ret );
setattr(n, jit::stringToSymbol("diagonal"), diagonal);
@@ -3325,16 +3866,18 @@
}
Tensor & VariableType::m_triu_(Tensor & self, int64_t diagonal) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<TriuBackward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<TriuBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<TriuBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->diagonal = diagonal;
}
baseType->m_triu_(self_, diagonal);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "triu_", { self }, self );
setattr(n, jit::stringToSymbol("diagonal"), diagonal);
@@ -3344,15 +3887,18 @@
Tensor VariableType::cross(const Tensor & self, const Tensor & other, int64_t dim) const {
auto& self_ = unpack(self, "self", 0);
auto& other_ = unpack(other, "other", 1);
- auto flags = Function::flags({ self, other });
- auto grad_fn = std::make_shared<CrossBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<CrossBackward> grad_fn;
+ auto flags = compute_flags({ self, other });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<CrossBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, other });
grad_fn->other_ = SavedVariable(other, nullptr);
grad_fn->dim = dim;
grad_fn->self_ = SavedVariable(self, nullptr);
}
auto ret = as_variable(baseType->cross(self_, other_, dim));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self, other })) {
jit::Node *n = jit::tracer::recordTrace( "cross", { self, other }, ret );
setattr(n, jit::stringToSymbol("dim"), dim);
@@ -3364,13 +3910,16 @@
}
Tensor VariableType::diag(const Tensor & self, int64_t diagonal) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<DiagBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<DiagBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<DiagBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->diagonal = diagonal;
}
auto ret = as_variable(baseType->diag(self_, diagonal));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "diag", { self }, ret );
setattr(n, jit::stringToSymbol("diagonal"), diagonal);
@@ -3381,16 +3930,19 @@
auto& self_ = unpack(self, "self", 0);
auto& mat1_ = unpack(mat1, "mat1", 1);
auto& mat2_ = unpack(mat2, "mat2", 2);
- auto flags = Function::flags({ self, mat1, mat2 });
- auto grad_fn = std::make_shared<AddmmBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<AddmmBackward> grad_fn;
+ auto flags = compute_flags({ self, mat1, mat2 });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<AddmmBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, mat1, mat2 });
grad_fn->beta = beta;
grad_fn->mat1_ = SavedVariable(mat1, nullptr);
grad_fn->mat2_ = SavedVariable(mat2, nullptr);
grad_fn->alpha = alpha;
}
auto ret = as_variable(baseType->s_addmm(self_, mat1_, mat2_, beta, alpha));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self, mat1, mat2 })) {
jit::Node *n = jit::tracer::recordTrace( "addmm", { self, mat1, mat2 }, ret );
setattr(n, jit::stringToSymbol("beta"), beta);
@@ -3402,19 +3954,21 @@
auto& self_ = unpack(self, "self", 0);
auto& mat1_ = unpack(mat1, "mat1", 1);
auto& mat2_ = unpack(mat2, "mat2", 2);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self, mat1, mat2 });
- auto grad_fn = std::make_shared<AddmmBackward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<AddmmBackward> grad_fn;
+ auto flags = compute_flags({ self, mat1, mat2 });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<AddmmBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, mat1, mat2 });
grad_fn->beta = beta;
grad_fn->mat1_ = SavedVariable(mat1, nullptr);
grad_fn->mat2_ = SavedVariable(mat2, nullptr);
grad_fn->alpha = alpha;
}
baseType->m_addmm_(self_, mat1_, mat2_, beta, alpha);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self, mat1, mat2 })) {
jit::Node *n = jit::tracer::recordTrace( "addmm_", { self, mat1, mat2 }, self );
setattr(n, jit::stringToSymbol("beta"), beta);
@@ -3426,16 +3980,19 @@
auto& self_ = unpack(self, "self", 0);
auto& mat_ = unpack(mat, "mat", 1);
auto& vec_ = unpack(vec, "vec", 2);
- auto flags = Function::flags({ self, mat, vec });
- auto grad_fn = std::make_shared<AddmvBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<AddmvBackward> grad_fn;
+ auto flags = compute_flags({ self, mat, vec });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<AddmvBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, mat, vec });
grad_fn->beta = beta;
grad_fn->vec_ = SavedVariable(vec, nullptr);
grad_fn->alpha = alpha;
grad_fn->mat_ = SavedVariable(mat, nullptr);
}
auto ret = as_variable(baseType->s_addmv(self_, mat_, vec_, beta, alpha));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self, mat, vec })) {
jit::Node *n = jit::tracer::recordTrace( "addmv", { self, mat, vec }, ret );
setattr(n, jit::stringToSymbol("beta"), beta);
@@ -3447,19 +4004,21 @@
auto& self_ = unpack(self, "self", 0);
auto& mat_ = unpack(mat, "mat", 1);
auto& vec_ = unpack(vec, "vec", 2);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self, mat, vec });
- auto grad_fn = std::make_shared<AddmvBackward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<AddmvBackward> grad_fn;
+ auto flags = compute_flags({ self, mat, vec });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<AddmvBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, mat, vec });
grad_fn->beta = beta;
grad_fn->vec_ = SavedVariable(vec, nullptr);
grad_fn->alpha = alpha;
grad_fn->mat_ = SavedVariable(mat, nullptr);
}
baseType->m_addmv_(self_, mat_, vec_, beta, alpha);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self, mat, vec })) {
jit::Node *n = jit::tracer::recordTrace( "addmv_", { self, mat, vec }, self );
setattr(n, jit::stringToSymbol("beta"), beta);
@@ -3471,16 +4030,19 @@
auto& self_ = unpack(self, "self", 0);
auto& vec1_ = unpack(vec1, "vec1", 1);
auto& vec2_ = unpack(vec2, "vec2", 2);
- auto flags = Function::flags({ self, vec1, vec2 });
- auto grad_fn = std::make_shared<AddrBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<AddrBackward> grad_fn;
+ auto flags = compute_flags({ self, vec1, vec2 });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<AddrBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, vec1, vec2 });
grad_fn->beta = beta;
grad_fn->vec2_ = SavedVariable(vec2, nullptr);
grad_fn->alpha = alpha;
grad_fn->vec1_ = SavedVariable(vec1, nullptr);
}
auto ret = as_variable(baseType->s_addr(self_, vec1_, vec2_, beta, alpha));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self, vec1, vec2 })) {
jit::Node *n = jit::tracer::recordTrace( "addr", { self, vec1, vec2 }, ret );
setattr(n, jit::stringToSymbol("beta"), beta);
@@ -3492,19 +4054,21 @@
auto& self_ = unpack(self, "self", 0);
auto& vec1_ = unpack(vec1, "vec1", 1);
auto& vec2_ = unpack(vec2, "vec2", 2);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self, vec1, vec2 });
- auto grad_fn = std::make_shared<AddrBackward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<AddrBackward> grad_fn;
+ auto flags = compute_flags({ self, vec1, vec2 });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<AddrBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, vec1, vec2 });
grad_fn->beta = beta;
grad_fn->vec2_ = SavedVariable(vec2, nullptr);
grad_fn->alpha = alpha;
grad_fn->vec1_ = SavedVariable(vec1, nullptr);
}
baseType->m_addr_(self_, vec1_, vec2_, beta, alpha);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self, vec1, vec2 })) {
jit::Node *n = jit::tracer::recordTrace( "addr_", { self, vec1, vec2 }, self );
setattr(n, jit::stringToSymbol("beta"), beta);
@@ -3515,14 +4079,17 @@
Tensor VariableType::ger(const Tensor & self, const Tensor & vec2) const {
auto& self_ = unpack(self, "self", 0);
auto& vec2_ = unpack(vec2, "vec2", 1);
- auto flags = Function::flags({ self, vec2 });
- auto grad_fn = std::make_shared<GerBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<GerBackward> grad_fn;
+ auto flags = compute_flags({ self, vec2 });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<GerBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, vec2 });
grad_fn->vec2_ = SavedVariable(vec2, nullptr);
grad_fn->self_ = SavedVariable(self, nullptr);
}
auto ret = as_variable(baseType->ger(self_, vec2_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self, vec2 })) {
jit::Node *n = jit::tracer::recordTrace( "ger", { self, vec2 }, ret );
(void)n;
@@ -3532,14 +4099,17 @@
Tensor VariableType::mv(const Tensor & self, const Tensor & vec) const {
auto& self_ = unpack(self, "self", 0);
auto& vec_ = unpack(vec, "vec", 1);
- auto flags = Function::flags({ self, vec });
- auto grad_fn = std::make_shared<MvBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<MvBackward> grad_fn;
+ auto flags = compute_flags({ self, vec });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<MvBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, vec });
grad_fn->vec_ = SavedVariable(vec, nullptr);
grad_fn->self_ = SavedVariable(self, nullptr);
}
auto ret = as_variable(baseType->mv(self_, vec_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self, vec })) {
jit::Node *n = jit::tracer::recordTrace( "mv", { self, vec }, ret );
(void)n;
@@ -3549,14 +4119,17 @@
Tensor VariableType::mm(const Tensor & self, const Tensor & mat2) const {
auto& self_ = unpack(self, "self", 0);
auto& mat2_ = unpack(mat2, "mat2", 1);
- auto flags = Function::flags({ self, mat2 });
- auto grad_fn = std::make_shared<MmBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<MmBackward> grad_fn;
+ auto flags = compute_flags({ self, mat2 });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<MmBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, mat2 });
grad_fn->mat2_ = SavedVariable(mat2, nullptr);
grad_fn->self_ = SavedVariable(self, nullptr);
}
auto ret = as_variable(baseType->mm(self_, mat2_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self, mat2 })) {
jit::Node *n = jit::tracer::recordTrace( "mm", { self, mat2 }, ret );
(void)n;
@@ -3566,14 +4139,17 @@
Tensor VariableType::bmm(const Tensor & self, const Tensor & mat2) const {
auto& self_ = unpack(self, "self", 0);
auto& mat2_ = unpack(mat2, "mat2", 1);
- auto flags = Function::flags({ self, mat2 });
- auto grad_fn = std::make_shared<BmmBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<BmmBackward> grad_fn;
+ auto flags = compute_flags({ self, mat2 });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<BmmBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, mat2 });
grad_fn->mat2_ = SavedVariable(mat2, nullptr);
grad_fn->self_ = SavedVariable(self, nullptr);
}
auto ret = as_variable(baseType->bmm(self_, mat2_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self, mat2 })) {
jit::Node *n = jit::tracer::recordTrace( "bmm", { self, mat2 }, ret );
(void)n;
@@ -3584,9 +4160,12 @@
auto& self_ = unpack(self, "self", 0);
auto& batch1_ = unpack(batch1, "batch1", 1);
auto& batch2_ = unpack(batch2, "batch2", 2);
- auto flags = Function::flags({ self, batch1, batch2 });
- auto grad_fn = std::make_shared<AddbmmBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<AddbmmBackward> grad_fn;
+ auto flags = compute_flags({ self, batch1, batch2 });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<AddbmmBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, batch1, batch2 });
grad_fn->beta = beta;
grad_fn->batch1_argsize_0 = batch1.size(0);
grad_fn->batch1_argsize_1 = batch1.size(1);
@@ -3596,7 +4175,7 @@
grad_fn->batch1_ = SavedVariable(batch1, nullptr);
}
auto ret = as_variable(baseType->s_addbmm(self_, batch1_, batch2_, beta, alpha));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self, batch1, batch2 })) {
jit::Node *n = jit::tracer::recordTrace( "addbmm", { self, batch1, batch2 }, ret );
setattr(n, jit::stringToSymbol("beta"), beta);
@@ -3608,11 +4187,13 @@
auto& self_ = unpack(self, "self", 0);
auto& batch1_ = unpack(batch1, "batch1", 1);
auto& batch2_ = unpack(batch2, "batch2", 2);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self, batch1, batch2 });
- auto grad_fn = std::make_shared<AddbmmBackward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<AddbmmBackward> grad_fn;
+ auto flags = compute_flags({ self, batch1, batch2 });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<AddbmmBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, batch1, batch2 });
grad_fn->beta = beta;
grad_fn->batch1_argsize_0 = batch1.size(0);
grad_fn->batch1_argsize_1 = batch1.size(1);
@@ -3622,8 +4203,8 @@
grad_fn->batch1_ = SavedVariable(batch1, nullptr);
}
baseType->m_addbmm_(self_, batch1_, batch2_, beta, alpha);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self, batch1, batch2 })) {
jit::Node *n = jit::tracer::recordTrace( "addbmm_", { self, batch1, batch2 }, self );
setattr(n, jit::stringToSymbol("beta"), beta);
@@ -3635,16 +4216,19 @@
auto& self_ = unpack(self, "self", 0);
auto& batch1_ = unpack(batch1, "batch1", 1);
auto& batch2_ = unpack(batch2, "batch2", 2);
- auto flags = Function::flags({ self, batch1, batch2 });
- auto grad_fn = std::make_shared<BaddbmmBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<BaddbmmBackward> grad_fn;
+ auto flags = compute_flags({ self, batch1, batch2 });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<BaddbmmBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, batch1, batch2 });
grad_fn->beta = beta;
grad_fn->batch2_ = SavedVariable(batch2, nullptr);
grad_fn->alpha = alpha;
grad_fn->batch1_ = SavedVariable(batch1, nullptr);
}
auto ret = as_variable(baseType->s_baddbmm(self_, batch1_, batch2_, beta, alpha));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self, batch1, batch2 })) {
jit::Node *n = jit::tracer::recordTrace( "baddbmm", { self, batch1, batch2 }, ret );
setattr(n, jit::stringToSymbol("beta"), beta);
@@ -3656,19 +4240,21 @@
auto& self_ = unpack(self, "self", 0);
auto& batch1_ = unpack(batch1, "batch1", 1);
auto& batch2_ = unpack(batch2, "batch2", 2);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self, batch1, batch2 });
- auto grad_fn = std::make_shared<BaddbmmBackward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<BaddbmmBackward> grad_fn;
+ auto flags = compute_flags({ self, batch1, batch2 });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<BaddbmmBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, batch1, batch2 });
grad_fn->beta = beta;
grad_fn->batch2_ = SavedVariable(batch2, nullptr);
grad_fn->alpha = alpha;
grad_fn->batch1_ = SavedVariable(batch1, nullptr);
}
baseType->m_baddbmm_(self_, batch1_, batch2_, beta, alpha);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self, batch1, batch2 })) {
jit::Node *n = jit::tracer::recordTrace( "baddbmm_", { self, batch1, batch2 }, self );
setattr(n, jit::stringToSymbol("beta"), beta);
@@ -3680,15 +4266,18 @@
auto& self_ = unpack(self, "self", 0);
auto& tensor1_ = unpack(tensor1, "tensor1", 1);
auto& tensor2_ = unpack(tensor2, "tensor2", 2);
- auto flags = Function::flags({ self, tensor1, tensor2 });
- auto grad_fn = std::make_shared<AddcmulBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<AddcmulBackward> grad_fn;
+ auto flags = compute_flags({ self, tensor1, tensor2 });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<AddcmulBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, tensor1, tensor2 });
grad_fn->tensor2_ = SavedVariable(tensor2, nullptr);
grad_fn->value = value;
grad_fn->tensor1_ = SavedVariable(tensor1, nullptr);
}
auto ret = as_variable(baseType->s_addcmul(self_, tensor1_, tensor2_, value));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self, tensor1, tensor2 })) {
jit::Node *n = jit::tracer::recordTrace( "addcmul", { self, tensor1, tensor2 }, ret );
setattr(n, jit::stringToSymbol("value"), value);
@@ -3699,18 +4288,20 @@
auto& self_ = unpack(self, "self", 0);
auto& tensor1_ = unpack(tensor1, "tensor1", 1);
auto& tensor2_ = unpack(tensor2, "tensor2", 2);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self, tensor1, tensor2 });
- auto grad_fn = std::make_shared<AddcmulBackward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<AddcmulBackward> grad_fn;
+ auto flags = compute_flags({ self, tensor1, tensor2 });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<AddcmulBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, tensor1, tensor2 });
grad_fn->tensor2_ = SavedVariable(tensor2, nullptr);
grad_fn->value = value;
grad_fn->tensor1_ = SavedVariable(tensor1, nullptr);
}
baseType->s_m_addcmul_(self_, tensor1_, tensor2_, value);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self, tensor1, tensor2 })) {
jit::Node *n = jit::tracer::recordTrace( "addcmul_", { self, tensor1, tensor2 }, self );
setattr(n, jit::stringToSymbol("value"), value);
@@ -3721,15 +4312,18 @@
auto& self_ = unpack(self, "self", 0);
auto& tensor1_ = unpack(tensor1, "tensor1", 1);
auto& tensor2_ = unpack(tensor2, "tensor2", 2);
- auto flags = Function::flags({ self, tensor1, tensor2 });
- auto grad_fn = std::make_shared<AddcdivBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<AddcdivBackward> grad_fn;
+ auto flags = compute_flags({ self, tensor1, tensor2 });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<AddcdivBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, tensor1, tensor2 });
grad_fn->tensor2_ = SavedVariable(tensor2, nullptr);
grad_fn->value = value;
grad_fn->tensor1_ = SavedVariable(tensor1, nullptr);
}
auto ret = as_variable(baseType->s_addcdiv(self_, tensor1_, tensor2_, value));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self, tensor1, tensor2 })) {
jit::Node *n = jit::tracer::recordTrace( "addcdiv", { self, tensor1, tensor2 }, ret );
setattr(n, jit::stringToSymbol("value"), value);
@@ -3740,18 +4334,20 @@
auto& self_ = unpack(self, "self", 0);
auto& tensor1_ = unpack(tensor1, "tensor1", 1);
auto& tensor2_ = unpack(tensor2, "tensor2", 2);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self, tensor1, tensor2 });
- auto grad_fn = std::make_shared<AddcdivBackward>();
- if (flags.is_executable) {
+ check_inplace(self);
+ std::shared_ptr<AddcdivBackward> grad_fn;
+ auto flags = compute_flags({ self, tensor1, tensor2 });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<AddcdivBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, tensor1, tensor2 });
grad_fn->tensor2_ = SavedVariable(tensor2, nullptr);
grad_fn->value = value;
grad_fn->tensor1_ = SavedVariable(tensor1, nullptr);
}
baseType->s_m_addcdiv_(self_, tensor1_, tensor2_, value);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
if (jit::tracer::isTracing({ self, tensor1, tensor2 })) {
jit::Node *n = jit::tracer::recordTrace( "addcdiv_", { self, tensor1, tensor2 }, self );
setattr(n, jit::stringToSymbol("value"), value);
@@ -3761,14 +4357,17 @@
std::tuple<Tensor,Tensor> VariableType::gesv(const Tensor & self, const Tensor & A) const {
auto& self_ = unpack(self, "self", 0);
auto& A_ = unpack(A, "A", 1);
- auto flags = Function::flags({ self, A });
- auto grad_fn = std::make_shared<GesvBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<GesvBackward> grad_fn;
+ auto flags = compute_flags({ self, A });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<GesvBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, A });
grad_fn->A_ = SavedVariable(A, nullptr);
}
auto ret = as_variable(baseType->gesv(self_, A_));
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(std::get<0>(ret), flags, grad_fn);
+ if (grad_fn) {
auto& solution = std::get<0>(ret);
grad_fn->solution_ = SavedVariable(solution, grad_fn.get());
}
@@ -3781,10 +4380,15 @@
std::tuple<Tensor,Tensor> VariableType::gels(const Tensor & self, const Tensor & A) const {
auto& self_ = unpack(self, "self", 0);
auto& A_ = unpack(A, "A", 1);
- auto flags = Function::flags({ self, A });
- auto grad_fn = std::make_shared<GelsBackward>();
+ std::shared_ptr<GelsBackward> grad_fn;
+ auto flags = compute_flags({ self, A });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<GelsBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, A });
+ }
auto ret = as_variable(baseType->gels(self_, A_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(std::get<0>(ret), flags, grad_fn);
if (jit::tracer::isTracing({ self, A })) {
jit::Node *n = jit::tracer::recordTrace( "gels", { self, A }, ret );
(void)n;
@@ -3794,12 +4398,18 @@
std::tuple<Tensor,Tensor> VariableType::trtrs(const Tensor & self, const Tensor & A, bool upper, bool transpose, bool unitriangular) const {
auto& self_ = unpack(self, "self", 0);
auto& A_ = unpack(A, "A", 1);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<TrtrsBackward>();
+ check_no_requires_grad(A, "A");
+ std::shared_ptr<TrtrsBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<TrtrsBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
auto ret = as_variable(baseType->trtrs(self_, A_, upper, transpose, unitriangular));
- wrap_output(ret, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ self })) {
- jit::Node *n = jit::tracer::recordTrace( "trtrs", { self }, ret );
+ set_flags(std::get<0>(ret), flags, grad_fn);
+ if (jit::tracer::isTracing({ self, A })) {
+ jit::Node *n = jit::tracer::recordTrace( "trtrs", { self, A }, ret );
setattr(n, jit::stringToSymbol("upper"), upper);
setattr(n, jit::stringToSymbol("transpose"), transpose);
setattr(n, jit::stringToSymbol("unitriangular"), unitriangular);
@@ -3808,10 +4418,15 @@
}
std::tuple<Tensor,Tensor> VariableType::symeig(const Tensor & self, bool eigenvectors, bool upper) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<SymeigBackward>();
+ std::shared_ptr<SymeigBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SymeigBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
auto ret = as_variable(baseType->symeig(self_, eigenvectors, upper));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(std::get<0>(ret), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "symeig", { self }, ret );
setattr(n, jit::stringToSymbol("eigenvectors"), eigenvectors);
@@ -3821,10 +4436,15 @@
}
std::tuple<Tensor,Tensor> VariableType::eig(const Tensor & self, bool eigenvectors) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<EigBackward>();
+ std::shared_ptr<EigBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<EigBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
auto ret = as_variable(baseType->eig(self_, eigenvectors));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(std::get<0>(ret), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "eig", { self }, ret );
setattr(n, jit::stringToSymbol("eigenvectors"), eigenvectors);
@@ -3833,10 +4453,15 @@
}
std::tuple<Tensor,Tensor,Tensor> VariableType::svd(const Tensor & self, bool some) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<SvdBackward>();
+ std::shared_ptr<SvdBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SvdBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
auto ret = as_variable(baseType->svd(self_, some));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(std::get<0>(ret), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "svd", { self }, ret );
setattr(n, jit::stringToSymbol("some"), some);
@@ -3845,11 +4470,16 @@
}
Tensor VariableType::inverse(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<InverseBackward>();
+ std::shared_ptr<InverseBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<InverseBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
auto ret = as_variable(baseType->inverse(self_));
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(ret, flags, grad_fn);
+ if (grad_fn) {
auto& output = ret;
grad_fn->output_ = SavedVariable(output, grad_fn.get());
}
@@ -3861,14 +4491,17 @@
}
Tensor VariableType::potrf(const Tensor & self, bool upper) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<PotrfBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<PotrfBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<PotrfBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->upper = upper;
}
auto ret = as_variable(baseType->potrf(self_, upper));
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(ret, flags, grad_fn);
+ if (grad_fn) {
auto& output = ret;
grad_fn->output_ = SavedVariable(output, grad_fn.get());
}
@@ -3881,10 +4514,15 @@
Tensor VariableType::potrs(const Tensor & self, const Tensor & input2, bool upper) const {
auto& self_ = unpack(self, "self", 0);
auto& input2_ = unpack(input2, "input2", 1);
- auto flags = Function::flags({ self, input2 });
- auto grad_fn = std::make_shared<PotrsBackward>();
+ std::shared_ptr<PotrsBackward> grad_fn;
+ auto flags = compute_flags({ self, input2 });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<PotrsBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, input2 });
+ }
auto ret = as_variable(baseType->potrs(self_, input2_, upper));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self, input2 })) {
jit::Node *n = jit::tracer::recordTrace( "potrs", { self, input2 }, ret );
setattr(n, jit::stringToSymbol("upper"), upper);
@@ -3893,10 +4531,15 @@
}
Tensor VariableType::potri(const Tensor & self, bool upper) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<PotriBackward>();
+ std::shared_ptr<PotriBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<PotriBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
auto ret = as_variable(baseType->potri(self_, upper));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "potri", { self }, ret );
setattr(n, jit::stringToSymbol("upper"), upper);
@@ -3905,10 +4548,15 @@
}
std::tuple<Tensor,Tensor> VariableType::pstrf(const Tensor & self, bool upper, Scalar tol) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<PstrfBackward>();
+ std::shared_ptr<PstrfBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<PstrfBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
auto ret = as_variable(baseType->pstrf(self_, upper, tol));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(std::get<0>(ret), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "pstrf", { self }, ret );
setattr(n, jit::stringToSymbol("upper"), upper);
@@ -3918,10 +4566,15 @@
}
std::tuple<Tensor,Tensor> VariableType::qr(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<QrBackward>();
+ std::shared_ptr<QrBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<QrBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
auto ret = as_variable(baseType->qr(self_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(std::get<0>(ret), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "qr", { self }, ret );
(void)n;
@@ -3930,10 +4583,15 @@
}
std::tuple<Tensor,Tensor> VariableType::geqrf(const Tensor & self) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<GeqrfBackward>();
+ std::shared_ptr<GeqrfBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<GeqrfBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
auto ret = as_variable(baseType->geqrf(self_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(std::get<0>(ret), flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "geqrf", { self }, ret );
(void)n;
@@ -3943,10 +4601,15 @@
Tensor VariableType::orgqr(const Tensor & self, const Tensor & input2) const {
auto& self_ = unpack(self, "self", 0);
auto& input2_ = unpack(input2, "input2", 1);
- auto flags = Function::flags({ self, input2 });
- auto grad_fn = std::make_shared<OrgqrBackward>();
+ std::shared_ptr<OrgqrBackward> grad_fn;
+ auto flags = compute_flags({ self, input2 });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<OrgqrBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, input2 });
+ }
auto ret = as_variable(baseType->orgqr(self_, input2_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self, input2 })) {
jit::Node *n = jit::tracer::recordTrace( "orgqr", { self, input2 }, ret );
(void)n;
@@ -3957,10 +4620,15 @@
auto& self_ = unpack(self, "self", 0);
auto& input2_ = unpack(input2, "input2", 1);
auto& input3_ = unpack(input3, "input3", 2);
- auto flags = Function::flags({ self, input2, input3 });
- auto grad_fn = std::make_shared<OrmqrBackward>();
+ std::shared_ptr<OrmqrBackward> grad_fn;
+ auto flags = compute_flags({ self, input2, input3 });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<OrmqrBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self, input2, input3 });
+ }
auto ret = as_variable(baseType->ormqr(self_, input2_, input3_, left, transpose));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self, input2, input3 })) {
jit::Node *n = jit::tracer::recordTrace( "ormqr", { self, input2, input3 }, ret );
setattr(n, jit::stringToSymbol("left"), left);
@@ -3971,12 +4639,18 @@
std::tuple<Tensor,Tensor> VariableType::btrifact(const Tensor & self, const Tensor & info, bool pivot) const {
auto& self_ = unpack(self, "self", 0);
auto& info_ = unpack(info, "info", 1);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<BtrifactBackward>();
+ check_no_requires_grad(info, "info");
+ std::shared_ptr<BtrifactBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<BtrifactBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
auto ret = as_variable(baseType->btrifact(self_, info_, pivot));
- wrap_output(ret, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ self })) {
- jit::Node *n = jit::tracer::recordTrace( "btrifact", { self }, ret );
+ set_flags(std::get<0>(ret), flags, grad_fn);
+ if (jit::tracer::isTracing({ self, info })) {
+ jit::Node *n = jit::tracer::recordTrace( "btrifact", { self, info }, ret );
setattr(n, jit::stringToSymbol("pivot"), pivot);
}
return std::tuple<Tensor,Tensor>(std::move(ret));
@@ -3985,12 +4659,19 @@
auto& self_ = unpack(self, "self", 0);
auto& LU_data_ = unpack(LU_data, "LU_data", 1);
auto& LU_pivots_ = unpack(LU_pivots, "LU_pivots", 2);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<BtrisolveBackward>();
+ check_no_requires_grad(LU_data, "LU_data");
+ check_no_requires_grad(LU_pivots, "LU_pivots");
+ std::shared_ptr<BtrisolveBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<BtrisolveBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
auto ret = as_variable(baseType->btrisolve(self_, LU_data_, LU_pivots_));
- wrap_output(ret, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ self })) {
- jit::Node *n = jit::tracer::recordTrace( "btrisolve", { self }, ret );
+ set_flags(ret, flags, grad_fn);
+ if (jit::tracer::isTracing({ self, LU_data, LU_pivots })) {
+ jit::Node *n = jit::tracer::recordTrace( "btrisolve", { self, LU_data, LU_pivots }, ret );
(void)n;
}
return Tensor(std::move(ret));
@@ -4030,13 +4711,17 @@
}
Tensor & VariableType::m_log_normal_(Tensor & self, double mean, double std, Generator * generator) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<LogNormalBackward>();
+ check_inplace(self);
+ std::shared_ptr<LogNormalBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<LogNormalBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
baseType->m_log_normal_(self_, mean, std, generator);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
return self;
}
Tensor & VariableType::m_exponential_(Tensor & self, double lambd, Generator * generator) const {
@@ -4050,13 +4735,17 @@
}
Tensor & VariableType::m_geometric_(Tensor & self, double p, Generator * generator) const {
auto& self_ = unpack(self, "self", 0);
- auto& pImpl = static_cast<VariableImpl&>(*self.get());
- check_inplace(pImpl);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<GeometricBackward>();
+ check_inplace(self);
+ std::shared_ptr<GeometricBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<GeometricBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
+ }
baseType->m_geometric_(self_, p, generator);
- pImpl.version_counter.increment();
- wrap_output(self, std::move(flags), grad_fn);
+ increment_version(self);
+ set_flags(static_cast<Variable&>(self), flags, grad_fn);
return self;
}
int64_t VariableType::m_size(const Tensor & self, int64_t dim) const {
@@ -4084,15 +4773,18 @@
}
Tensor VariableType::_unnarrow(const Tensor & self, int64_t dimension, int64_t offset, int64_t dimSize) const {
auto& self_ = unpack(self, "self", 0);
- auto flags = Function::flags({ self });
- auto grad_fn = std::make_shared<UnnarrowBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<UnnarrowBackward> grad_fn;
+ auto flags = compute_flags({ self });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<UnnarrowBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ self });
grad_fn->self_argsize_dimension = self.size(dimension);
grad_fn->dimension = dimension;
grad_fn->offset = offset;
}
auto ret = as_variable(baseType->_unnarrow(self_, dimension, offset, dimSize));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ self })) {
jit::Node *n = jit::tracer::recordTrace( "_unnarrow", { self }, ret );
setattr(n, jit::stringToSymbol("dimension"), dimension);
@@ -4106,14 +4798,17 @@
}
Tensor VariableType::cat(TensorList tensors, int64_t dim) const {
auto tensors_ = unpack(tensors, "tensors", 0);
- auto flags = Function::flags({ tensors });
- auto grad_fn = std::make_shared<CatBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<CatBackward> grad_fn;
+ auto flags = compute_flags({ tensors });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<CatBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ tensors });
grad_fn->tensors_argsizes_dim = to_arg_sizes(tensors, dim);
grad_fn->dim = dim;
}
auto ret = as_variable(baseType->cat(tensors_, dim));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ tensors })) {
jit::Node *n = jit::tracer::recordTrace( "cat", { tensors }, ret );
setattr(n, jit::stringToSymbol("dim"), dim);
@@ -4124,18 +4819,23 @@
auto& input_ = unpack(input, "input", 0);
auto& target_ = unpack(target, "target", 1);
auto weight_ = unpack_opt(weight, "weight", 2);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<BinaryCrossEntropyBackward>();
- if (flags.is_executable) {
+ check_no_requires_grad(target, "target");
+ check_no_requires_grad(weight, "weight");
+ std::shared_ptr<BinaryCrossEntropyBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<BinaryCrossEntropyBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
grad_fn->input_ = SavedVariable(input, nullptr);
grad_fn->target_ = SavedVariable(target, nullptr);
grad_fn->weight_ = SavedVariable(weight, nullptr);
grad_fn->size_average = size_average;
}
auto ret = as_variable(baseType->binary_cross_entropy_forward(input_, target_, weight_, size_average));
- wrap_output(ret, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ input })) {
- jit::Node *n = jit::tracer::recordTrace( "binary_cross_entropy", { input }, ret );
+ set_flags(ret, flags, grad_fn);
+ if (jit::tracer::isTracing({ input, target, weight })) {
+ jit::Node *n = jit::tracer::recordTrace( "binary_cross_entropy", { input, target, weight }, ret );
setattr(n, jit::stringToSymbol("size_average"), size_average);
}
return Tensor(std::move(ret));
@@ -4144,30 +4844,39 @@
auto& input_ = unpack(input, "input", 0);
auto& target_ = unpack(target, "target", 1);
auto weight_ = unpack_opt(weight, "weight", 2);
- auto flags = Function::flags({ input, target, weight });
- auto grad_fn = std::make_shared<Error>("binary_cross_entropy_backward is not differentiable");
+ std::shared_ptr<Error> grad_fn;
+ auto flags = compute_flags({ input, target, weight });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<Error>("the derivative for binary_cross_entropy_backward is not implemented");
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input, target, weight });
+ }
auto ret = as_variable(baseType->binary_cross_entropy_backward(input_, target_, weight_, size_average));
- wrap_output(ret, std::move(flags), std::move(grad_fn));
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ input, target, weight })) {
jit::Node *n = jit::tracer::recordTrace( "binary_cross_entropy_backward", { input, target, weight }, ret );
setattr(n, jit::stringToSymbol("size_average"), size_average);
}
- return ret;
+ return Tensor(std::move(ret));
}
Tensor VariableType::kl_div(const Tensor & input, const Tensor & target, bool size_average) const {
auto& input_ = unpack(input, "input", 0);
auto& target_ = unpack(target, "target", 1);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<KlDivBackward>();
- if (flags.is_executable) {
+ check_no_requires_grad(target, "target");
+ std::shared_ptr<KlDivBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<KlDivBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
grad_fn->input_ = SavedVariable(input, nullptr);
grad_fn->target_ = SavedVariable(target, nullptr);
grad_fn->size_average = size_average;
}
auto ret = as_variable(baseType->kl_div_forward(input_, target_, size_average));
- wrap_output(ret, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ input })) {
- jit::Node *n = jit::tracer::recordTrace( "kl_div", { input }, ret );
+ set_flags(ret, flags, grad_fn);
+ if (jit::tracer::isTracing({ input, target })) {
+ jit::Node *n = jit::tracer::recordTrace( "kl_div", { input, target }, ret );
setattr(n, jit::stringToSymbol("size_average"), size_average);
}
return Tensor(std::move(ret));
@@ -4175,12 +4884,18 @@
Tensor VariableType::kl_div_backward(const Tensor & input, const Tensor & target, bool size_average) const {
auto& input_ = unpack(input, "input", 0);
auto& target_ = unpack(target, "target", 1);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<KlDivBackwardBackward>();
+ check_no_requires_grad(target, "target");
+ std::shared_ptr<KlDivBackwardBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<KlDivBackwardBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
+ }
auto ret = as_variable(baseType->kl_div_backward(input_, target_, size_average));
- wrap_output(ret, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ input })) {
- jit::Node *n = jit::tracer::recordTrace( "kl_div_backward", { input }, ret );
+ set_flags(ret, flags, grad_fn);
+ if (jit::tracer::isTracing({ input, target })) {
+ jit::Node *n = jit::tracer::recordTrace( "kl_div_backward", { input, target }, ret );
setattr(n, jit::stringToSymbol("size_average"), size_average);
}
return Tensor(std::move(ret));
@@ -4188,17 +4903,21 @@
Tensor VariableType::l1_loss(const Tensor & input, const Tensor & target, bool size_average) const {
auto& input_ = unpack(input, "input", 0);
auto& target_ = unpack(target, "target", 1);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<L1LossBackward>();
- if (flags.is_executable) {
+ check_no_requires_grad(target, "target");
+ std::shared_ptr<L1LossBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<L1LossBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
grad_fn->input_ = SavedVariable(input, nullptr);
grad_fn->target_ = SavedVariable(target, nullptr);
grad_fn->size_average = size_average;
}
auto ret = as_variable(baseType->l1_loss_forward(input_, target_, size_average));
- wrap_output(ret, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ input })) {
- jit::Node *n = jit::tracer::recordTrace( "l1_loss", { input }, ret );
+ set_flags(ret, flags, grad_fn);
+ if (jit::tracer::isTracing({ input, target })) {
+ jit::Node *n = jit::tracer::recordTrace( "l1_loss", { input, target }, ret );
setattr(n, jit::stringToSymbol("size_average"), size_average);
}
return Tensor(std::move(ret));
@@ -4206,12 +4925,18 @@
Tensor VariableType::l1_loss_backward(const Tensor & input, const Tensor & target, bool size_average) const {
auto& input_ = unpack(input, "input", 0);
auto& target_ = unpack(target, "target", 1);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<L1LossBackwardBackward>();
+ check_no_requires_grad(target, "target");
+ std::shared_ptr<L1LossBackwardBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<L1LossBackwardBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
+ }
auto ret = as_variable(baseType->l1_loss_backward(input_, target_, size_average));
- wrap_output(ret, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ input })) {
- jit::Node *n = jit::tracer::recordTrace( "l1_loss_backward", { input }, ret );
+ set_flags(ret, flags, grad_fn);
+ if (jit::tracer::isTracing({ input, target })) {
+ jit::Node *n = jit::tracer::recordTrace( "l1_loss_backward", { input, target }, ret );
setattr(n, jit::stringToSymbol("size_average"), size_average);
}
return Tensor(std::move(ret));
@@ -4219,18 +4944,22 @@
Tensor VariableType::mse_loss(const Tensor & input, const Tensor & target, bool size_average, bool reduce) const {
auto& input_ = unpack(input, "input", 0);
auto& target_ = unpack(target, "target", 1);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<MseLossBackward>();
- if (flags.is_executable) {
+ check_no_requires_grad(target, "target");
+ std::shared_ptr<MseLossBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<MseLossBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
grad_fn->input_ = SavedVariable(input, nullptr);
grad_fn->target_ = SavedVariable(target, nullptr);
grad_fn->size_average = size_average;
grad_fn->reduce = reduce;
}
auto ret = as_variable(baseType->mse_loss_forward(input_, target_, size_average, reduce));
- wrap_output(ret, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ input })) {
- jit::Node *n = jit::tracer::recordTrace( "mse_loss", { input }, ret );
+ set_flags(ret, flags, grad_fn);
+ if (jit::tracer::isTracing({ input, target })) {
+ jit::Node *n = jit::tracer::recordTrace( "mse_loss", { input, target }, ret );
setattr(n, jit::stringToSymbol("size_average"), size_average);
setattr(n, jit::stringToSymbol("reduce"), reduce);
}
@@ -4240,9 +4969,13 @@
auto& grad_output_ = unpack(grad_output, "grad_output", 0);
auto& input_ = unpack(input, "input", 1);
auto& target_ = unpack(target, "target", 2);
- auto flags = Function::flags({ grad_output, input });
- auto grad_fn = std::make_shared<MseLossBackwardBackward>();
- if (flags.is_executable) {
+ check_no_requires_grad(target, "target");
+ std::shared_ptr<MseLossBackwardBackward> grad_fn;
+ auto flags = compute_flags({ grad_output, input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<MseLossBackwardBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ grad_output, input });
grad_fn->grad_output_ = SavedVariable(grad_output, nullptr);
grad_fn->input_ = SavedVariable(input, nullptr);
grad_fn->target_ = SavedVariable(target, nullptr);
@@ -4250,9 +4983,9 @@
grad_fn->reduce = reduce;
}
auto ret = as_variable(baseType->mse_loss_backward(grad_output_, input_, target_, size_average, reduce));
- wrap_output(ret, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ grad_output, input })) {
- jit::Node *n = jit::tracer::recordTrace( "mse_loss_backward", { grad_output, input }, ret );
+ set_flags(ret, flags, grad_fn);
+ if (jit::tracer::isTracing({ grad_output, input, target })) {
+ jit::Node *n = jit::tracer::recordTrace( "mse_loss_backward", { grad_output, input, target }, ret );
setattr(n, jit::stringToSymbol("size_average"), size_average);
setattr(n, jit::stringToSymbol("reduce"), reduce);
}
@@ -4262,9 +4995,13 @@
auto& input_ = unpack(input, "input", 0);
auto& target_ = unpack_long(target, "target", 1);
auto weight_ = unpack_opt(weight, "weight", 4);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<MultiMarginLossBackward>();
- if (flags.is_executable) {
+ check_no_requires_grad(weight, "weight");
+ std::shared_ptr<MultiMarginLossBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<MultiMarginLossBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
grad_fn->input_ = SavedVariable(input, nullptr);
grad_fn->target_ = SavedVariable(target, nullptr);
grad_fn->p = p;
@@ -4273,9 +5010,9 @@
grad_fn->size_average = size_average;
}
auto ret = as_variable(baseType->multi_margin_loss_forward(input_, target_, p, margin, weight_, size_average));
- wrap_output(ret, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ input })) {
- jit::Node *n = jit::tracer::recordTrace( "multi_margin_loss", { input }, ret );
+ set_flags(ret, flags, grad_fn);
+ if (jit::tracer::isTracing({ input, target, weight })) {
+ jit::Node *n = jit::tracer::recordTrace( "multi_margin_loss", { input, target, weight }, ret );
setattr(n, jit::stringToSymbol("p"), p);
setattr(n, jit::stringToSymbol("margin"), margin);
setattr(n, jit::stringToSymbol("size_average"), size_average);
@@ -4286,35 +5023,43 @@
auto& input_ = unpack(input, "input", 0);
auto& target_ = unpack_long(target, "target", 1);
auto weight_ = unpack_opt(weight, "weight", 4);
- auto flags = Function::flags({ input, target, weight });
- auto grad_fn = std::make_shared<Error>("multi_margin_loss_backward is not differentiable");
+ std::shared_ptr<Error> grad_fn;
+ auto flags = compute_flags({ input, target, weight });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<Error>("the derivative for multi_margin_loss_backward is not implemented");
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input, target, weight });
+ }
auto ret = as_variable(baseType->multi_margin_loss_backward(input_, target_, p, margin, weight_, size_average));
- wrap_output(ret, std::move(flags), std::move(grad_fn));
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ input, target, weight })) {
jit::Node *n = jit::tracer::recordTrace( "multi_margin_loss_backward", { input, target, weight }, ret );
setattr(n, jit::stringToSymbol("p"), p);
setattr(n, jit::stringToSymbol("margin"), margin);
setattr(n, jit::stringToSymbol("size_average"), size_average);
}
- return ret;
+ return Tensor(std::move(ret));
}
Tensor VariableType::multilabel_margin_loss(const Tensor & input, const Tensor & target, bool size_average) const {
auto& input_ = unpack(input, "input", 0);
auto& target_ = unpack_long(target, "target", 1);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<MultilabelMarginLossBackward>();
auto is_target = tensor();
auto& is_target_ = static_cast<VariableImpl*>(is_target.get())->data;
- if (flags.is_executable) {
+ std::shared_ptr<MultilabelMarginLossBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<MultilabelMarginLossBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
grad_fn->input_ = SavedVariable(input, nullptr);
grad_fn->target_ = SavedVariable(target, nullptr);
grad_fn->size_average = size_average;
grad_fn->is_target_ = SavedVariable(is_target, nullptr);
}
auto ret = as_variable(baseType->multilabel_margin_loss_forward(input_, target_, size_average, is_target_));
- wrap_output(ret, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ input })) {
- jit::Node *n = jit::tracer::recordTrace( "multilabel_margin_loss", { input }, ret );
+ set_flags(ret, flags, grad_fn);
+ if (jit::tracer::isTracing({ input, target })) {
+ jit::Node *n = jit::tracer::recordTrace( "multilabel_margin_loss", { input, target }, ret );
setattr(n, jit::stringToSymbol("size_average"), size_average);
}
return Tensor(std::move(ret));
@@ -4323,25 +5068,34 @@
auto& input_ = unpack(input, "input", 0);
auto& target_ = unpack_long(target, "target", 1);
auto& is_target_ = unpack(is_target, "is_target", 3);
- auto flags = Function::flags({ input, target, is_target });
- auto grad_fn = std::make_shared<Error>("multilabel_margin_loss_backward is not differentiable");
+ std::shared_ptr<Error> grad_fn;
+ auto flags = compute_flags({ input, target, is_target });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<Error>("the derivative for multilabel_margin_loss_backward is not implemented");
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input, target, is_target });
+ }
auto ret = as_variable(baseType->multilabel_margin_loss_backward(input_, target_, size_average, is_target_));
- wrap_output(ret, std::move(flags), std::move(grad_fn));
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ input, target, is_target })) {
jit::Node *n = jit::tracer::recordTrace( "multilabel_margin_loss_backward", { input, target, is_target }, ret );
setattr(n, jit::stringToSymbol("size_average"), size_average);
}
- return ret;
+ return Tensor(std::move(ret));
}
Tensor VariableType::nll_loss(const Tensor & input, const Tensor & target, const Tensor & weight, bool size_average, int64_t ignore_index) const {
auto& input_ = unpack(input, "input", 0);
auto& target_ = unpack_long(target, "target", 1);
auto weight_ = unpack_opt(weight, "weight", 2);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<NllLossBackward>();
auto total_weight = tensor();
auto& total_weight_ = static_cast<VariableImpl*>(total_weight.get())->data;
- if (flags.is_executable) {
+ check_no_requires_grad(weight, "weight");
+ std::shared_ptr<NllLossBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<NllLossBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
grad_fn->input_ = SavedVariable(input, nullptr);
grad_fn->target_ = SavedVariable(target, nullptr);
grad_fn->weight_ = SavedVariable(weight, nullptr);
@@ -4350,9 +5104,9 @@
grad_fn->total_weight_ = SavedVariable(total_weight, nullptr);
}
auto ret = as_variable(baseType->nll_loss_forward(input_, target_, weight_, size_average, ignore_index, total_weight_));
- wrap_output(ret, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ input })) {
- jit::Node *n = jit::tracer::recordTrace( "nll_loss", { input }, ret );
+ set_flags(ret, flags, grad_fn);
+ if (jit::tracer::isTracing({ input, target, weight })) {
+ jit::Node *n = jit::tracer::recordTrace( "nll_loss", { input, target, weight }, ret );
setattr(n, jit::stringToSymbol("size_average"), size_average);
setattr(n, jit::stringToSymbol("ignore_index"), ignore_index);
}
@@ -4363,12 +5117,19 @@
auto& target_ = unpack_long(target, "target", 1);
auto weight_ = unpack_opt(weight, "weight", 2);
auto& total_weight_ = unpack(total_weight, "total_weight", 5);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<NllLossBackwardBackward>();
+ check_no_requires_grad(weight, "weight");
+ check_no_requires_grad(total_weight, "total_weight");
+ std::shared_ptr<NllLossBackwardBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<NllLossBackwardBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
+ }
auto ret = as_variable(baseType->nll_loss_backward(input_, target_, weight_, size_average, ignore_index, total_weight_));
- wrap_output(ret, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ input })) {
- jit::Node *n = jit::tracer::recordTrace( "nll_loss_backward", { input }, ret );
+ set_flags(ret, flags, grad_fn);
+ if (jit::tracer::isTracing({ input, target, weight, total_weight })) {
+ jit::Node *n = jit::tracer::recordTrace( "nll_loss_backward", { input, target, weight, total_weight }, ret );
setattr(n, jit::stringToSymbol("size_average"), size_average);
setattr(n, jit::stringToSymbol("ignore_index"), ignore_index);
}
@@ -4378,11 +5139,15 @@
auto& input_ = unpack(input, "input", 0);
auto& target_ = unpack_long(target, "target", 1);
auto weight_ = unpack_opt(weight, "weight", 2);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<NllLoss2DBackward>();
auto total_weight = tensor();
auto& total_weight_ = static_cast<VariableImpl*>(total_weight.get())->data;
- if (flags.is_executable) {
+ check_no_requires_grad(weight, "weight");
+ std::shared_ptr<NllLoss2DBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<NllLoss2DBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
grad_fn->input_ = SavedVariable(input, nullptr);
grad_fn->target_ = SavedVariable(target, nullptr);
grad_fn->weight_ = SavedVariable(weight, nullptr);
@@ -4391,9 +5156,9 @@
grad_fn->total_weight_ = SavedVariable(total_weight, nullptr);
}
auto ret = as_variable(baseType->nll_loss2d_forward(input_, target_, weight_, size_average, ignore_index, total_weight_));
- wrap_output(ret, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ input })) {
- jit::Node *n = jit::tracer::recordTrace( "nll_loss2d", { input }, ret );
+ set_flags(ret, flags, grad_fn);
+ if (jit::tracer::isTracing({ input, target, weight })) {
+ jit::Node *n = jit::tracer::recordTrace( "nll_loss2d", { input, target, weight }, ret );
setattr(n, jit::stringToSymbol("size_average"), size_average);
setattr(n, jit::stringToSymbol("ignore_index"), ignore_index);
}
@@ -4404,12 +5169,19 @@
auto& target_ = unpack_long(target, "target", 1);
auto weight_ = unpack_opt(weight, "weight", 2);
auto& total_weight_ = unpack(total_weight, "total_weight", 5);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<NllLoss2DBackwardBackward>();
+ check_no_requires_grad(weight, "weight");
+ check_no_requires_grad(total_weight, "total_weight");
+ std::shared_ptr<NllLoss2DBackwardBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<NllLoss2DBackwardBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
+ }
auto ret = as_variable(baseType->nll_loss2d_backward(input_, target_, weight_, size_average, ignore_index, total_weight_));
- wrap_output(ret, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ input })) {
- jit::Node *n = jit::tracer::recordTrace( "nll_loss2d_backward", { input }, ret );
+ set_flags(ret, flags, grad_fn);
+ if (jit::tracer::isTracing({ input, target, weight, total_weight })) {
+ jit::Node *n = jit::tracer::recordTrace( "nll_loss2d_backward", { input, target, weight, total_weight }, ret );
setattr(n, jit::stringToSymbol("size_average"), size_average);
setattr(n, jit::stringToSymbol("ignore_index"), ignore_index);
}
@@ -4418,17 +5190,21 @@
Tensor VariableType::smooth_l1_loss(const Tensor & input, const Tensor & target, bool size_average) const {
auto& input_ = unpack(input, "input", 0);
auto& target_ = unpack(target, "target", 1);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<SmoothL1LossBackward>();
- if (flags.is_executable) {
+ check_no_requires_grad(target, "target");
+ std::shared_ptr<SmoothL1LossBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SmoothL1LossBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
grad_fn->input_ = SavedVariable(input, nullptr);
grad_fn->target_ = SavedVariable(target, nullptr);
grad_fn->size_average = size_average;
}
auto ret = as_variable(baseType->smooth_l1_loss_forward(input_, target_, size_average));
- wrap_output(ret, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ input })) {
- jit::Node *n = jit::tracer::recordTrace( "smooth_l1_loss", { input }, ret );
+ set_flags(ret, flags, grad_fn);
+ if (jit::tracer::isTracing({ input, target })) {
+ jit::Node *n = jit::tracer::recordTrace( "smooth_l1_loss", { input, target }, ret );
setattr(n, jit::stringToSymbol("size_average"), size_average);
}
return Tensor(std::move(ret));
@@ -4436,17 +5212,21 @@
Tensor VariableType::smooth_l1_loss_backward(const Tensor & input, const Tensor & target, bool size_average) const {
auto& input_ = unpack(input, "input", 0);
auto& target_ = unpack(target, "target", 1);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<SmoothL1LossBackwardBackward>();
- if (flags.is_executable) {
+ check_no_requires_grad(target, "target");
+ std::shared_ptr<SmoothL1LossBackwardBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SmoothL1LossBackwardBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
grad_fn->input_ = SavedVariable(input, nullptr);
grad_fn->target_ = SavedVariable(target, nullptr);
grad_fn->size_average = size_average;
}
auto ret = as_variable(baseType->smooth_l1_loss_backward(input_, target_, size_average));
- wrap_output(ret, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ input })) {
- jit::Node *n = jit::tracer::recordTrace( "smooth_l1_loss_backward", { input }, ret );
+ set_flags(ret, flags, grad_fn);
+ if (jit::tracer::isTracing({ input, target })) {
+ jit::Node *n = jit::tracer::recordTrace( "smooth_l1_loss_backward", { input, target }, ret );
setattr(n, jit::stringToSymbol("size_average"), size_average);
}
return Tensor(std::move(ret));
@@ -4454,17 +5234,21 @@
Tensor VariableType::soft_margin_loss(const Tensor & input, const Tensor & target, bool size_average) const {
auto& input_ = unpack(input, "input", 0);
auto& target_ = unpack(target, "target", 1);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<SoftMarginLossBackward>();
- if (flags.is_executable) {
+ check_no_requires_grad(target, "target");
+ std::shared_ptr<SoftMarginLossBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SoftMarginLossBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
grad_fn->input_ = SavedVariable(input, nullptr);
grad_fn->target_ = SavedVariable(target, nullptr);
grad_fn->size_average = size_average;
}
auto ret = as_variable(baseType->soft_margin_loss_forward(input_, target_, size_average));
- wrap_output(ret, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ input })) {
- jit::Node *n = jit::tracer::recordTrace( "soft_margin_loss", { input }, ret );
+ set_flags(ret, flags, grad_fn);
+ if (jit::tracer::isTracing({ input, target })) {
+ jit::Node *n = jit::tracer::recordTrace( "soft_margin_loss", { input, target }, ret );
setattr(n, jit::stringToSymbol("size_average"), size_average);
}
return Tensor(std::move(ret));
@@ -4472,32 +5256,40 @@
Tensor VariableType::soft_margin_loss_backward(const Tensor & input, const Tensor & target, bool size_average) const {
auto& input_ = unpack(input, "input", 0);
auto& target_ = unpack(target, "target", 1);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<SoftMarginLossBackwardBackward>();
- if (flags.is_executable) {
+ check_no_requires_grad(target, "target");
+ std::shared_ptr<SoftMarginLossBackwardBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SoftMarginLossBackwardBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
grad_fn->input_ = SavedVariable(input, nullptr);
grad_fn->target_ = SavedVariable(target, nullptr);
grad_fn->size_average = size_average;
}
auto ret = as_variable(baseType->soft_margin_loss_backward(input_, target_, size_average));
- wrap_output(ret, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ input })) {
- jit::Node *n = jit::tracer::recordTrace( "soft_margin_loss_backward", { input }, ret );
+ set_flags(ret, flags, grad_fn);
+ if (jit::tracer::isTracing({ input, target })) {
+ jit::Node *n = jit::tracer::recordTrace( "soft_margin_loss_backward", { input, target }, ret );
setattr(n, jit::stringToSymbol("size_average"), size_average);
}
return Tensor(std::move(ret));
}
Tensor VariableType::elu(const Tensor & input, Scalar alpha, bool inplace) const {
auto& input_ = unpack(input, "input", 0);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<EluBackward>();
- if (flags.is_executable) {
+ if (inplace) check_inplace(input);
+ std::shared_ptr<EluBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<EluBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
grad_fn->alpha = alpha;
}
auto ret = as_variable(baseType->elu_forward(input_, alpha, inplace));
if (inplace) increment_version(input);
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(ret, flags, grad_fn);
+ if (grad_fn) {
auto& output = ret;
grad_fn->output_ = SavedVariable(output, grad_fn.get());
grad_fn->input_ = SavedVariable(input, grad_fn.get());
@@ -4513,23 +5305,27 @@
auto& grad_output_ = unpack(grad_output, "grad_output", 0);
auto& input_ = unpack(input, "input", 1);
auto& output_ = unpack(output, "output", 4);
- auto flags = Function::flags({ grad_output, input });
- auto grad_fn = std::make_shared<EluBackwardBackward>();
- if (flags.is_executable) {
+ if (inplace) check_inplace(input);
+ std::shared_ptr<EluBackwardBackward> grad_fn;
+ auto flags = compute_flags({ grad_output, input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<EluBackwardBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ grad_output, input });
grad_fn->alpha = alpha;
grad_fn->inplace = inplace;
grad_fn->output_ = SavedVariable(output, nullptr);
}
auto ret = as_variable(baseType->elu_backward(grad_output_, input_, alpha, inplace, output_));
if (inplace) increment_version(input);
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(ret, flags, grad_fn);
+ if (grad_fn) {
grad_fn->input_ = SavedVariable(input, grad_fn.get());
auto& grad_input = ret;
grad_fn->grad_input_ = SavedVariable(grad_input, grad_fn.get());
}
- if (jit::tracer::isTracing({ grad_output, input })) {
- jit::Node *n = jit::tracer::recordTrace( "elu_backward", { grad_output, input }, ret );
+ if (jit::tracer::isTracing({ grad_output, input, output })) {
+ jit::Node *n = jit::tracer::recordTrace( "elu_backward", { grad_output, input, output }, ret );
setattr(n, jit::stringToSymbol("alpha"), alpha);
setattr(n, jit::stringToSymbol("inplace"), inplace);
}
@@ -4537,14 +5333,17 @@
}
Tensor VariableType::glu(const Tensor & input, int64_t dim) const {
auto& input_ = unpack(input, "input", 0);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<GluBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<GluBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<GluBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
grad_fn->input_ = SavedVariable(input, nullptr);
grad_fn->dim = dim;
}
auto ret = as_variable(baseType->glu_forward(input_, dim));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ input })) {
jit::Node *n = jit::tracer::recordTrace( "glu", { input }, ret );
setattr(n, jit::stringToSymbol("dim"), dim);
@@ -4554,15 +5353,18 @@
Tensor VariableType::glu_backward(const Tensor & grad_output, const Tensor & input, int64_t dim) const {
auto& grad_output_ = unpack(grad_output, "grad_output", 0);
auto& input_ = unpack(input, "input", 1);
- auto flags = Function::flags({ grad_output, input });
- auto grad_fn = std::make_shared<GluBackwardBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<GluBackwardBackward> grad_fn;
+ auto flags = compute_flags({ grad_output, input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<GluBackwardBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ grad_output, input });
grad_fn->input_ = SavedVariable(input, nullptr);
grad_fn->dim = dim;
grad_fn->grad_output_ = SavedVariable(grad_output, nullptr);
}
auto ret = as_variable(baseType->glu_backward(grad_output_, input_, dim));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ grad_output, input })) {
jit::Node *n = jit::tracer::recordTrace( "glu_backward", { grad_output, input }, ret );
setattr(n, jit::stringToSymbol("dim"), dim);
@@ -4571,14 +5373,17 @@
}
Tensor VariableType::hardshrink(const Tensor & input, Scalar lambd) const {
auto& input_ = unpack(input, "input", 0);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<HardshrinkBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<HardshrinkBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<HardshrinkBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
grad_fn->input_ = SavedVariable(input, nullptr);
grad_fn->lambd = lambd;
}
auto ret = as_variable(baseType->hardshrink_forward(input_, lambd));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ input })) {
jit::Node *n = jit::tracer::recordTrace( "hardshrink", { input }, ret );
setattr(n, jit::stringToSymbol("lambd"), lambd);
@@ -4588,14 +5393,17 @@
Tensor VariableType::hardshrink_backward(const Tensor & grad_output, const Tensor & input, Scalar lambd) const {
auto& grad_output_ = unpack(grad_output, "grad_output", 0);
auto& input_ = unpack(input, "input", 1);
- auto flags = Function::flags({ grad_output, input });
- auto grad_fn = std::make_shared<HardshrinkBackwardBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<HardshrinkBackwardBackward> grad_fn;
+ auto flags = compute_flags({ grad_output, input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<HardshrinkBackwardBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ grad_output, input });
grad_fn->input_ = SavedVariable(input, nullptr);
grad_fn->lambd = lambd;
}
auto ret = as_variable(baseType->hardshrink_backward(grad_output_, input_, lambd));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ grad_output, input })) {
jit::Node *n = jit::tracer::recordTrace( "hardshrink_backward", { grad_output, input }, ret );
setattr(n, jit::stringToSymbol("lambd"), lambd);
@@ -4604,16 +5412,20 @@
}
Tensor VariableType::hardtanh(const Tensor & input, Scalar min_val, Scalar max_val, bool inplace) const {
auto& input_ = unpack(input, "input", 0);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<HardtanhBackward>();
- if (flags.is_executable) {
+ if (inplace) check_inplace(input);
+ std::shared_ptr<HardtanhBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<HardtanhBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
grad_fn->min_val = min_val;
grad_fn->max_val = max_val;
}
auto ret = as_variable(baseType->hardtanh_forward(input_, min_val, max_val, inplace));
if (inplace) increment_version(input);
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(ret, flags, grad_fn);
+ if (grad_fn) {
grad_fn->input_ = SavedVariable(input, grad_fn.get());
}
if (jit::tracer::isTracing({ input })) {
@@ -4627,16 +5439,20 @@
Tensor VariableType::hardtanh_backward(const Tensor & grad_output, const Tensor & input, Scalar min_val, Scalar max_val, bool inplace) const {
auto& grad_output_ = unpack(grad_output, "grad_output", 0);
auto& input_ = unpack(input, "input", 1);
- auto flags = Function::flags({ grad_output, input });
- auto grad_fn = std::make_shared<HardtanhBackwardBackward>();
- if (flags.is_executable) {
+ if (inplace) check_inplace(input);
+ std::shared_ptr<HardtanhBackwardBackward> grad_fn;
+ auto flags = compute_flags({ grad_output, input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<HardtanhBackwardBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ grad_output, input });
grad_fn->min_val = min_val;
grad_fn->max_val = max_val;
}
auto ret = as_variable(baseType->hardtanh_backward(grad_output_, input_, min_val, max_val, inplace));
if (inplace) increment_version(input);
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(ret, flags, grad_fn);
+ if (grad_fn) {
grad_fn->input_ = SavedVariable(input, grad_fn.get());
}
if (jit::tracer::isTracing({ grad_output, input })) {
@@ -4649,15 +5465,19 @@
}
Tensor VariableType::leaky_relu(const Tensor & input, Scalar negative_slope, bool inplace) const {
auto& input_ = unpack(input, "input", 0);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<LeakyReluBackward>();
- if (flags.is_executable) {
+ if (inplace) check_inplace(input);
+ std::shared_ptr<LeakyReluBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<LeakyReluBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
grad_fn->negative_slope = negative_slope;
}
auto ret = as_variable(baseType->leaky_relu_forward(input_, negative_slope, inplace));
if (inplace) increment_version(input);
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(ret, flags, grad_fn);
+ if (grad_fn) {
grad_fn->input_ = SavedVariable(input, grad_fn.get());
}
if (jit::tracer::isTracing({ input })) {
@@ -4670,15 +5490,19 @@
Tensor VariableType::leaky_relu_backward(const Tensor & grad_output, const Tensor & input, Scalar negative_slope, bool inplace) const {
auto& grad_output_ = unpack(grad_output, "grad_output", 0);
auto& input_ = unpack(input, "input", 1);
- auto flags = Function::flags({ grad_output, input });
- auto grad_fn = std::make_shared<LeakyReluBackwardBackward>();
- if (flags.is_executable) {
+ if (inplace) check_inplace(input);
+ std::shared_ptr<LeakyReluBackwardBackward> grad_fn;
+ auto flags = compute_flags({ grad_output, input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<LeakyReluBackwardBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ grad_output, input });
grad_fn->negative_slope = negative_slope;
}
auto ret = as_variable(baseType->leaky_relu_backward(grad_output_, input_, negative_slope, inplace));
if (inplace) increment_version(input);
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(ret, flags, grad_fn);
+ if (grad_fn) {
grad_fn->input_ = SavedVariable(input, grad_fn.get());
}
if (jit::tracer::isTracing({ grad_output, input })) {
@@ -4690,16 +5514,19 @@
}
Tensor VariableType::log_sigmoid(const Tensor & input) const {
auto& input_ = unpack(input, "input", 0);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<LogSigmoidBackward>();
auto buffer = tensor();
auto& buffer_ = static_cast<VariableImpl*>(buffer.get())->data;
- if (flags.is_executable) {
+ std::shared_ptr<LogSigmoidBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<LogSigmoidBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
grad_fn->input_ = SavedVariable(input, nullptr);
grad_fn->buffer_ = SavedVariable(buffer, nullptr);
}
auto ret = as_variable(baseType->log_sigmoid_forward(input_, buffer_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ input })) {
jit::Node *n = jit::tracer::recordTrace( "log_sigmoid", { input }, ret );
(void)n;
@@ -4710,32 +5537,39 @@
auto& grad_output_ = unpack(grad_output, "grad_output", 0);
auto& input_ = unpack(input, "input", 1);
auto& buffer_ = unpack(buffer, "buffer", 2);
- auto flags = Function::flags({ grad_output, input });
- auto grad_fn = std::make_shared<LogSigmoidBackwardBackward>();
- if (flags.is_executable) {
+ check_no_requires_grad(buffer, "buffer");
+ std::shared_ptr<LogSigmoidBackwardBackward> grad_fn;
+ auto flags = compute_flags({ grad_output, input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<LogSigmoidBackwardBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ grad_output, input });
grad_fn->input_ = SavedVariable(input, nullptr);
grad_fn->buffer_ = SavedVariable(buffer, nullptr);
grad_fn->grad_output_ = SavedVariable(grad_output, nullptr);
}
auto ret = as_variable(baseType->log_sigmoid_backward(grad_output_, input_, buffer_));
- wrap_output(ret, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ grad_output, input })) {
- jit::Node *n = jit::tracer::recordTrace( "log_sigmoid_backward", { grad_output, input }, ret );
+ set_flags(ret, flags, grad_fn);
+ if (jit::tracer::isTracing({ grad_output, input, buffer })) {
+ jit::Node *n = jit::tracer::recordTrace( "log_sigmoid_backward", { grad_output, input, buffer }, ret );
(void)n;
}
return Tensor(std::move(ret));
}
Tensor VariableType::log_softmax(const Tensor & input, int64_t dim) const {
auto& input_ = unpack(input, "input", 0);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<LogSoftmaxBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<LogSoftmaxBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<LogSoftmaxBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
grad_fn->input_ = SavedVariable(input, nullptr);
grad_fn->dim = dim;
}
auto ret = as_variable(baseType->log_softmax_forward(input_, dim));
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(ret, flags, grad_fn);
+ if (grad_fn) {
auto& output = ret;
grad_fn->output_ = SavedVariable(output, grad_fn.get());
}
@@ -4749,17 +5583,20 @@
auto& grad_output_ = unpack(grad_output, "grad_output", 0);
auto& input_ = unpack(input, "input", 1);
auto& output_ = unpack(output, "output", 3);
- auto flags = Function::flags({ grad_output, input });
- auto grad_fn = std::make_shared<LogSoftmaxBackwardBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<LogSoftmaxBackwardBackward> grad_fn;
+ auto flags = compute_flags({ grad_output, input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<LogSoftmaxBackwardBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ grad_output, input });
grad_fn->dim = dim;
grad_fn->output_ = SavedVariable(output, nullptr);
grad_fn->grad_output_ = SavedVariable(grad_output, nullptr);
}
auto ret = as_variable(baseType->log_softmax_backward(grad_output_, input_, dim, output_));
- wrap_output(ret, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ grad_output, input })) {
- jit::Node *n = jit::tracer::recordTrace( "log_softmax_backward", { grad_output, input }, ret );
+ set_flags(ret, flags, grad_fn);
+ if (jit::tracer::isTracing({ grad_output, input, output })) {
+ jit::Node *n = jit::tracer::recordTrace( "log_softmax_backward", { grad_output, input, output }, ret );
setattr(n, jit::stringToSymbol("dim"), dim);
}
return Tensor(std::move(ret));
@@ -4767,14 +5604,17 @@
Tensor VariableType::prelu(const Tensor & input, const Tensor & weight) const {
auto& input_ = unpack(input, "input", 0);
auto& weight_ = unpack(weight, "weight", 1);
- auto flags = Function::flags({ input, weight });
- auto grad_fn = std::make_shared<PreluBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<PreluBackward> grad_fn;
+ auto flags = compute_flags({ input, weight });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<PreluBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input, weight });
grad_fn->input_ = SavedVariable(input, nullptr);
grad_fn->weight_ = SavedVariable(weight, nullptr);
}
auto ret = as_variable(baseType->prelu_forward(input_, weight_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ input, weight })) {
jit::Node *n = jit::tracer::recordTrace( "prelu", { input, weight }, ret );
(void)n;
@@ -4785,15 +5625,18 @@
auto& grad_output_ = unpack(grad_output, "grad_output", 0);
auto& input_ = unpack(input, "input", 1);
auto& weight_ = unpack(weight, "weight", 2);
- auto flags = Function::flags({ grad_output, input, weight });
- auto grad_fn = std::make_shared<PreluBackwardBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<PreluBackwardBackward> grad_fn;
+ auto flags = compute_flags({ grad_output, input, weight });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<PreluBackwardBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ grad_output, input, weight });
grad_fn->grad_output_info = grad_output;
grad_fn->input_info = input;
grad_fn->weight_info = weight;
}
auto ret = as_variable(baseType->prelu_backward(grad_output_, input_, weight_, output_mask));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(std::get<0>(ret), flags, grad_fn);
if (jit::tracer::isTracing({ grad_output, input, weight })) {
jit::Node *n = jit::tracer::recordTrace( "prelu_backward", { grad_output, input, weight }, ret );
setattr(n, jit::stringToSymbol("output_mask"), output_mask);
@@ -4802,11 +5645,15 @@
}
Tensor VariableType::rrelu(const Tensor & input, Scalar lower, Scalar upper, bool training, bool inplace, Generator * generator) const {
auto& input_ = unpack(input, "input", 0);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<RreluBackward>();
auto noise = tensor();
auto& noise_ = static_cast<VariableImpl*>(noise.get())->data;
- if (flags.is_executable) {
+ if (inplace) check_inplace(input);
+ std::shared_ptr<RreluBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<RreluBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
grad_fn->lower = lower;
grad_fn->upper = upper;
grad_fn->training = training;
@@ -4814,8 +5661,8 @@
}
auto ret = as_variable(baseType->rrelu_forward(input_, lower, upper, training, inplace, generator, noise_));
if (inplace) increment_version(input);
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(ret, flags, grad_fn);
+ if (grad_fn) {
grad_fn->input_ = SavedVariable(input, grad_fn.get());
}
return Tensor(std::move(ret));
@@ -4824,9 +5671,14 @@
auto& grad_output_ = unpack(grad_output, "grad_output", 0);
auto& input_ = unpack(input, "input", 1);
auto& noise_ = unpack(noise, "noise", 6);
- auto flags = Function::flags({ grad_output, input });
- auto grad_fn = std::make_shared<RreluBackwardBackward>();
- if (flags.is_executable) {
+ if (inplace) check_inplace(input);
+ check_no_requires_grad(noise, "noise");
+ std::shared_ptr<RreluBackwardBackward> grad_fn;
+ auto flags = compute_flags({ grad_output, input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<RreluBackwardBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ grad_output, input });
grad_fn->lower = lower;
grad_fn->upper = upper;
grad_fn->training = training;
@@ -4834,12 +5686,12 @@
}
auto ret = as_variable(baseType->rrelu_backward(grad_output_, input_, lower, upper, training, inplace, noise_));
if (inplace) increment_version(input);
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(ret, flags, grad_fn);
+ if (grad_fn) {
grad_fn->input_ = SavedVariable(input, grad_fn.get());
}
- if (jit::tracer::isTracing({ grad_output, input })) {
- jit::Node *n = jit::tracer::recordTrace( "rrelu_backward", { grad_output, input }, ret );
+ if (jit::tracer::isTracing({ grad_output, input, noise })) {
+ jit::Node *n = jit::tracer::recordTrace( "rrelu_backward", { grad_output, input, noise }, ret );
setattr(n, jit::stringToSymbol("lower"), lower);
setattr(n, jit::stringToSymbol("upper"), upper);
setattr(n, jit::stringToSymbol("training"), training);
@@ -4849,15 +5701,18 @@
}
Tensor VariableType::softmax(const Tensor & input, int64_t dim) const {
auto& input_ = unpack(input, "input", 0);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<SoftmaxBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<SoftmaxBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SoftmaxBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
grad_fn->input_ = SavedVariable(input, nullptr);
grad_fn->dim = dim;
}
auto ret = as_variable(baseType->softmax_forward(input_, dim));
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(ret, flags, grad_fn);
+ if (grad_fn) {
auto& output = ret;
grad_fn->output_ = SavedVariable(output, grad_fn.get());
}
@@ -4871,34 +5726,40 @@
auto& grad_output_ = unpack(grad_output, "grad_output", 0);
auto& input_ = unpack(input, "input", 1);
auto& output_ = unpack(output, "output", 3);
- auto flags = Function::flags({ grad_output, input });
- auto grad_fn = std::make_shared<SoftmaxBackwardBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<SoftmaxBackwardBackward> grad_fn;
+ auto flags = compute_flags({ grad_output, input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SoftmaxBackwardBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ grad_output, input });
grad_fn->input_ = SavedVariable(input, nullptr);
grad_fn->dim = dim;
grad_fn->output_ = SavedVariable(output, nullptr);
grad_fn->grad_output_ = SavedVariable(grad_output, nullptr);
}
auto ret = as_variable(baseType->softmax_backward(grad_output_, input_, dim, output_));
- wrap_output(ret, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ grad_output, input })) {
- jit::Node *n = jit::tracer::recordTrace( "softmax_backward", { grad_output, input }, ret );
+ set_flags(ret, flags, grad_fn);
+ if (jit::tracer::isTracing({ grad_output, input, output })) {
+ jit::Node *n = jit::tracer::recordTrace( "softmax_backward", { grad_output, input, output }, ret );
setattr(n, jit::stringToSymbol("dim"), dim);
}
return Tensor(std::move(ret));
}
Tensor VariableType::softplus(const Tensor & input, Scalar beta, Scalar threshold) const {
auto& input_ = unpack(input, "input", 0);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<SoftplusBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<SoftplusBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SoftplusBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
grad_fn->input_ = SavedVariable(input, nullptr);
grad_fn->beta = beta;
grad_fn->threshold = threshold;
}
auto ret = as_variable(baseType->softplus_forward(input_, beta, threshold));
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(ret, flags, grad_fn);
+ if (grad_fn) {
auto& output = ret;
grad_fn->output_ = SavedVariable(output, grad_fn.get());
}
@@ -4913,9 +5774,12 @@
auto& grad_output_ = unpack(grad_output, "grad_output", 0);
auto& input_ = unpack(input, "input", 1);
auto& output_ = unpack(output, "output", 4);
- auto flags = Function::flags({ grad_output, input });
- auto grad_fn = std::make_shared<SoftplusBackwardBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<SoftplusBackwardBackward> grad_fn;
+ auto flags = compute_flags({ grad_output, input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SoftplusBackwardBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ grad_output, input });
grad_fn->input_ = SavedVariable(input, nullptr);
grad_fn->beta = beta;
grad_fn->threshold = threshold;
@@ -4923,9 +5787,9 @@
grad_fn->grad_output_ = SavedVariable(grad_output, nullptr);
}
auto ret = as_variable(baseType->softplus_backward(grad_output_, input_, beta, threshold, output_));
- wrap_output(ret, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ grad_output, input })) {
- jit::Node *n = jit::tracer::recordTrace( "softplus_backward", { grad_output, input }, ret );
+ set_flags(ret, flags, grad_fn);
+ if (jit::tracer::isTracing({ grad_output, input, output })) {
+ jit::Node *n = jit::tracer::recordTrace( "softplus_backward", { grad_output, input, output }, ret );
setattr(n, jit::stringToSymbol("beta"), beta);
setattr(n, jit::stringToSymbol("threshold"), threshold);
}
@@ -4933,14 +5797,17 @@
}
Tensor VariableType::softshrink(const Tensor & input, Scalar lambd) const {
auto& input_ = unpack(input, "input", 0);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<SoftshrinkBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<SoftshrinkBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SoftshrinkBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
grad_fn->input_ = SavedVariable(input, nullptr);
grad_fn->lambd = lambd;
}
auto ret = as_variable(baseType->softshrink_forward(input_, lambd));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ input })) {
jit::Node *n = jit::tracer::recordTrace( "softshrink", { input }, ret );
setattr(n, jit::stringToSymbol("lambd"), lambd);
@@ -4950,14 +5817,17 @@
Tensor VariableType::softshrink_backward(const Tensor & grad_output, const Tensor & input, Scalar lambd) const {
auto& grad_output_ = unpack(grad_output, "grad_output", 0);
auto& input_ = unpack(input, "input", 1);
- auto flags = Function::flags({ grad_output, input });
- auto grad_fn = std::make_shared<SoftshrinkBackwardBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<SoftshrinkBackwardBackward> grad_fn;
+ auto flags = compute_flags({ grad_output, input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SoftshrinkBackwardBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ grad_output, input });
grad_fn->input_ = SavedVariable(input, nullptr);
grad_fn->lambd = lambd;
}
auto ret = as_variable(baseType->softshrink_backward(grad_output_, input_, lambd));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ grad_output, input })) {
jit::Node *n = jit::tracer::recordTrace( "softshrink_backward", { grad_output, input }, ret );
setattr(n, jit::stringToSymbol("lambd"), lambd);
@@ -4966,16 +5836,20 @@
}
Tensor VariableType::threshold(const Tensor & input, Scalar threshold, Scalar value, bool inplace) const {
auto& input_ = unpack(input, "input", 0);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<ThresholdBackward>();
- if (flags.is_executable) {
+ if (inplace) check_inplace(input);
+ std::shared_ptr<ThresholdBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<ThresholdBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
grad_fn->threshold = threshold;
grad_fn->value = value;
}
auto ret = as_variable(baseType->threshold_forward(input_, threshold, value, inplace));
if (inplace) increment_version(input);
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(ret, flags, grad_fn);
+ if (grad_fn) {
grad_fn->input_ = SavedVariable(input, grad_fn.get());
}
if (jit::tracer::isTracing({ input })) {
@@ -4989,16 +5863,20 @@
Tensor VariableType::threshold_backward(const Tensor & grad_output, const Tensor & input, Scalar threshold, Scalar value, bool inplace) const {
auto& grad_output_ = unpack(grad_output, "grad_output", 0);
auto& input_ = unpack(input, "input", 1);
- auto flags = Function::flags({ grad_output, input });
- auto grad_fn = std::make_shared<ThresholdBackwardBackward>();
- if (flags.is_executable) {
+ if (inplace) check_inplace(input);
+ std::shared_ptr<ThresholdBackwardBackward> grad_fn;
+ auto flags = compute_flags({ grad_output, input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<ThresholdBackwardBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ grad_output, input });
grad_fn->threshold = threshold;
grad_fn->value = value;
}
auto ret = as_variable(baseType->threshold_backward(grad_output_, input_, threshold, value, inplace));
if (inplace) increment_version(input);
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(ret, flags, grad_fn);
+ if (grad_fn) {
grad_fn->input_ = SavedVariable(input, grad_fn.get());
}
if (jit::tracer::isTracing({ grad_output, input })) {
@@ -5011,14 +5889,17 @@
}
std::tuple<Tensor,Tensor> VariableType::adaptive_max_pool2d(const Tensor & input, IntList output_size) const {
auto& input_ = unpack(input, "input", 0);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<AdaptiveMaxPool2DBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<AdaptiveMaxPool2DBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<AdaptiveMaxPool2DBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
grad_fn->input_ = SavedVariable(input, nullptr);
}
auto ret = as_variable(baseType->adaptive_max_pool2d_forward(input_, output_size));
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(std::get<0>(ret), flags, grad_fn);
+ if (grad_fn) {
auto& indices = std::get<1>(ret);
grad_fn->indices_ = SavedVariable(indices, grad_fn.get());
}
@@ -5032,21 +5913,29 @@
auto& grad_output_ = unpack(grad_output, "grad_output", 0);
auto& input_ = unpack(input, "input", 1);
auto& indices_ = unpack_long(indices, "indices", 2);
- auto flags = Function::flags({ grad_output, input, indices });
- auto grad_fn = std::make_shared<Error>("adaptive_max_pool2d_backward is not differentiable");
+ std::shared_ptr<Error> grad_fn;
+ auto flags = compute_flags({ grad_output, input, indices });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<Error>("the derivative for adaptive_max_pool2d_backward is not implemented");
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ grad_output, input, indices });
+ }
auto ret = as_variable(baseType->adaptive_max_pool2d_backward(grad_output_, input_, indices_));
- wrap_output(ret, std::move(flags), std::move(grad_fn));
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ grad_output, input, indices })) {
jit::Node *n = jit::tracer::recordTrace( "adaptive_max_pool2d_backward", { grad_output, input, indices }, ret );
(void)n;
}
- return ret;
+ return Tensor(std::move(ret));
}
Tensor VariableType::avg_pool2d(const Tensor & input, IntList kernel_size, IntList stride, IntList padding, bool ceil_mode, bool count_include_pad) const {
auto& input_ = unpack(input, "input", 0);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<AvgPool2DBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<AvgPool2DBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<AvgPool2DBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
grad_fn->input_ = SavedVariable(input, nullptr);
grad_fn->kernel_size = kernel_size;
grad_fn->stride = stride;
@@ -5055,7 +5944,7 @@
grad_fn->count_include_pad = count_include_pad;
}
auto ret = as_variable(baseType->avg_pool2d_forward(input_, kernel_size, stride, padding, ceil_mode, count_include_pad));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ input })) {
jit::Node *n = jit::tracer::recordTrace( "avg_pool2d", { input }, ret );
setattr(n, jit::stringToSymbol("kernel_size"), kernel_size);
@@ -5069,9 +5958,12 @@
Tensor VariableType::avg_pool2d_backward(const Tensor & grad_output, const Tensor & input, IntList kernel_size, IntList stride, IntList padding, bool ceil_mode, bool count_include_pad) const {
auto& grad_output_ = unpack(grad_output, "grad_output", 0);
auto& input_ = unpack(input, "input", 1);
- auto flags = Function::flags({ grad_output, input });
- auto grad_fn = std::make_shared<AvgPool2DBackwardBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<AvgPool2DBackwardBackward> grad_fn;
+ auto flags = compute_flags({ grad_output, input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<AvgPool2DBackwardBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ grad_output, input });
grad_fn->kernel_size = kernel_size;
grad_fn->stride = stride;
grad_fn->padding = padding;
@@ -5080,7 +5972,7 @@
grad_fn->input_info = input;
}
auto ret = as_variable(baseType->avg_pool2d_backward(grad_output_, input_, kernel_size, stride, padding, ceil_mode, count_include_pad));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ grad_output, input })) {
jit::Node *n = jit::tracer::recordTrace( "avg_pool2d_backward", { grad_output, input }, ret );
setattr(n, jit::stringToSymbol("kernel_size"), kernel_size);
@@ -5093,9 +5985,12 @@
}
Tensor VariableType::avg_pool3d(const Tensor & input, IntList kernel_size, IntList stride, IntList padding, bool ceil_mode, bool count_include_pad) const {
auto& input_ = unpack(input, "input", 0);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<AvgPool3DBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<AvgPool3DBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<AvgPool3DBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
grad_fn->input_ = SavedVariable(input, nullptr);
grad_fn->kernel_size = kernel_size;
grad_fn->stride = stride;
@@ -5104,7 +5999,7 @@
grad_fn->count_include_pad = count_include_pad;
}
auto ret = as_variable(baseType->avg_pool3d_forward(input_, kernel_size, stride, padding, ceil_mode, count_include_pad));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ input })) {
jit::Node *n = jit::tracer::recordTrace( "avg_pool3d", { input }, ret );
setattr(n, jit::stringToSymbol("kernel_size"), kernel_size);
@@ -5118,9 +6013,12 @@
Tensor VariableType::avg_pool3d_backward(const Tensor & grad_output, const Tensor & input, IntList kernel_size, IntList stride, IntList padding, bool ceil_mode, bool count_include_pad) const {
auto& grad_output_ = unpack(grad_output, "grad_output", 0);
auto& input_ = unpack(input, "input", 1);
- auto flags = Function::flags({ grad_output, input });
- auto grad_fn = std::make_shared<AvgPool3DBackwardBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<AvgPool3DBackwardBackward> grad_fn;
+ auto flags = compute_flags({ grad_output, input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<AvgPool3DBackwardBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ grad_output, input });
grad_fn->kernel_size = kernel_size;
grad_fn->stride = stride;
grad_fn->padding = padding;
@@ -5129,7 +6027,7 @@
grad_fn->input_info = input;
}
auto ret = as_variable(baseType->avg_pool3d_backward(grad_output_, input_, kernel_size, stride, padding, ceil_mode, count_include_pad));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ grad_output, input })) {
jit::Node *n = jit::tracer::recordTrace( "avg_pool3d_backward", { grad_output, input }, ret );
setattr(n, jit::stringToSymbol("kernel_size"), kernel_size);
@@ -5142,9 +6040,12 @@
}
std::tuple<Tensor,Tensor> VariableType::max_pool2d(const Tensor & input, IntList kernel_size, IntList stride, IntList padding, IntList dilation, bool ceil_mode) const {
auto& input_ = unpack(input, "input", 0);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<MaxPool2DBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<MaxPool2DBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<MaxPool2DBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
grad_fn->input_ = SavedVariable(input, nullptr);
grad_fn->kernel_size = kernel_size;
grad_fn->stride = stride;
@@ -5153,8 +6054,8 @@
grad_fn->ceil_mode = ceil_mode;
}
auto ret = as_variable(baseType->max_pool2d_forward(input_, kernel_size, stride, padding, dilation, ceil_mode));
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(std::get<0>(ret), flags, grad_fn);
+ if (grad_fn) {
auto& indices = std::get<1>(ret);
grad_fn->indices_ = SavedVariable(indices, grad_fn.get());
}
@@ -5172,16 +6073,19 @@
auto& grad_output_ = unpack(grad_output, "grad_output", 0);
auto& input_ = unpack(input, "input", 1);
auto& indices_ = unpack_long(indices, "indices", 7);
- auto flags = Function::flags({ grad_output, input });
- auto grad_fn = std::make_shared<MaxPool2DBackwardBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<MaxPool2DBackwardBackward> grad_fn;
+ auto flags = compute_flags({ grad_output, input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<MaxPool2DBackwardBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ grad_output, input });
grad_fn->indices_ = SavedVariable(indices, nullptr);
grad_fn->input_info = input;
}
auto ret = as_variable(baseType->max_pool2d_backward(grad_output_, input_, kernel_size, stride, padding, dilation, ceil_mode, indices_));
- wrap_output(ret, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ grad_output, input })) {
- jit::Node *n = jit::tracer::recordTrace( "max_pool2d_backward", { grad_output, input }, ret );
+ set_flags(ret, flags, grad_fn);
+ if (jit::tracer::isTracing({ grad_output, input, indices })) {
+ jit::Node *n = jit::tracer::recordTrace( "max_pool2d_backward", { grad_output, input, indices }, ret );
setattr(n, jit::stringToSymbol("kernel_size"), kernel_size);
setattr(n, jit::stringToSymbol("stride"), stride);
setattr(n, jit::stringToSymbol("padding"), padding);
@@ -5192,9 +6096,12 @@
}
std::tuple<Tensor,Tensor> VariableType::max_pool3d(const Tensor & input, IntList kernel_size, IntList stride, IntList padding, IntList dilation, bool ceil_mode) const {
auto& input_ = unpack(input, "input", 0);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<MaxPool3DBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<MaxPool3DBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<MaxPool3DBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
grad_fn->input_ = SavedVariable(input, nullptr);
grad_fn->kernel_size = kernel_size;
grad_fn->stride = stride;
@@ -5203,8 +6110,8 @@
grad_fn->ceil_mode = ceil_mode;
}
auto ret = as_variable(baseType->max_pool3d_forward(input_, kernel_size, stride, padding, dilation, ceil_mode));
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(std::get<0>(ret), flags, grad_fn);
+ if (grad_fn) {
auto& indices = std::get<1>(ret);
grad_fn->indices_ = SavedVariable(indices, grad_fn.get());
}
@@ -5222,10 +6129,15 @@
auto& grad_output_ = unpack(grad_output, "grad_output", 0);
auto& input_ = unpack(input, "input", 1);
auto& indices_ = unpack_long(indices, "indices", 7);
- auto flags = Function::flags({ grad_output, input, indices });
- auto grad_fn = std::make_shared<Error>("max_pool3d_backward is not differentiable");
+ std::shared_ptr<Error> grad_fn;
+ auto flags = compute_flags({ grad_output, input, indices });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<Error>("the derivative for max_pool3d_backward is not implemented");
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ grad_output, input, indices });
+ }
auto ret = as_variable(baseType->max_pool3d_backward(grad_output_, input_, kernel_size, stride, padding, dilation, ceil_mode, indices_));
- wrap_output(ret, std::move(flags), std::move(grad_fn));
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ grad_output, input, indices })) {
jit::Node *n = jit::tracer::recordTrace( "max_pool3d_backward", { grad_output, input, indices }, ret );
setattr(n, jit::stringToSymbol("kernel_size"), kernel_size);
@@ -5234,22 +6146,25 @@
setattr(n, jit::stringToSymbol("dilation"), dilation);
setattr(n, jit::stringToSymbol("ceil_mode"), ceil_mode);
}
- return ret;
+ return Tensor(std::move(ret));
}
Tensor VariableType::max_unpool2d(const Tensor & input, const Tensor & indices, IntList output_size) const {
auto& input_ = unpack(input, "input", 0);
auto& indices_ = unpack_long(indices, "indices", 1);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<MaxUnpool2DBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<MaxUnpool2DBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<MaxUnpool2DBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
grad_fn->input_ = SavedVariable(input, nullptr);
grad_fn->indices_ = SavedVariable(indices, nullptr);
grad_fn->output_size = output_size;
}
auto ret = as_variable(baseType->max_unpool2d_forward(input_, indices_, output_size));
- wrap_output(ret, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ input })) {
- jit::Node *n = jit::tracer::recordTrace( "max_unpool2d", { input }, ret );
+ set_flags(ret, flags, grad_fn);
+ if (jit::tracer::isTracing({ input, indices })) {
+ jit::Node *n = jit::tracer::recordTrace( "max_unpool2d", { input, indices }, ret );
setattr(n, jit::stringToSymbol("output_size"), output_size);
}
return Tensor(std::move(ret));
@@ -5258,17 +6173,20 @@
auto& grad_output_ = unpack(grad_output, "grad_output", 0);
auto& input_ = unpack(input, "input", 1);
auto& indices_ = unpack_long(indices, "indices", 2);
- auto flags = Function::flags({ grad_output, input });
- auto grad_fn = std::make_shared<MaxUnpool2DBackwardBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<MaxUnpool2DBackwardBackward> grad_fn;
+ auto flags = compute_flags({ grad_output, input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<MaxUnpool2DBackwardBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ grad_output, input });
grad_fn->indices_ = SavedVariable(indices, nullptr);
grad_fn->output_size = output_size;
grad_fn->input_info = input;
}
auto ret = as_variable(baseType->max_unpool2d_backward(grad_output_, input_, indices_, output_size));
- wrap_output(ret, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ grad_output, input })) {
- jit::Node *n = jit::tracer::recordTrace( "max_unpool2d_backward", { grad_output, input }, ret );
+ set_flags(ret, flags, grad_fn);
+ if (jit::tracer::isTracing({ grad_output, input, indices })) {
+ jit::Node *n = jit::tracer::recordTrace( "max_unpool2d_backward", { grad_output, input, indices }, ret );
setattr(n, jit::stringToSymbol("output_size"), output_size);
}
return Tensor(std::move(ret));
@@ -5276,9 +6194,12 @@
Tensor VariableType::max_unpool3d(const Tensor & input, const Tensor & indices, IntList output_size, IntList stride, IntList padding) const {
auto& input_ = unpack(input, "input", 0);
auto& indices_ = unpack_long(indices, "indices", 1);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<MaxUnpool3DBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<MaxUnpool3DBackward> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<MaxUnpool3DBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
grad_fn->input_ = SavedVariable(input, nullptr);
grad_fn->indices_ = SavedVariable(indices, nullptr);
grad_fn->output_size = output_size;
@@ -5286,9 +6207,9 @@
grad_fn->padding = padding;
}
auto ret = as_variable(baseType->max_unpool3d_forward(input_, indices_, output_size, stride, padding));
- wrap_output(ret, std::move(flags), grad_fn);
- if (jit::tracer::isTracing({ input })) {
- jit::Node *n = jit::tracer::recordTrace( "max_unpool3d", { input }, ret );
+ set_flags(ret, flags, grad_fn);
+ if (jit::tracer::isTracing({ input, indices })) {
+ jit::Node *n = jit::tracer::recordTrace( "max_unpool3d", { input, indices }, ret );
setattr(n, jit::stringToSymbol("output_size"), output_size);
setattr(n, jit::stringToSymbol("stride"), stride);
setattr(n, jit::stringToSymbol("padding"), padding);
@@ -5299,25 +6220,35 @@
auto& grad_output_ = unpack(grad_output, "grad_output", 0);
auto& input_ = unpack(input, "input", 1);
auto& indices_ = unpack_long(indices, "indices", 2);
- auto flags = Function::flags({ grad_output, input, indices });
- auto grad_fn = std::make_shared<Error>("max_unpool3d_backward is not differentiable");
+ std::shared_ptr<Error> grad_fn;
+ auto flags = compute_flags({ grad_output, input, indices });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<Error>("the derivative for max_unpool3d_backward is not implemented");
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ grad_output, input, indices });
+ }
auto ret = as_variable(baseType->max_unpool3d_backward(grad_output_, input_, indices_, output_size, stride, padding));
- wrap_output(ret, std::move(flags), std::move(grad_fn));
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ grad_output, input, indices })) {
jit::Node *n = jit::tracer::recordTrace( "max_unpool3d_backward", { grad_output, input, indices }, ret );
setattr(n, jit::stringToSymbol("output_size"), output_size);
setattr(n, jit::stringToSymbol("stride"), stride);
setattr(n, jit::stringToSymbol("padding"), padding);
}
- return ret;
+ return Tensor(std::move(ret));
}
Tensor VariableType::_sigmoid(const Tensor & input) const {
auto& input_ = unpack(input, "input", 0);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<SigmoidBackward1>();
+ std::shared_ptr<SigmoidBackward1> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SigmoidBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
+ }
auto ret = as_variable(baseType->_sigmoid_forward(input_));
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(ret, flags, grad_fn);
+ if (grad_fn) {
auto& output = ret;
grad_fn->output_ = SavedVariable(output, grad_fn.get());
}
@@ -5330,14 +6261,17 @@
Tensor VariableType::_sigmoid_backward(const Tensor & grad_output, const Tensor & output) const {
auto& grad_output_ = unpack(grad_output, "grad_output", 0);
auto& output_ = unpack(output, "output", 1);
- auto flags = Function::flags({ grad_output, output });
- auto grad_fn = std::make_shared<SigmoidBackwardBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<SigmoidBackwardBackward> grad_fn;
+ auto flags = compute_flags({ grad_output, output });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<SigmoidBackwardBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ grad_output, output });
grad_fn->output_ = SavedVariable(output, nullptr);
grad_fn->grad_output_ = SavedVariable(grad_output, nullptr);
}
auto ret = as_variable(baseType->_sigmoid_backward(grad_output_, output_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ grad_output, output })) {
jit::Node *n = jit::tracer::recordTrace( "_sigmoid_backward", { grad_output, output }, ret );
(void)n;
@@ -5346,11 +6280,16 @@
}
Tensor VariableType::_tanh(const Tensor & input) const {
auto& input_ = unpack(input, "input", 0);
- auto flags = Function::flags({ input });
- auto grad_fn = std::make_shared<TanhBackward1>();
+ std::shared_ptr<TanhBackward1> grad_fn;
+ auto flags = compute_flags({ input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<TanhBackward1>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ input });
+ }
auto ret = as_variable(baseType->_tanh_forward(input_));
- wrap_output(ret, std::move(flags), grad_fn);
- if (flags.is_executable) {
+ set_flags(ret, flags, grad_fn);
+ if (grad_fn) {
auto& output = ret;
grad_fn->output_ = SavedVariable(output, grad_fn.get());
}
@@ -5363,14 +6302,17 @@
Tensor VariableType::_tanh_backward(const Tensor & grad_output, const Tensor & output) const {
auto& grad_output_ = unpack(grad_output, "grad_output", 0);
auto& output_ = unpack(output, "output", 1);
- auto flags = Function::flags({ grad_output, output });
- auto grad_fn = std::make_shared<TanhBackwardBackward>();
- if (flags.is_executable) {
+ std::shared_ptr<TanhBackwardBackward> grad_fn;
+ auto flags = compute_flags({ grad_output, output });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<TanhBackwardBackward>();
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ grad_output, output });
grad_fn->output_ = SavedVariable(output, nullptr);
grad_fn->grad_output_ = SavedVariable(grad_output, nullptr);
}
auto ret = as_variable(baseType->_tanh_backward(grad_output_, output_));
- wrap_output(ret, std::move(flags), grad_fn);
+ set_flags(ret, flags, grad_fn);
if (jit::tracer::isTracing({ grad_output, output })) {
jit::Node *n = jit::tracer::recordTrace( "_tanh_backward", { grad_output, output }, ret );
(void)n;
@@ -5388,17 +6330,22 @@
auto& running_var_ = unpack(running_var, "running_var", 4);
auto& save_mean_ = unpack(save_mean, "save_mean", 7);
auto& save_std_ = unpack(save_std, "save_std", 8);
- auto flags = Function::flags({ grad_output, input, weight, running_mean, running_var, save_mean, save_std });
- auto grad_fn = std::make_shared<Error>("batch_norm_backward is not differentiable");
+ std::shared_ptr<Error> grad_fn;
+ auto flags = compute_flags({ grad_output, input, weight, running_mean, running_var, save_mean, save_std });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<Error>("the derivative for batch_norm_backward is not implemented");
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ grad_output, input, weight, running_mean, running_var, save_mean, save_std });
+ }
auto ret = as_variable(baseType->batch_norm_backward(grad_output_, input_, weight_, running_mean_, running_var_, training, eps, save_mean_, save_std_, output_mask));
- wrap_output(ret, std::move(flags), std::move(grad_fn));
+ set_flags(std::get<0>(ret), flags, grad_fn);
if (jit::tracer::isTracing({ grad_output, input, weight, running_mean, running_var, save_mean, save_std })) {
jit::Node *n = jit::tracer::recordTrace( "batch_norm_backward", { grad_output, input, weight, running_mean, running_var, save_mean, save_std }, ret );
setattr(n, jit::stringToSymbol("training"), training);
setattr(n, jit::stringToSymbol("eps"), eps);
setattr(n, jit::stringToSymbol("output_mask"), output_mask);
}
- return ret;
+ return std::tuple<Tensor,Tensor,Tensor>(std::move(ret));
}
Tensor VariableType::conv_transpose2d(const Tensor & input, const Tensor & weight, IntList kernel_size, const Tensor & bias, IntList stride, IntList padding, IntList output_padding, IntList dilation) const {
throw std::runtime_error("conv_transpose2d: NYI");
@@ -5409,10 +6356,15 @@
auto& weight_ = unpack(weight, "weight", 2);
auto& columns_ = unpack(columns, "columns", 8);
auto& ones_ = unpack(ones, "ones", 9);
- auto flags = Function::flags({ grad_output, input, weight, columns, ones });
- auto grad_fn = std::make_shared<Error>("conv_transpose2d_backward is not differentiable");
+ std::shared_ptr<Error> grad_fn;
+ auto flags = compute_flags({ grad_output, input, weight, columns, ones });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<Error>("the derivative for conv_transpose2d_backward is not implemented");
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ grad_output, input, weight, columns, ones });
+ }
auto ret = as_variable(baseType->conv_transpose2d_backward(grad_output_, input_, weight_, kernel_size, stride, padding, output_padding, dilation, columns_, ones_, output_mask));
- wrap_output(ret, std::move(flags), std::move(grad_fn));
+ set_flags(std::get<0>(ret), flags, grad_fn);
if (jit::tracer::isTracing({ grad_output, input, weight, columns, ones })) {
jit::Node *n = jit::tracer::recordTrace( "conv_transpose2d_backward", { grad_output, input, weight, columns, ones }, ret );
setattr(n, jit::stringToSymbol("kernel_size"), kernel_size);
@@ -5422,7 +6374,7 @@
setattr(n, jit::stringToSymbol("dilation"), dilation);
setattr(n, jit::stringToSymbol("output_mask"), output_mask);
}
- return ret;
+ return std::tuple<Tensor,Tensor,Tensor>(std::move(ret));
}
Tensor VariableType::conv_transpose3d(const Tensor & input, const Tensor & weight, const Tensor & bias, IntList stride, IntList padding, IntList output_padding, IntList dilation) const {
throw std::runtime_error("conv_transpose3d: NYI");
@@ -5433,10 +6385,15 @@
auto& weight_ = unpack(weight, "weight", 2);
auto& finput_ = unpack(finput, "finput", 7);
auto& fgrad_input_ = unpack(fgrad_input, "fgrad_input", 8);
- auto flags = Function::flags({ grad_output, input, weight, finput, fgrad_input });
- auto grad_fn = std::make_shared<Error>("conv_transpose3d_backward is not differentiable");
+ std::shared_ptr<Error> grad_fn;
+ auto flags = compute_flags({ grad_output, input, weight, finput, fgrad_input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<Error>("the derivative for conv_transpose3d_backward is not implemented");
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ grad_output, input, weight, finput, fgrad_input });
+ }
auto ret = as_variable(baseType->conv_transpose3d_backward(grad_output_, input_, weight_, stride, padding, output_padding, dilation, finput_, fgrad_input_, output_mask));
- wrap_output(ret, std::move(flags), std::move(grad_fn));
+ set_flags(std::get<0>(ret), flags, grad_fn);
if (jit::tracer::isTracing({ grad_output, input, weight, finput, fgrad_input })) {
jit::Node *n = jit::tracer::recordTrace( "conv_transpose3d_backward", { grad_output, input, weight, finput, fgrad_input }, ret );
setattr(n, jit::stringToSymbol("stride"), stride);
@@ -5445,7 +6402,7 @@
setattr(n, jit::stringToSymbol("dilation"), dilation);
setattr(n, jit::stringToSymbol("output_mask"), output_mask);
}
- return ret;
+ return std::tuple<Tensor,Tensor,Tensor>(std::move(ret));
}
Tensor VariableType::conv2d(const Tensor & input, const Tensor & weight, IntList kernel_size, const Tensor & bias, IntList stride, IntList padding) const {
throw std::runtime_error("conv2d: NYI");
@@ -5456,10 +6413,15 @@
auto& weight_ = unpack(weight, "weight", 2);
auto& finput_ = unpack(finput, "finput", 6);
auto& fgrad_input_ = unpack(fgrad_input, "fgrad_input", 7);
- auto flags = Function::flags({ grad_output, input, weight, finput, fgrad_input });
- auto grad_fn = std::make_shared<Error>("conv2d_backward is not differentiable");
+ std::shared_ptr<Error> grad_fn;
+ auto flags = compute_flags({ grad_output, input, weight, finput, fgrad_input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<Error>("the derivative for conv2d_backward is not implemented");
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ grad_output, input, weight, finput, fgrad_input });
+ }
auto ret = as_variable(baseType->conv2d_backward(grad_output_, input_, weight_, kernel_size, stride, padding, finput_, fgrad_input_, output_mask));
- wrap_output(ret, std::move(flags), std::move(grad_fn));
+ set_flags(std::get<0>(ret), flags, grad_fn);
if (jit::tracer::isTracing({ grad_output, input, weight, finput, fgrad_input })) {
jit::Node *n = jit::tracer::recordTrace( "conv2d_backward", { grad_output, input, weight, finput, fgrad_input }, ret );
setattr(n, jit::stringToSymbol("kernel_size"), kernel_size);
@@ -5467,7 +6429,7 @@
setattr(n, jit::stringToSymbol("padding"), padding);
setattr(n, jit::stringToSymbol("output_mask"), output_mask);
}
- return ret;
+ return std::tuple<Tensor,Tensor,Tensor>(std::move(ret));
}
Tensor VariableType::conv_depthwise2d(const Tensor & input, const Tensor & weight, IntList kernel_size, const Tensor & bias, IntList stride, IntList padding, IntList dilation) const {
throw std::runtime_error("conv_depthwise2d: NYI");
@@ -5476,10 +6438,15 @@
auto& grad_output_ = unpack(grad_output, "grad_output", 0);
auto& input_ = unpack(input, "input", 1);
auto& weight_ = unpack(weight, "weight", 2);
- auto flags = Function::flags({ grad_output, input, weight });
- auto grad_fn = std::make_shared<Error>("conv_depthwise2d_backward is not differentiable");
+ std::shared_ptr<Error> grad_fn;
+ auto flags = compute_flags({ grad_output, input, weight });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<Error>("the derivative for conv_depthwise2d_backward is not implemented");
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ grad_output, input, weight });
+ }
auto ret = as_variable(baseType->conv_depthwise2d_backward(grad_output_, input_, weight_, kernel_size, stride, padding, dilation, output_mask));
- wrap_output(ret, std::move(flags), std::move(grad_fn));
+ set_flags(std::get<0>(ret), flags, grad_fn);
if (jit::tracer::isTracing({ grad_output, input, weight })) {
jit::Node *n = jit::tracer::recordTrace( "conv_depthwise2d_backward", { grad_output, input, weight }, ret );
setattr(n, jit::stringToSymbol("kernel_size"), kernel_size);
@@ -5488,7 +6455,7 @@
setattr(n, jit::stringToSymbol("dilation"), dilation);
setattr(n, jit::stringToSymbol("output_mask"), output_mask);
}
- return ret;
+ return std::tuple<Tensor,Tensor>(std::move(ret));
}
Tensor VariableType::conv3d(const Tensor & input, const Tensor & weight, IntList kernel_size, const Tensor & bias, IntList stride, IntList padding) const {
throw std::runtime_error("conv3d: NYI");
@@ -5499,10 +6466,15 @@
auto& weight_ = unpack(weight, "weight", 2);
auto& finput_ = unpack(finput, "finput", 6);
auto& fgrad_input_ = unpack(fgrad_input, "fgrad_input", 7);
- auto flags = Function::flags({ grad_output, input, weight, finput, fgrad_input });
- auto grad_fn = std::make_shared<Error>("conv3d_backward is not differentiable");
+ std::shared_ptr<Error> grad_fn;
+ auto flags = compute_flags({ grad_output, input, weight, finput, fgrad_input });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<Error>("the derivative for conv3d_backward is not implemented");
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ grad_output, input, weight, finput, fgrad_input });
+ }
auto ret = as_variable(baseType->conv3d_backward(grad_output_, input_, weight_, kernel_size, stride, padding, finput_, fgrad_input_, output_mask));
- wrap_output(ret, std::move(flags), std::move(grad_fn));
+ set_flags(std::get<0>(ret), flags, grad_fn);
if (jit::tracer::isTracing({ grad_output, input, weight, finput, fgrad_input })) {
jit::Node *n = jit::tracer::recordTrace( "conv3d_backward", { grad_output, input, weight, finput, fgrad_input }, ret );
setattr(n, jit::stringToSymbol("kernel_size"), kernel_size);
@@ -5510,7 +6482,7 @@
setattr(n, jit::stringToSymbol("padding"), padding);
setattr(n, jit::stringToSymbol("output_mask"), output_mask);
}
- return ret;
+ return std::tuple<Tensor,Tensor,Tensor>(std::move(ret));
}
Tensor VariableType::conv_dilated2d(const Tensor & input, const Tensor & weight, IntList kernel_size, const Tensor & bias, IntList stride, IntList padding, IntList dilation) const {
throw std::runtime_error("conv_dilated2d: NYI");
@@ -5521,10 +6493,15 @@
auto& weight_ = unpack(weight, "weight", 2);
auto& columns_ = unpack(columns, "columns", 7);
auto& ones_ = unpack(ones, "ones", 8);
- auto flags = Function::flags({ grad_output, input, weight, columns, ones });
- auto grad_fn = std::make_shared<Error>("conv_dilated2d_backward is not differentiable");
+ std::shared_ptr<Error> grad_fn;
+ auto flags = compute_flags({ grad_output, input, weight, columns, ones });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<Error>("the derivative for conv_dilated2d_backward is not implemented");
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ grad_output, input, weight, columns, ones });
+ }
auto ret = as_variable(baseType->conv_dilated2d_backward(grad_output_, input_, weight_, kernel_size, stride, padding, dilation, columns_, ones_, output_mask));
- wrap_output(ret, std::move(flags), std::move(grad_fn));
+ set_flags(std::get<0>(ret), flags, grad_fn);
if (jit::tracer::isTracing({ grad_output, input, weight, columns, ones })) {
jit::Node *n = jit::tracer::recordTrace( "conv_dilated2d_backward", { grad_output, input, weight, columns, ones }, ret );
setattr(n, jit::stringToSymbol("kernel_size"), kernel_size);
@@ -5533,7 +6510,7 @@
setattr(n, jit::stringToSymbol("dilation"), dilation);
setattr(n, jit::stringToSymbol("output_mask"), output_mask);
}
- return ret;
+ return std::tuple<Tensor,Tensor,Tensor>(std::move(ret));
}
Tensor VariableType::conv_dilated3d(const Tensor & input, const Tensor & weight, IntList kernel_size, const Tensor & bias, IntList stride, IntList padding, IntList dilation) const {
throw std::runtime_error("conv_dilated3d: NYI");
@@ -5544,10 +6521,15 @@
auto& weight_ = unpack(weight, "weight", 2);
auto& columns_ = unpack(columns, "columns", 7);
auto& ones_ = unpack(ones, "ones", 8);
- auto flags = Function::flags({ grad_output, input, weight, columns, ones });
- auto grad_fn = std::make_shared<Error>("conv_dilated3d_backward is not differentiable");
+ std::shared_ptr<Error> grad_fn;
+ auto flags = compute_flags({ grad_output, input, weight, columns, ones });
+ if (flags.requires_grad) {
+ grad_fn = std::make_shared<Error>("the derivative for conv_dilated3d_backward is not implemented");
+ grad_fn->is_executable = true;
+ grad_fn->next_functions = compute_next_functions({ grad_output, input, weight, columns, ones });
+ }
auto ret = as_variable(baseType->conv_dilated3d_backward(grad_output_, input_, weight_, kernel_size, stride, padding, dilation, columns_, ones_, output_mask));
- wrap_output(ret, std::move(flags), std::move(grad_fn));
+ set_flags(std::get<0>(ret), flags, grad_fn);
if (jit::tracer::isTracing({ grad_output, input, weight, columns, ones })) {
jit::Node *n = jit::tracer::recordTrace( "conv_dilated3d_backward", { grad_output, input, weight, columns, ones }, ret );
setattr(n, jit::stringToSymbol("kernel_size"), kernel_size);
@@ -5556,7 +6538,7 @@
setattr(n, jit::stringToSymbol("dilation"), dilation);
setattr(n, jit::stringToSymbol("output_mask"), output_mask);
}
- return ret;
+ return std::tuple<Tensor,Tensor,Tensor>(std::move(ret));
}
std::vector<Tensor> VariableType::split(Tensor self, int64_t split_size, int64_t dim) const {
throw std::runtime_error("split: NYI");
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment