Skip to content

Instantly share code, notes, and snippets.

from torch.nested._internal.nested_tensor import jagged_from_list
a = torch.randn(2, 7, 256, requires_grad=True, dtype=torch.float32)
b = torch.randn(3, 7, 256, requires_grad=True, dtype=torch.float32)
c = torch.randn(4, 7, 256, requires_grad=True, dtype=torch.float32)
d = torch.randn(5, 7, 256, requires_grad=True, dtype=torch.float32)
nt1 = jagged_from_list([a, b, c, d], None)[0]
nt2 = jagged_from_list([a, b, c, d], None)[0]
nt1_view = nt1.select(2, 1)
@soulitzer
soulitzer / inference_mode_propagation.py
Created November 3, 2023 00:51
Edge case if we try to patch inference-ness in ADInplaceOrView
import torch
class T(torch.Tensor):
def __new__(cls, elem):
return torch.Tensor._make_wrapper_subclass(cls, elem.shape, dtype=elem.dtype)
def __init__(self, elem):
self.elem = elem
@classmethod
# Technically even in the "easy case" of t._base.requires_grad == t.requires_grad
# I need to perform two views to recreate that view authentically. why?
# There are actually two things I need to recreate, (1) the autograd
# graph relationship and (2) the view relationship.
# The reason we don't handle this today is because this autograd connectivity information
# is not accessible during tracing and hence not relevant to compile in part because dynam
# doesn't support grad_fn access.
import torch
from torch.nested._internal.nested_tensor import NestedTensor, jagged_from_list
from torch.profiler import profile, record_function, ProfilerActivity
device="cuda:5"
for nb_unit in (10, 1, 2, 5, 20):
lin = torch.nn.functional.linear
def sin(x):
@soulitzer
soulitzer / nested_tensor.py
Last active July 19, 2023 15:48
NestedTensor __torch_dispatch__ wrapper tensor subclass POC with automatic dispatching to custom kernel
import torch
from torch.library import Library
from typing import List
from torch.utils._pytree import tree_map
import torch.nn.functional as F
import functools
import numpy as np
# NestedTensor __torch_dispatch__ wrapper tensor subclass POC
@soulitzer
soulitzer / nested_tensor.py
Created July 17, 2023 15:24
NestedTensor python torch dispatch wrapper subclass
import torch
from torch.library import Library
from typing import List
from torch.utils._pytree import tree_map
import functools
# NestedTensor __torch_dispatch__ wrapper tensor subclass POC
#
# 1) The __torch_dispatch__ handles pointwise ops entirely in python
@soulitzer
soulitzer / test.py
Created July 6, 2023 23:03
output_nr issue when requires_grad=False
from torch.library import Library
test_ns = "abc"
lib = Library(test_ns, "FRAGMENT")
lib.define("foo(Tensor(a!) a, Tensor(b!) b) -> (Tensor(a!), Tensor(b!))")
def get_op(name):
return getattr(getattr(torch.ops, test_ns), name).default
op = get_op("foo")
@soulitzer
soulitzer / get_source_partition.py
Created June 27, 2023 15:33
get_source_partitions produces different results in export aten_graph=True
import torch
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
import pprint
m = torch.nn.Linear(10, 10)
def fn(x):
return m(x)
@soulitzer
soulitzer / gist:4da95730e40d814aa0c64cdff5c48571
Last active June 19, 2023 20:17
fx graph before inlining modules
(pytorch1) jw3468-mbp:pytorch1 jw3468$ python test/dynamo/test_modules.py -k TestTemp -v
test_dynamo_inline_module_nn_AdaptiveAvgPool1d_cpu_float32 (__main__.TestTempCPU) ...
opcode name target args kwargs
----------- --------- --------- ------------ --------
placeholder l_args_0_ L_args_0_ () {}
call_module m m (l_args_0_,) {}
output output output ((m,),) {}
stats [('calls_captured', 1), ('unique_graphs', 1)]
ok
@soulitzer
soulitzer / gist:c747d3e9cb1e241f6c7b9b57c0f84a9b
Created June 19, 2023 20:12
fx graph after inlining module call
test_dynamo_inline_module_nn_AdaptiveAvgPool1d_cpu_float32 (__main__.TestTempCPU) ...
opcode name target args kwargs
------------- ------------------- ------------------------------------------------------------------- ------------------------- --------
placeholder l_args_0_ L_args_0_ () {}
call_function adaptive_avg_pool1d <built-in method adaptive_avg_pool1d of type object at 0x103ccc7e8> (l_args_0_, 3) {}
output output output ((adaptive_avg_pool1d,),) {}
inline_call []
stats [('calls_captured', 1), ('unique_graphs', 1)]
ok