This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import functools | |
from torch.utils._python_dispatch import TorchDispatchMode | |
import torch.utils._pytree as pytree | |
from torch.utils.weak import WeakTensorKeyDictionary | |
class RecomputableTensor(torch.Tensor): | |
@staticmethod | |
def __new__(cls, t, func, args): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.nested._internal.nested_tensor import jagged_from_list | |
a = torch.randn(2, 7, 256, requires_grad=True, dtype=torch.float32) | |
b = torch.randn(3, 7, 256, requires_grad=True, dtype=torch.float32) | |
c = torch.randn(4, 7, 256, requires_grad=True, dtype=torch.float32) | |
d = torch.randn(5, 7, 256, requires_grad=True, dtype=torch.float32) | |
nt1 = jagged_from_list([a, b, c, d], None)[0] | |
nt2 = jagged_from_list([a, b, c, d], None)[0] | |
nt1_view = nt1.select(2, 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
class T(torch.Tensor): | |
def __new__(cls, elem): | |
return torch.Tensor._make_wrapper_subclass(cls, elem.shape, dtype=elem.dtype) | |
def __init__(self, elem): | |
self.elem = elem | |
@classmethod |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Technically even in the "easy case" of t._base.requires_grad == t.requires_grad | |
# I need to perform two views to recreate that view authentically. why? | |
# There are actually two things I need to recreate, (1) the autograd | |
# graph relationship and (2) the view relationship. | |
# The reason we don't handle this today is because this autograd connectivity information | |
# is not accessible during tracing and hence not relevant to compile in part because dynam | |
# doesn't support grad_fn access. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.nested._internal.nested_tensor import NestedTensor, jagged_from_list | |
from torch.profiler import profile, record_function, ProfilerActivity | |
device="cuda:5" | |
for nb_unit in (10, 1, 2, 5, 20): | |
lin = torch.nn.functional.linear | |
def sin(x): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.library import Library | |
from typing import List | |
from torch.utils._pytree import tree_map | |
import torch.nn.functional as F | |
import functools | |
import numpy as np | |
# NestedTensor __torch_dispatch__ wrapper tensor subclass POC |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.library import Library | |
from typing import List | |
from torch.utils._pytree import tree_map | |
import functools | |
# NestedTensor __torch_dispatch__ wrapper tensor subclass POC | |
# | |
# 1) The __torch_dispatch__ handles pointwise ops entirely in python |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.library import Library | |
test_ns = "abc" | |
lib = Library(test_ns, "FRAGMENT") | |
lib.define("foo(Tensor(a!) a, Tensor(b!) b) -> (Tensor(a!), Tensor(b!))") | |
def get_op(name): | |
return getattr(getattr(torch.ops, test_ns), name).default | |
op = get_op("foo") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions | |
import pprint | |
m = torch.nn.Linear(10, 10) | |
def fn(x): | |
return m(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(pytorch1) jw3468-mbp:pytorch1 jw3468$ python test/dynamo/test_modules.py -k TestTemp -v | |
test_dynamo_inline_module_nn_AdaptiveAvgPool1d_cpu_float32 (__main__.TestTempCPU) ... | |
opcode name target args kwargs | |
----------- --------- --------- ------------ -------- | |
placeholder l_args_0_ L_args_0_ () {} | |
call_module m m (l_args_0_,) {} | |
output output output ((m,),) {} | |
stats [('calls_captured', 1), ('unique_graphs', 1)] | |
ok |
NewerOlder