Skip to content

Instantly share code, notes, and snippets.

@ailzhang
Last active June 23, 2022 03:07
Show Gist options
  • Save ailzhang/75af24db042ec5e101a6fa4fef1122c3 to your computer and use it in GitHub Desktop.
Save ailzhang/75af24db042ec5e101a6fa4fef1122c3 to your computer and use it in GitHub Desktop.

Functionalization

Link to prototype branch: https://github.com/ailzhang/pytorch/commit/83f647e9f14a89e61378c8e834aec4dfbcb74a00

Quick summary:

This prototype only focus on getting rid of aliasing in view ops, after this Func2 you can safely assume that after Func2 kernel you won't see any view ops, but only view_op_copy ops that returns completely new storage.

You can build this branch and check some examples by running python test/base_to_view.py and python test/view_to_base.py. But this branch is only a proof of concept and comes with a lot of hacks, and requires some careful design work described in the section below.

What does the prototype support?

  1. It takes one view op at::view and added a at::view_copy which is the "non-alias" version of view.

  2. It added a new dispatch key Func2 (sorry for the bad naming...) after Autograd keys and before ADInplaceOrView which comes with the following kernels: (in TensorShape.cpp and VariableFallbackKernel.cpp)

    1. Fallback kernel for Func2 materialize all the tensors before redispatch to lower dispatch keys. Here by materialize, it calls tensor.sync_() so that the value is updated before reaching other ops. Note in the prototype we hacked CPU backend for testing, since it's eager execution mode tensor.sync_() eagerly updates the tensor value as well. In lazy execution mode (e.g. XLA backend) this tensor.sync_() can be simply current_ir.set_(new_ir) without executing the kernel.
    2. Kernel for view ops at Func2 dispatch key should handle the view stuff and redispatch to the non-alias version of the view op.
    3. Kernel for inplace op at Func2 dispatch key should redispatch to lower keys, and also queue the inplace update on self's alias_ (which contains base tensor information).
  3. Misc helpers to make the following use case work:

    1. inplace update on view tensor can propagate back to base tensor (test/view_to_base.py)
    
    import torch
    a = torch.rand(2, 4)
    b = a.view(8)
    c = b.view(4, 2)
    const = torch.ones_like(c)
    print('a ', a)
    print('b ', b)
    print('c ', c)
    c.add_(const)
    print('inplace update view tensor c also updates its aliases')
    print('a ', a) # updated
    print('b ', b) # updated
    print('c ', c) # updated
    
    2. inplace update on base tensor can propagate to view tensor (test/base_to_view.py)
    
    import torch
    a = torch.rand(2, 4)
    b = a.view(8)
    c = b.view(4, 2)
    const = torch.ones_like(a)
    print('a ', a)
    print('b ', b)
    print('c ', c)
    a.add_(const)
    print('inplace update base tensor a also updates its view tensors')
    print('a ', a) # updated
    print('b ', b) # updated
    print('c ', c) # updated
    • it supports chaining view ops as well but since we only support view/add_ in the prototype, you can only chain views in this prototype (covered in the test above).

What're the potential followups?

  1. I only implemented view as an example of view ops, and add_ as an example for inplace ops to make sure it works for my test cases. There are many other view ops we need to implement, and need to codegen a similar kernel for all inplace ops at Func2 key. Also ofc I don't expect this prototype to pass all the test suite as well, it might blow up here and there. :P
  2. I hacked this prototype to demonstrate we can get rid of aliasing using kernels in dispatcher but completely ignored a proper data structure design, e.g
    1. the alias_ structure shouldn't live in Tensor class, instead it might fit better in ViewInfo which is allocated on heap.
    2. I followed xla's implementation and used a switch case + saving a lot of side info (source/dst sizes) in ViewMeta struct, but it might be cleaner if we can wrap everything in a std::function instead which might save us from carrying a vector of ViewMeta , but just a std::function which contains chains of view ops inside.
    3. For easy testing purpose, I hacked to include Func2 key on CPU tensors ( in practice CPU tensor supports aliasing so Func2 should be enabled per backend choice).

Implementation details:

Fact: output tensor shape of view ops is always ≤ input tensor shape.

So the first time a view happens on the input tensor, we turn the input tensor to a view tensor with ViewMeta=kNoOp, and saved input tensor in an Alias struct which will be shared by ALL the aliases of the input tensor.

For example, say we have two view chains: tensor → view1 → view2 and tensor → view3 → view4, here tensor/view1/view2/view3/view4 all points to the same Alias object which contains tensor. Also view1/view2/view3/view4 also carries ViewMeta information about how to get its content from tensor.

When an inplace update happens (on base or view tensor), we don't immediately propagate this update to all its aliases, but instead queue the update on the Alias it points to.

When we hit a kernel at Func2 which is non-view and non-inplace op, we'll use the fallback kernel which try to materialize the tensor if its Alias has pending updates before we proceed. (again this is eagerly done in CPU backend but can be lazy in XLA backend).

When we materialize a tensor, we first apply all the updates in Aliasobject and updates the base tensor. And then we reapply the ViewMeta info on the current tensor to get the updated content from base tensor.

How about mutation removal?

Func2 doesn't handle mutation removal, yet. In other words, Func2 is only part of the functionalization pass we want. It's expected to see inplace ops after Func2 kernel. But after Func2 it's much easier to do mutation removal. You can add another key after Func2 to do mutation removal as a separate job.

A simple solution proposed by alband is we just keep a special set_ (already implemented as a special kernel in the prototype above) and transform all the inplace ops to their functional version.

X = ...
Y = X
X.add_(2)
// transform to 
X.set_(X + 2)
// Y and X points to the same Tensor so it's also updated. 
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment