Skip to content

Instantly share code, notes, and snippets.

@Lunderberg
Last active January 9, 2024 20:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Lunderberg/4a2957d536cabb30437edaafb904cbf7 to your computer and use it in GitHub Desktop.
Save Lunderberg/4a2957d536cabb30437edaafb904cbf7 to your computer and use it in GitHub Desktop.
Conversions between scalars and scalar tensors
#!/usr/bin/env python3
import numpy as np
import tvm.testing
from tvm import relax
from tvm.script import tir as T, relax as R
def _scale_by_0d_tensor():
"""Works"""
@R.function
def func(A: R.Tensor([16, 16], "float32"), B: R.Tensor([], "float32")):
C = A * B
return C
return func
def _scale_by_prim_value():
"""Doesn't work
Fails during struct inference. The multiplication uses
`relax.op.multiply`, which calls `InferStructInfoBroadcast` for
shape inference. `InferStructInfoBroadcast` requires both
arguments to be tensors.
"""
@R.function
def func(A: R.Tensor([16, 16], "float32"), B: R.Prim("float32")):
# Fails during InferStructInfo,
C = A * B
return C
return func
def _scale_by_prim_value_wrapped_in_const():
"""Doesn't work
Fails during parsing. The argument of `R.const` must be a python
scalar or a NDArray, because it is stored as an NDArray in the
IRModule.
"""
@R.function
def func(A: R.Tensor([16, 16], "float32"), B: R.Prim("float32")):
C = A * R.const(B, dtype="float32")
return C
return func
def _scale_int64_match_cast_then_shape_to_tensor():
"""Works
The `factor` is defined, can be converted to a tensor, and then
used as a scaling factor.
"""
@R.function
def func(A: R.Tensor([16, 16], "int64"), B: R.Prim("int64")):
factor = T.int64()
_ = R.match_cast(B, R.Prim(value=factor))
# Fails here during parsing.
dummy_shape = R.shape([factor])
B_as_tensor = R.shape_to_tensor(dummy_shape)
C = A * B_as_tensor
return C
return func
def _scale_int64_arg_value_then_shape_to_tensor():
"""Works
The `factor` is defined, can be converted to a tensor, and then
used as a scaling factor.
"""
@R.function
def func(A: R.Tensor([16, 16], "int64"), B: R.Prim(value="factor")):
factor = T.int64()
dummy_shape = R.shape([factor])
B_as_tensor = R.shape_to_tensor(dummy_shape)
C = A * B_as_tensor
return C
return func
def _scale_float32_arg_value_then_shape_to_tensor():
"""Doesn't work
The `factor` is assumed to have dtype of int64, which causes an
error when later defined as float32.
"""
@R.function
def func(A: R.Tensor([16, 16], "float32"), B: R.Prim(value="factor")):
# Fails here during parsing
factor = T.float32()
dummy_shape = R.shape([factor])
B_as_tensor = R.shape_to_tensor(dummy_shape)
C = A * B_as_tensor
return C
return func
def _scale_float32_external_arg_value_then_shape_to_tensor():
"""Doesn't work
In order to use `R.shape_to_tensor`, must first bundle the
PrimExpr into a `R.shape`. This fails, because `R.shape` requires
each item to have dtype of int64.
This restricts the use of `R.shape_to_tensor` as a workaround.
"""
factor = tvm.tir.Var("factor", "float32")
@R.function
def func(A: R.Tensor([16, 16], "float32"), B: R.Prim(value=factor)):
dummy_shape = R.shape([factor])
B_as_tensor = R.shape_to_tensor(dummy_shape)
C = A * B_as_tensor
return C
return func
generate_scaling_func = tvm.testing.parameter(
_scale_by_0d_tensor,
_scale_by_prim_value,
_scale_by_prim_value_wrapped_in_const,
_scale_int64_match_cast_then_shape_to_tensor,
_scale_int64_arg_value_then_shape_to_tensor,
_scale_float32_arg_value_then_shape_to_tensor,
_scale_float32_external_arg_value_then_shape_to_tensor,
)
def test_scale_by_constant(generate_scaling_func):
func = generate_scaling_func()
dtype = func.params[0].struct_info.dtype
A_np = np.random.uniform(low=0, high=255, size=[16, 16]).astype(dtype)
B_np = np.random.uniform(low=0, high=255, size=[]).astype(dtype)
mod = tvm.IRModule.from_expr(func)
built_mod = relax.build(mod, target="llvm")
vm = relax.vm.VirtualMachine(built_mod, tvm.cpu(0))
built_func = vm["func"]
A_tvm = tvm.nd.array(A_np)
if isinstance(func.params[1].struct_info, relax.TensorStructInfo):
B_tvm = tvm.nd.array(B_np)
elif isinstance(func.params[1].struct_info, relax.PrimStructInfo):
B_tvm = B_np = B_np[()]
C_np = A_np * B_np
C_relax = built_func(A_tvm, B_tvm).numpy()
tvm.testing.assert_allclose(C_np, C_relax)
def _tril_by_prim_value():
"""Doens't work
Fails when converting `R.tril` to a TE expression, because there
is no symbolic variable for `offset`.
"""
@R.function
def func(A: R.Tensor([16, 16], "float32"), offset: R.Prim(dtype="int64")):
# Fails when converting `R.tril` to a TE expression
tril = R.tril(A, offset)
return tril
return func
def _tril_by_prim_value_with_symbolic_var():
"""Doesn't work (yet)
Bugfix needed in
`tvm.relax.utils._ArgsConverter._convert_te_arg_helper`, should
return provide `_convert_te_arg_helper(arg.struct_info.value)`
when encountering a `PrimStructInfo` with a known binding, instead
of `arg.value`.
"""
@R.function
def func(A: R.Tensor([16, 16], "float32"), relax_offset: R.Prim(value="tir_offset")):
tril = R.tril(A, relax_offset)
return tril
return func
def _tril_by_prim_expr():
"""Doesn't work (yet)
Bugfix needed in
`tvm.relax.utils._ArgsConverter._convert_te_arg_helper`, should
return provide `_convert_te_arg_helper(arg.struct_info.value)`
when encountering a `PrimStructInfo` with a known binding, instead
of `arg.value`.
"""
@R.function
def func(A: R.Tensor([16, 16], "float32"), relax_offset: R.Prim(value="tir_offset")):
tir_offset = T.int64()
tril = R.tril(A, tir_offset)
return tril
return func
def _tril_by_0d_tensor():
"""Doesn't work
The second argument of `R.tril` must be a `R.Prim` with `int64`
dtype. The 0-d tensor has `TensorStructInfo` instead.
"""
@R.function
def func(A: R.Tensor([16, 16], "float32"), offset: R.Tensor([], "int64")):
# Fails here in shape inference
tril = R.tril(A, offset)
return tril
return func
def _tril_by_0d_tensor_with_tensor_to_shape():
"""Doesn't work"""
@R.function
def func(A: R.Tensor([16, 16], "float32"), offset: R.Tensor([], "int64")):
# Fails here, need a 1-d tensor
offset_shape = R.tensor_to_shape(offset)
tril = R.tril(A, R.prim_value(offset_shape))
return tril
return func
def _tril_by_1d_tensor_with_tensor_to_shape_and_prim_value():
"""Doesn't work
The `R.prim_value` argument needs a `PrimExpr`, not a relax
`Expr`.
"""
@R.function
def func(A: R.Tensor([16, 16], "float32"), offset: R.Tensor([1], "int64")):
offset_shape = R.tensor_to_shape(offset)
# Fails here, need a PrimExpr
tril = R.tril(A, R.prim_value(offset_shape))
return tril
return func
def _tril_by_1d_tensor_with_tensor_to_shape():
"""Doesn't work"""
@R.function
def func(A: R.Tensor([16, 16], "float32"), offset: R.Tensor([1], "int64")):
offset_shape = R.tensor_to_shape(offset)
# Fails here, need a R.Prim or PrimExpr, not a R.Shape.
tril = R.tril(A, offset_shape)
return tril
return func
def _tril_by_1d_tensor_with_tensor_to_shape_and_index():
"""Doesn't work
Indexing `offset_shape[0]` produces a `TupleGetItem` node, where
it should access an element of the shape tuple.
"""
@R.function
def func(A: R.Tensor([16, 16], "float32"), offset: R.Tensor([1], "int64")):
offset_shape = R.tensor_to_shape(offset)
# Fails here in shape inference, `offset_shape` is a
# `R.Shape`, not `R.Tuple`.
relax_offset = R.prim_value(offset_shape[0])
tril = R.tril(A, relax_offset)
return tril
return func
def _tril_by_1d_tensor_with_tensor_to_shape_and_match_cast():
"""Doesn't work
Fails during CodegenVM, cannot handle R.tensor_to_shape.
"""
@R.function
def func(A: R.Tensor([16, 16], "float32"), offset: R.Tensor([1], "int64")):
offset_shape = R.tensor_to_shape(offset)
tir_offset = T.int64()
_ = R.match_cast(offset_shape, R.Shape([tir_offset]))
tril = R.tril(A, tir_offset)
return tril
return func
def _tril_by_1d_tensor_with_tensor_to_shape_and_match_cast_with_decompose():
"""Doesn't work
Fails during CodegenVM, the `offset_shape` variable is not defined.
Bugfix needed in `VMShapeLower`. In the
`VisitBinding_(MatchCastNode*)`, it re-emits the binding as-is,
without using the mutated value. As a result, the old value is
still present, including any references to replaced variables.
"""
func = _tril_by_1d_tensor_with_tensor_to_shape_and_match_cast()
mod = tvm.IRModule.from_expr(func)
mod = tvm.relax.transform.DecomposeOpsForInference()(mod)
func = mod["func"]
# func.show()
#
# @R.function
# def func(
# A: R.Tensor((16, 16), dtype="float32"), offset: R.Tensor((1,), dtype="int64")
# ) -> R.Tensor((16, 16), dtype="float32"):
# x = T.int64()
# tir_offset = T.int64()
# gv: R.Shape(ndim=1) = R.call_pure_packed(
# "vm.builtin.tensor_to_shape", offset, sinfo_args=(R.Shape(ndim=1),)
# )
# y: R.Shape([x]) = R.match_cast(gv, R.Shape([x]))
# offset_shape: R.Shape([x]) = R.shape([x])
# _: R.Shape([tir_offset]) = R.match_cast(offset_shape, R.Shape([tir_offset]))
# tril: R.Tensor((16, 16), dtype="float32") = R.tril(A, R.prim_value(tir_offset))
# return tril
return func
def _tril_by_1d_tensor_with_explicit_tensor_to_shape():
"""Works
Explicitly write the output of `DecomposeOpsForInference`, to
avoid a round-trip of `R.match_cast` to `R.shape` to
R.match_cast`, as that seems to cause an undefined variable
in `VMShapeLower`.
Requires a bugfix in `python/tvm/relax/utils.py`. In
`_convert_te_arg_helper` should check PrimStructInfo for TIR
variables that need to be passed in.
if (
isinstance(arg.struct_info, PrimStructInfo)
and arg.struct_info.value is not None
):
return _convert_te_arg_helper(arg.struct_info.value)
"""
@R.function
def func(
A: R.Tensor((16, 16), dtype="float32"), offset: R.Tensor((1,), dtype="int64")
) -> R.Tensor((16, 16), dtype="float32"):
tir_offset = T.int64()
offset_shape = R.call_pure_packed(
"vm.builtin.tensor_to_shape", offset, sinfo_args=(R.Shape(ndim=1),)
)
_ = R.match_cast(offset_shape, R.Shape([tir_offset]))
tril = R.tril(A, R.prim_value(tir_offset))
return tril
return func
def _tril_by_0d_tensor_with_explicit_tensor_to_shape():
"""Almost works
Explicitly write the output of `DecomposeOpsForInference`, to
avoid a round-trip of `R.match_cast` to `R.shape` to
R.match_cast`, as that seems to cause an undefined variable
in `VMShapeLower`.
Requires a bugfix in `python/tvm/relax/utils.py`. In
`_convert_te_arg_helper` should check PrimStructInfo for TIR
variables that need to be passed in.
if (
isinstance(arg.struct_info, PrimStructInfo)
and arg.struct_info.value is not None
):
return _convert_te_arg_helper(arg.struct_info.value)
"""
@R.function
def func(
A: R.Tensor((16, 16), dtype="float32"), offset: R.Tensor([], dtype="int64")
) -> R.Tensor((16, 16), dtype="float32"):
tir_offset = T.int64()
offset_1d: R.Tensor([1], dtype="int64") = R.full([1], offset)
offset_shape = R.call_pure_packed(
"vm.builtin.tensor_to_shape", offset_1d, sinfo_args=(R.Shape(ndim=1),)
)
_ = R.match_cast(offset_shape, R.Shape([tir_offset]))
tril = R.tril(A, R.prim_value(tir_offset))
return tril
return func
generate_tril_func = tvm.testing.parameter(
_tril_by_prim_value,
_tril_by_prim_value_with_symbolic_var,
_tril_by_prim_expr,
_tril_by_0d_tensor,
_tril_by_0d_tensor_with_tensor_to_shape,
_tril_by_1d_tensor_with_tensor_to_shape_and_prim_value,
_tril_by_1d_tensor_with_tensor_to_shape,
_tril_by_1d_tensor_with_tensor_to_shape_and_index,
_tril_by_1d_tensor_with_tensor_to_shape_and_match_cast,
_tril_by_1d_tensor_with_tensor_to_shape_and_match_cast_with_decompose,
_tril_by_1d_tensor_with_explicit_tensor_to_shape,
_tril_by_0d_tensor_with_explicit_tensor_to_shape,
)
def test_tril(generate_tril_func):
func = generate_tril_func()
dtype = func.params[0].struct_info.dtype
A_np = np.random.uniform(low=0, high=255, size=[16, 16]).astype(dtype)
B_np = np.random.uniform(low=2, high=15, size=[]).astype("int64")[()]
mod = tvm.IRModule.from_expr(func)
built_mod = relax.build(mod, target="llvm")
vm = relax.vm.VirtualMachine(built_mod, tvm.cpu(0))
built_func = vm["func"]
A_tvm = tvm.nd.array(A_np)
B_sinfo = func.params[1].struct_info
if isinstance(B_sinfo, relax.TensorStructInfo) and B_sinfo.ndim == 0:
B_tvm = tvm.nd.array(np.array(B_np))
elif isinstance(B_sinfo, relax.TensorStructInfo) and B_sinfo.ndim == 1:
B_tvm = tvm.nd.array(np.array([B_np]))
elif isinstance(B_sinfo, relax.PrimStructInfo):
B_tvm = B_np
C_np = np.tril(A_np, B_np)
C_relax = built_func(A_tvm, B_tvm).numpy()
tvm.testing.assert_allclose(C_np, C_relax)
if __name__ == "__main__":
tvm.testing.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment