Link to prototype branch: https://github.com/ailzhang/pytorch/commit/83f647e9f14a89e61378c8e834aec4dfbcb74a00
This prototype only focus on getting rid of aliasing in view ops, after this Func2
you can safely assume that after Func2
kernel you won't see any view ops, but only view_op_copy
ops that returns completely new storage.
You can build this branch and check some examples by running python test/base_to_view.py
and python test/view_to_base.py
. But this branch is only a proof of concept and comes with a lot of hacks, and requires some careful design work described in the section below.
-
It takes one view op
at::view
and added aat::view_copy
which is the "non-alias" version ofview
. -
It added a new dispatch key
Func2
(sorry for the bad naming...) afterAutograd
keys and beforeADInplaceOrView
which comes with the following kernels: (in TensorShape.cpp and VariableFallbackKernel.cpp)- Fallback kernel for
Func2
materialize all the tensors before redispatch to lower dispatch keys. Here by materialize, it callstensor.sync_()
so that the value is updated before reaching other ops. Note in the prototype we hacked CPU backend for testing, since it's eager execution modetensor.sync_()
eagerly updates the tensor value as well. In lazy execution mode (e.g. XLA backend) thistensor.sync_()
can be simplycurrent_ir.set_(new_ir)
without executing the kernel. - Kernel for view ops at
Func2
dispatch key should handle the view stuff and redispatch to the non-alias version of the view op. - Kernel for inplace op at
Func2
dispatch key should redispatch to lower keys, and also queue the inplace update onself
's alias_ (which contains base tensor information).
- Fallback kernel for
-
Misc helpers to make the following use case work:
1. inplace update on view tensor can propagate back to base tensor (test/view_to_base.py) import torch a = torch.rand(2, 4) b = a.view(8) c = b.view(4, 2) const = torch.ones_like(c) print('a ', a) print('b ', b) print('c ', c) c.add_(const) print('inplace update view tensor c also updates its aliases') print('a ', a) # updated print('b ', b) # updated print('c ', c) # updated 2. inplace update on base tensor can propagate to view tensor (test/base_to_view.py) import torch a = torch.rand(2, 4) b = a.view(8) c = b.view(4, 2) const = torch.ones_like(a) print('a ', a) print('b ', b) print('c ', c) a.add_(const) print('inplace update base tensor a also updates its view tensors') print('a ', a) # updated print('b ', b) # updated print('c ', c) # updated
- it supports chaining view ops as well but since we only support
view/add_
in the prototype, you can only chainview
s in this prototype (covered in the test above).
- it supports chaining view ops as well but since we only support
- I only implemented
view
as an example of view ops, andadd_
as an example for inplace ops to make sure it works for my test cases. There are many other view ops we need to implement, and need to codegen a similar kernel for all inplace ops atFunc2
key. Also ofc I don't expect this prototype to pass all the test suite as well, it might blow up here and there. :P - I hacked this prototype to demonstrate we can get rid of aliasing using kernels in dispatcher but completely ignored a proper data structure design, e.g
- the alias_ structure shouldn't live in Tensor class, instead it might fit better in
ViewInfo
which is allocated on heap. - I followed xla's implementation and used a switch case + saving a lot of side info (source/dst sizes) in
ViewMeta
struct, but it might be cleaner if we can wrap everything in astd::function
instead which might save us from carrying a vector ofViewMeta
, but just astd::function
which contains chains of view ops inside. - For easy testing purpose, I hacked to include
Func2
key on CPU tensors ( in practice CPU tensor supports aliasing soFunc2
should be enabled per backend choice).
- the alias_ structure shouldn't live in Tensor class, instead it might fit better in
Fact: output tensor shape of view ops is always ≤ input tensor shape.
So the first time a view happens on the input tensor, we turn the input tensor to a view tensor with ViewMeta=kNoOp
, and saved input tensor in an Alias
struct which will be shared by ALL the aliases of the input tensor.
For example, say we have two view chains: tensor → view1 → view2 and tensor → view3 → view4, here tensor/view1/view2/view3/view4
all points to the same Alias
object which contains tensor
. Also view1/view2/view3/view4
also carries ViewMeta
information about how to get its content from tensor
.
When an inplace update happens (on base or view tensor), we don't immediately propagate this update to all its aliases, but instead queue the update on the Alias
it points to.
When we hit a kernel at Func2
which is non-view and non-inplace op, we'll use the fallback kernel which try to materialize the tensor if its Alias
has pending updates before we proceed. (again this is eagerly done in CPU backend but can be lazy in XLA backend).
When we materialize a tensor, we first apply all the updates in Alias
object and updates the base tensor. And then we reapply the ViewMeta
info on the current tensor to get the updated content from base tensor.
Func2 doesn't handle mutation removal, yet. In other words, Func2
is only part of the functionalization pass we want. It's expected to see inplace ops after Func2 kernel. But after Func2 it's much easier to do mutation removal. You can add another key after Func2 to do mutation removal as a separate job.
A simple solution proposed by alband is we just keep a special set_
(already implemented as a special kernel in the prototype above) and transform all the inplace ops to their functional version.
X = ...
Y = X
X.add_(2)
// transform to
X.set_(X + 2)
// Y and X points to the same Tensor so it's also updated.