Skip to content

Instantly share code, notes, and snippets.

@bdhirsh
Created February 23, 2023 20:52
Show Gist options
  • Save bdhirsh/7dadbf6296f8f7d1abcf4c482f438aaa to your computer and use it in GitHub Desktop.
Save bdhirsh/7dadbf6296f8f7d1abcf4c482f438aaa to your computer and use it in GitHub Desktop.
how to teach functionalization about a custom, mutable operator
// In my_cpp_file_that_registers_ops_to_the_dispatcher.cpp
#include <ATen/FunctionalTensorWrapper.h>
#include <torch/library.h>
// The implementation of my out-of-place op, "foo"
at::Tensor foo_impl(const at::Tensor& x) {
return x.mul(2);
}
// The implementation of my in-place op, "foo_"
at::Tensor& foo__impl(at::Tensor& x) {
return x.mul_(2);
}
// The boilerplate functionalization logic, that teaches functionalization
// how to map foo_() calls into foo() calls.
// Long term, we'd like to not require users to write this logic.
// HOWEVER, if you have a custom op that is mutable,
// You will still need to write an out-of-place version of that op!
at::Tensor& foo__functionalization_glue(at::Tensor& x) {
// We expect all tensor inputs to our op to be "functional tensors"
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(x));
// First, sync and unwrap and functional tensors
at::functionalization::impl::sync(x);
auto x_ = at::functionalization::impl::from_functional_tensor(x);
// Grab the dispatcher entry corresponding to the out-of-place op, "foo"
static auto op_handle = c10::Dispatcher::singleton()
// specify namespace::op_name, op_overload_name
.findSchemaOrThrow("custom_namespace::foo", "")
// Specify the C++ schema of the out-of-place op.
.typed<at::Tensor(const at::Tensor&)>();
// Next, redispatch to the out-of-place op, foo() (user called foo_, we call fooo)
at::Tensor tmp_output;
{
at::AutoDispatchSkipFunctionalize guard;
tmp_output = op_handle.call(x_);
}
// Finally, tell functionalization about this mutation.
at::functionalization::impl::replace_(x, tmp_output);
at::functionalization::impl::commit_update(x);
at::functionalization::impl::sync(x);
return x;
}
TORCH_LIBRARY(custom_namespace, m) {
m.def("foo(Tensor x) -> Tensor");
m.def("foo_(Tensor(a!) x) -> Tensor(a!)");
}
TORCH_LIBRARY_IMPL(custom_namespace, CPU, m) {
m.impl("foo", foo_impl);
m.impl("foo_", foo__impl);
}
TORCH_LIBRARY_IMPL(custom_namespace, Functionalize, m) {
m.impl("foo_", foo__functionalization_glue);
}
# Example python code where we call the custom (mutable) op, run functionalization,
# and see that we properly traced the out-of-place version of the custom op
import torch
from functorch import functionalize
from torch.fx.experimental.proxy_tensor import make_fx
a = torch.ones(2)
# This is a custom op, that mutates its input
foo_ = torch.ops.custom_namespace.foo_
def f(x):
y = x.mul(2)
return foo_(y)
out = make_fx(functionalize(f))(a)
# foo() will show up in the graph instead of foo_()
print(out.code)
@wbigat
Copy link

wbigat commented Sep 23, 2023

Hi,@bdhirsh.
I've defined a custom in-place op, following your guidance. And I did generate fx graphs, but there seems to be some problems with the results,as follows. Could you please give me some help? Thank you very much.
https://discuss.pytorch.org/t/how-to-define-an-in-place-custom-op-in-dynamo/188814?u=bigat_w

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment