Skip to content

Instantly share code, notes, and snippets.

@take-cheeze
Created December 14, 2022 01:06
Show Gist options
  • Save take-cheeze/8719ad2beed8fa65fffea322f7d1a1b6 to your computer and use it in GitHub Desktop.
Save take-cheeze/8719ad2beed8fa65fffea322f7d1a1b6 to your computer and use it in GitHub Desktop.
import torch
import torch.onnx.symbolic_helper as sym_hel
from typing import Any
def _maybe_get_scalar(value: torch._C.Value) -> Any:
value_t = _maybe_get_const(value, "t")
if isinstance(value_t, torch.Tensor) and value_t.shape == ():
return value_t
node_v = sym_hel._node_get(value, "value")
if isinstance(node_v, int):
return torch.scalar_tensor(node_v, dtype=torch.int64)
if isinstance(node_v, float):
return torch.scalar_tensor(node_v, dtype=double)
return value
sym_hel._maybe_get_scalar = _maybe_get_scalar
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment