Created
September 23, 2021 00:42
-
-
Save anj-s/cc3d65e168e51f2affec813a909de5c0 to your computer and use it in GitHub Desktop.
SsdParameter - SsdTensorHandle is a property
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SsdParameter(torch.nn.Parameter): | |
@staticmethod | |
def __new__( | |
cls: SsdParameter, data: torch.Tensor, shape: Tuple[int, ...], dtype: torch.dtype, requires_grad: bool = False | |
) -> SsdParameter: | |
if data is None: | |
data = torch.tensor([]) | |
return torch.Tensor._make_subclass(cls, data, requires_grad) | |
if type(data).__name__ == 'Tensor': | |
return torch.Tensor._make_subclass(cls, data, requires_grad) | |
else: | |
# The wrapping tensor is just a meta tensor, so it | |
# doesn't hold any memory (meta tensor is generally the preferred type | |
# of tensor you want to make a subclass from)... | |
p = torch.Tensor._make_subclass(cls, data.to("meta"), requires_grad) | |
# ...the real tensor is held as an element on the tensor. | |
p.wrapped = SsdTensorHandle._make_subclass(cls, torch.empty(data.shape, dtype=data.dtype), data.requires_grad) | |
return p | |
@classmethod | |
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): | |
def unwrap(e): | |
return e.wrapped if hasattr(e, 'wrapped') else e.wrapped | |
return func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs if kwargs else {})) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment