Skip to content

Instantly share code, notes, and snippets.

@anj-s
Created September 23, 2021 00:42
Show Gist options
  • Save anj-s/cc3d65e168e51f2affec813a909de5c0 to your computer and use it in GitHub Desktop.
Save anj-s/cc3d65e168e51f2affec813a909de5c0 to your computer and use it in GitHub Desktop.
SsdParameter - SsdTensorHandle is a property
class SsdParameter(torch.nn.Parameter):
@staticmethod
def __new__(
cls: SsdParameter, data: torch.Tensor, shape: Tuple[int, ...], dtype: torch.dtype, requires_grad: bool = False
) -> SsdParameter:
if data is None:
data = torch.tensor([])
return torch.Tensor._make_subclass(cls, data, requires_grad)
if type(data).__name__ == 'Tensor':
return torch.Tensor._make_subclass(cls, data, requires_grad)
else:
# The wrapping tensor is just a meta tensor, so it
# doesn't hold any memory (meta tensor is generally the preferred type
# of tensor you want to make a subclass from)...
p = torch.Tensor._make_subclass(cls, data.to("meta"), requires_grad)
# ...the real tensor is held as an element on the tensor.
p.wrapped = SsdTensorHandle._make_subclass(cls, torch.empty(data.shape, dtype=data.dtype), data.requires_grad)
return p
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(e):
return e.wrapped if hasattr(e, 'wrapped') else e.wrapped
return func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs if kwargs else {}))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment