Created
February 23, 2023 20:52
-
-
Save bdhirsh/7dadbf6296f8f7d1abcf4c482f438aaa to your computer and use it in GitHub Desktop.
how to teach functionalization about a custom, mutable operator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// In my_cpp_file_that_registers_ops_to_the_dispatcher.cpp | |
#include <ATen/FunctionalTensorWrapper.h> | |
#include <torch/library.h> | |
// The implementation of my out-of-place op, "foo" | |
at::Tensor foo_impl(const at::Tensor& x) { | |
return x.mul(2); | |
} | |
// The implementation of my in-place op, "foo_" | |
at::Tensor& foo__impl(at::Tensor& x) { | |
return x.mul_(2); | |
} | |
// The boilerplate functionalization logic, that teaches functionalization | |
// how to map foo_() calls into foo() calls. | |
// Long term, we'd like to not require users to write this logic. | |
// HOWEVER, if you have a custom op that is mutable, | |
// You will still need to write an out-of-place version of that op! | |
at::Tensor& foo__functionalization_glue(at::Tensor& x) { | |
// We expect all tensor inputs to our op to be "functional tensors" | |
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(x)); | |
// First, sync and unwrap and functional tensors | |
at::functionalization::impl::sync(x); | |
auto x_ = at::functionalization::impl::from_functional_tensor(x); | |
// Grab the dispatcher entry corresponding to the out-of-place op, "foo" | |
static auto op_handle = c10::Dispatcher::singleton() | |
// specify namespace::op_name, op_overload_name | |
.findSchemaOrThrow("custom_namespace::foo", "") | |
// Specify the C++ schema of the out-of-place op. | |
.typed<at::Tensor(const at::Tensor&)>(); | |
// Next, redispatch to the out-of-place op, foo() (user called foo_, we call fooo) | |
at::Tensor tmp_output; | |
{ | |
at::AutoDispatchSkipFunctionalize guard; | |
tmp_output = op_handle.call(x_); | |
} | |
// Finally, tell functionalization about this mutation. | |
at::functionalization::impl::replace_(x, tmp_output); | |
at::functionalization::impl::commit_update(x); | |
at::functionalization::impl::sync(x); | |
return x; | |
} | |
TORCH_LIBRARY(custom_namespace, m) { | |
m.def("foo(Tensor x) -> Tensor"); | |
m.def("foo_(Tensor(a!) x) -> Tensor(a!)"); | |
} | |
TORCH_LIBRARY_IMPL(custom_namespace, CPU, m) { | |
m.impl("foo", foo_impl); | |
m.impl("foo_", foo__impl); | |
} | |
TORCH_LIBRARY_IMPL(custom_namespace, Functionalize, m) { | |
m.impl("foo_", foo__functionalization_glue); | |
} | |
# Example python code where we call the custom (mutable) op, run functionalization, | |
# and see that we properly traced the out-of-place version of the custom op | |
import torch | |
from functorch import functionalize | |
from torch.fx.experimental.proxy_tensor import make_fx | |
a = torch.ones(2) | |
# This is a custom op, that mutates its input | |
foo_ = torch.ops.custom_namespace.foo_ | |
def f(x): | |
y = x.mul(2) | |
return foo_(y) | |
out = make_fx(functionalize(f))(a) | |
# foo() will show up in the graph instead of foo_() | |
print(out.code) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi,@bdhirsh.
I've defined a custom in-place op, following your guidance. And I did generate fx graphs, but there seems to be some problems with the results,as follows. Could you please give me some help? Thank you very much.
https://discuss.pytorch.org/t/how-to-define-an-in-place-custom-op-in-dynamo/188814?u=bigat_w