This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.nested._internal.nested_tensor import jagged_from_list | |
a = torch.randn(2, 7, 256, requires_grad=True, dtype=torch.float32) | |
b = torch.randn(3, 7, 256, requires_grad=True, dtype=torch.float32) | |
c = torch.randn(4, 7, 256, requires_grad=True, dtype=torch.float32) | |
d = torch.randn(5, 7, 256, requires_grad=True, dtype=torch.float32) | |
nt1 = jagged_from_list([a, b, c, d], None)[0] | |
nt2 = jagged_from_list([a, b, c, d], None)[0] | |
nt1_view = nt1.select(2, 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
class T(torch.Tensor): | |
def __new__(cls, elem): | |
return torch.Tensor._make_wrapper_subclass(cls, elem.shape, dtype=elem.dtype) | |
def __init__(self, elem): | |
self.elem = elem | |
@classmethod |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Technically even in the "easy case" of t._base.requires_grad == t.requires_grad | |
# I need to perform two views to recreate that view authentically. why? | |
# There are actually two things I need to recreate, (1) the autograd | |
# graph relationship and (2) the view relationship. | |
# The reason we don't handle this today is because this autograd connectivity information | |
# is not accessible during tracing and hence not relevant to compile in part because dynam | |
# doesn't support grad_fn access. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.nested._internal.nested_tensor import NestedTensor, jagged_from_list | |
from torch.profiler import profile, record_function, ProfilerActivity | |
device="cuda:5" | |
for nb_unit in (10, 1, 2, 5, 20): | |
lin = torch.nn.functional.linear | |
def sin(x): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.library import Library | |
from typing import List | |
from torch.utils._pytree import tree_map | |
import torch.nn.functional as F | |
import functools | |
import numpy as np | |
# NestedTensor __torch_dispatch__ wrapper tensor subclass POC |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.library import Library | |
from typing import List | |
from torch.utils._pytree import tree_map | |
import functools | |
# NestedTensor __torch_dispatch__ wrapper tensor subclass POC | |
# | |
# 1) The __torch_dispatch__ handles pointwise ops entirely in python |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.library import Library | |
test_ns = "abc" | |
lib = Library(test_ns, "FRAGMENT") | |
lib.define("foo(Tensor(a!) a, Tensor(b!) b) -> (Tensor(a!), Tensor(b!))") | |
def get_op(name): | |
return getattr(getattr(torch.ops, test_ns), name).default | |
op = get_op("foo") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions | |
import pprint | |
m = torch.nn.Linear(10, 10) | |
def fn(x): | |
return m(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(pytorch1) jw3468-mbp:pytorch1 jw3468$ python test/dynamo/test_modules.py -k TestTemp -v | |
test_dynamo_inline_module_nn_AdaptiveAvgPool1d_cpu_float32 (__main__.TestTempCPU) ... | |
opcode name target args kwargs | |
----------- --------- --------- ------------ -------- | |
placeholder l_args_0_ L_args_0_ () {} | |
call_module m m (l_args_0_,) {} | |
output output output ((m,),) {} | |
stats [('calls_captured', 1), ('unique_graphs', 1)] | |
ok |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
test_dynamo_inline_module_nn_AdaptiveAvgPool1d_cpu_float32 (__main__.TestTempCPU) ... | |
opcode name target args kwargs | |
------------- ------------------- ------------------------------------------------------------------- ------------------------- -------- | |
placeholder l_args_0_ L_args_0_ () {} | |
call_function adaptive_avg_pool1d <built-in method adaptive_avg_pool1d of type object at 0x103ccc7e8> (l_args_0_, 3) {} | |
output output output ((adaptive_avg_pool1d,),) {} | |
inline_call [] | |
stats [('calls_captured', 1), ('unique_graphs', 1)] | |
ok |
NewerOlder