-
-
Save Lunderberg/4a2957d536cabb30437edaafb904cbf7 to your computer and use it in GitHub Desktop.
Conversions between scalars and scalar tensors
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import numpy as np | |
import tvm.testing | |
from tvm import relax | |
from tvm.script import tir as T, relax as R | |
def _scale_by_0d_tensor(): | |
"""Works""" | |
@R.function | |
def func(A: R.Tensor([16, 16], "float32"), B: R.Tensor([], "float32")): | |
C = A * B | |
return C | |
return func | |
def _scale_by_prim_value(): | |
"""Doesn't work | |
Fails during struct inference. The multiplication uses | |
`relax.op.multiply`, which calls `InferStructInfoBroadcast` for | |
shape inference. `InferStructInfoBroadcast` requires both | |
arguments to be tensors. | |
""" | |
@R.function | |
def func(A: R.Tensor([16, 16], "float32"), B: R.Prim("float32")): | |
# Fails during InferStructInfo, | |
C = A * B | |
return C | |
return func | |
def _scale_by_prim_value_wrapped_in_const(): | |
"""Doesn't work | |
Fails during parsing. The argument of `R.const` must be a python | |
scalar or a NDArray, because it is stored as an NDArray in the | |
IRModule. | |
""" | |
@R.function | |
def func(A: R.Tensor([16, 16], "float32"), B: R.Prim("float32")): | |
C = A * R.const(B, dtype="float32") | |
return C | |
return func | |
def _scale_int64_match_cast_then_shape_to_tensor(): | |
"""Works | |
The `factor` is defined, can be converted to a tensor, and then | |
used as a scaling factor. | |
""" | |
@R.function | |
def func(A: R.Tensor([16, 16], "int64"), B: R.Prim("int64")): | |
factor = T.int64() | |
_ = R.match_cast(B, R.Prim(value=factor)) | |
# Fails here during parsing. | |
dummy_shape = R.shape([factor]) | |
B_as_tensor = R.shape_to_tensor(dummy_shape) | |
C = A * B_as_tensor | |
return C | |
return func | |
def _scale_int64_arg_value_then_shape_to_tensor(): | |
"""Works | |
The `factor` is defined, can be converted to a tensor, and then | |
used as a scaling factor. | |
""" | |
@R.function | |
def func(A: R.Tensor([16, 16], "int64"), B: R.Prim(value="factor")): | |
factor = T.int64() | |
dummy_shape = R.shape([factor]) | |
B_as_tensor = R.shape_to_tensor(dummy_shape) | |
C = A * B_as_tensor | |
return C | |
return func | |
def _scale_float32_arg_value_then_shape_to_tensor(): | |
"""Doesn't work | |
The `factor` is assumed to have dtype of int64, which causes an | |
error when later defined as float32. | |
""" | |
@R.function | |
def func(A: R.Tensor([16, 16], "float32"), B: R.Prim(value="factor")): | |
# Fails here during parsing | |
factor = T.float32() | |
dummy_shape = R.shape([factor]) | |
B_as_tensor = R.shape_to_tensor(dummy_shape) | |
C = A * B_as_tensor | |
return C | |
return func | |
def _scale_float32_external_arg_value_then_shape_to_tensor(): | |
"""Doesn't work | |
In order to use `R.shape_to_tensor`, must first bundle the | |
PrimExpr into a `R.shape`. This fails, because `R.shape` requires | |
each item to have dtype of int64. | |
This restricts the use of `R.shape_to_tensor` as a workaround. | |
""" | |
factor = tvm.tir.Var("factor", "float32") | |
@R.function | |
def func(A: R.Tensor([16, 16], "float32"), B: R.Prim(value=factor)): | |
dummy_shape = R.shape([factor]) | |
B_as_tensor = R.shape_to_tensor(dummy_shape) | |
C = A * B_as_tensor | |
return C | |
return func | |
generate_scaling_func = tvm.testing.parameter( | |
_scale_by_0d_tensor, | |
_scale_by_prim_value, | |
_scale_by_prim_value_wrapped_in_const, | |
_scale_int64_match_cast_then_shape_to_tensor, | |
_scale_int64_arg_value_then_shape_to_tensor, | |
_scale_float32_arg_value_then_shape_to_tensor, | |
_scale_float32_external_arg_value_then_shape_to_tensor, | |
) | |
def test_scale_by_constant(generate_scaling_func): | |
func = generate_scaling_func() | |
dtype = func.params[0].struct_info.dtype | |
A_np = np.random.uniform(low=0, high=255, size=[16, 16]).astype(dtype) | |
B_np = np.random.uniform(low=0, high=255, size=[]).astype(dtype) | |
mod = tvm.IRModule.from_expr(func) | |
built_mod = relax.build(mod, target="llvm") | |
vm = relax.vm.VirtualMachine(built_mod, tvm.cpu(0)) | |
built_func = vm["func"] | |
A_tvm = tvm.nd.array(A_np) | |
if isinstance(func.params[1].struct_info, relax.TensorStructInfo): | |
B_tvm = tvm.nd.array(B_np) | |
elif isinstance(func.params[1].struct_info, relax.PrimStructInfo): | |
B_tvm = B_np = B_np[()] | |
C_np = A_np * B_np | |
C_relax = built_func(A_tvm, B_tvm).numpy() | |
tvm.testing.assert_allclose(C_np, C_relax) | |
def _tril_by_prim_value(): | |
"""Doens't work | |
Fails when converting `R.tril` to a TE expression, because there | |
is no symbolic variable for `offset`. | |
""" | |
@R.function | |
def func(A: R.Tensor([16, 16], "float32"), offset: R.Prim(dtype="int64")): | |
# Fails when converting `R.tril` to a TE expression | |
tril = R.tril(A, offset) | |
return tril | |
return func | |
def _tril_by_prim_value_with_symbolic_var(): | |
"""Doesn't work (yet) | |
Bugfix needed in | |
`tvm.relax.utils._ArgsConverter._convert_te_arg_helper`, should | |
return provide `_convert_te_arg_helper(arg.struct_info.value)` | |
when encountering a `PrimStructInfo` with a known binding, instead | |
of `arg.value`. | |
""" | |
@R.function | |
def func(A: R.Tensor([16, 16], "float32"), relax_offset: R.Prim(value="tir_offset")): | |
tril = R.tril(A, relax_offset) | |
return tril | |
return func | |
def _tril_by_prim_expr(): | |
"""Doesn't work (yet) | |
Bugfix needed in | |
`tvm.relax.utils._ArgsConverter._convert_te_arg_helper`, should | |
return provide `_convert_te_arg_helper(arg.struct_info.value)` | |
when encountering a `PrimStructInfo` with a known binding, instead | |
of `arg.value`. | |
""" | |
@R.function | |
def func(A: R.Tensor([16, 16], "float32"), relax_offset: R.Prim(value="tir_offset")): | |
tir_offset = T.int64() | |
tril = R.tril(A, tir_offset) | |
return tril | |
return func | |
def _tril_by_0d_tensor(): | |
"""Doesn't work | |
The second argument of `R.tril` must be a `R.Prim` with `int64` | |
dtype. The 0-d tensor has `TensorStructInfo` instead. | |
""" | |
@R.function | |
def func(A: R.Tensor([16, 16], "float32"), offset: R.Tensor([], "int64")): | |
# Fails here in shape inference | |
tril = R.tril(A, offset) | |
return tril | |
return func | |
def _tril_by_0d_tensor_with_tensor_to_shape(): | |
"""Doesn't work""" | |
@R.function | |
def func(A: R.Tensor([16, 16], "float32"), offset: R.Tensor([], "int64")): | |
# Fails here, need a 1-d tensor | |
offset_shape = R.tensor_to_shape(offset) | |
tril = R.tril(A, R.prim_value(offset_shape)) | |
return tril | |
return func | |
def _tril_by_1d_tensor_with_tensor_to_shape_and_prim_value(): | |
"""Doesn't work | |
The `R.prim_value` argument needs a `PrimExpr`, not a relax | |
`Expr`. | |
""" | |
@R.function | |
def func(A: R.Tensor([16, 16], "float32"), offset: R.Tensor([1], "int64")): | |
offset_shape = R.tensor_to_shape(offset) | |
# Fails here, need a PrimExpr | |
tril = R.tril(A, R.prim_value(offset_shape)) | |
return tril | |
return func | |
def _tril_by_1d_tensor_with_tensor_to_shape(): | |
"""Doesn't work""" | |
@R.function | |
def func(A: R.Tensor([16, 16], "float32"), offset: R.Tensor([1], "int64")): | |
offset_shape = R.tensor_to_shape(offset) | |
# Fails here, need a R.Prim or PrimExpr, not a R.Shape. | |
tril = R.tril(A, offset_shape) | |
return tril | |
return func | |
def _tril_by_1d_tensor_with_tensor_to_shape_and_index(): | |
"""Doesn't work | |
Indexing `offset_shape[0]` produces a `TupleGetItem` node, where | |
it should access an element of the shape tuple. | |
""" | |
@R.function | |
def func(A: R.Tensor([16, 16], "float32"), offset: R.Tensor([1], "int64")): | |
offset_shape = R.tensor_to_shape(offset) | |
# Fails here in shape inference, `offset_shape` is a | |
# `R.Shape`, not `R.Tuple`. | |
relax_offset = R.prim_value(offset_shape[0]) | |
tril = R.tril(A, relax_offset) | |
return tril | |
return func | |
def _tril_by_1d_tensor_with_tensor_to_shape_and_match_cast(): | |
"""Doesn't work | |
Fails during CodegenVM, cannot handle R.tensor_to_shape. | |
""" | |
@R.function | |
def func(A: R.Tensor([16, 16], "float32"), offset: R.Tensor([1], "int64")): | |
offset_shape = R.tensor_to_shape(offset) | |
tir_offset = T.int64() | |
_ = R.match_cast(offset_shape, R.Shape([tir_offset])) | |
tril = R.tril(A, tir_offset) | |
return tril | |
return func | |
def _tril_by_1d_tensor_with_tensor_to_shape_and_match_cast_with_decompose(): | |
"""Doesn't work | |
Fails during CodegenVM, the `offset_shape` variable is not defined. | |
Bugfix needed in `VMShapeLower`. In the | |
`VisitBinding_(MatchCastNode*)`, it re-emits the binding as-is, | |
without using the mutated value. As a result, the old value is | |
still present, including any references to replaced variables. | |
""" | |
func = _tril_by_1d_tensor_with_tensor_to_shape_and_match_cast() | |
mod = tvm.IRModule.from_expr(func) | |
mod = tvm.relax.transform.DecomposeOpsForInference()(mod) | |
func = mod["func"] | |
# func.show() | |
# | |
# @R.function | |
# def func( | |
# A: R.Tensor((16, 16), dtype="float32"), offset: R.Tensor((1,), dtype="int64") | |
# ) -> R.Tensor((16, 16), dtype="float32"): | |
# x = T.int64() | |
# tir_offset = T.int64() | |
# gv: R.Shape(ndim=1) = R.call_pure_packed( | |
# "vm.builtin.tensor_to_shape", offset, sinfo_args=(R.Shape(ndim=1),) | |
# ) | |
# y: R.Shape([x]) = R.match_cast(gv, R.Shape([x])) | |
# offset_shape: R.Shape([x]) = R.shape([x]) | |
# _: R.Shape([tir_offset]) = R.match_cast(offset_shape, R.Shape([tir_offset])) | |
# tril: R.Tensor((16, 16), dtype="float32") = R.tril(A, R.prim_value(tir_offset)) | |
# return tril | |
return func | |
def _tril_by_1d_tensor_with_explicit_tensor_to_shape(): | |
"""Works | |
Explicitly write the output of `DecomposeOpsForInference`, to | |
avoid a round-trip of `R.match_cast` to `R.shape` to | |
R.match_cast`, as that seems to cause an undefined variable | |
in `VMShapeLower`. | |
Requires a bugfix in `python/tvm/relax/utils.py`. In | |
`_convert_te_arg_helper` should check PrimStructInfo for TIR | |
variables that need to be passed in. | |
if ( | |
isinstance(arg.struct_info, PrimStructInfo) | |
and arg.struct_info.value is not None | |
): | |
return _convert_te_arg_helper(arg.struct_info.value) | |
""" | |
@R.function | |
def func( | |
A: R.Tensor((16, 16), dtype="float32"), offset: R.Tensor((1,), dtype="int64") | |
) -> R.Tensor((16, 16), dtype="float32"): | |
tir_offset = T.int64() | |
offset_shape = R.call_pure_packed( | |
"vm.builtin.tensor_to_shape", offset, sinfo_args=(R.Shape(ndim=1),) | |
) | |
_ = R.match_cast(offset_shape, R.Shape([tir_offset])) | |
tril = R.tril(A, R.prim_value(tir_offset)) | |
return tril | |
return func | |
def _tril_by_0d_tensor_with_explicit_tensor_to_shape(): | |
"""Almost works | |
Explicitly write the output of `DecomposeOpsForInference`, to | |
avoid a round-trip of `R.match_cast` to `R.shape` to | |
R.match_cast`, as that seems to cause an undefined variable | |
in `VMShapeLower`. | |
Requires a bugfix in `python/tvm/relax/utils.py`. In | |
`_convert_te_arg_helper` should check PrimStructInfo for TIR | |
variables that need to be passed in. | |
if ( | |
isinstance(arg.struct_info, PrimStructInfo) | |
and arg.struct_info.value is not None | |
): | |
return _convert_te_arg_helper(arg.struct_info.value) | |
""" | |
@R.function | |
def func( | |
A: R.Tensor((16, 16), dtype="float32"), offset: R.Tensor([], dtype="int64") | |
) -> R.Tensor((16, 16), dtype="float32"): | |
tir_offset = T.int64() | |
offset_1d: R.Tensor([1], dtype="int64") = R.full([1], offset) | |
offset_shape = R.call_pure_packed( | |
"vm.builtin.tensor_to_shape", offset_1d, sinfo_args=(R.Shape(ndim=1),) | |
) | |
_ = R.match_cast(offset_shape, R.Shape([tir_offset])) | |
tril = R.tril(A, R.prim_value(tir_offset)) | |
return tril | |
return func | |
generate_tril_func = tvm.testing.parameter( | |
_tril_by_prim_value, | |
_tril_by_prim_value_with_symbolic_var, | |
_tril_by_prim_expr, | |
_tril_by_0d_tensor, | |
_tril_by_0d_tensor_with_tensor_to_shape, | |
_tril_by_1d_tensor_with_tensor_to_shape_and_prim_value, | |
_tril_by_1d_tensor_with_tensor_to_shape, | |
_tril_by_1d_tensor_with_tensor_to_shape_and_index, | |
_tril_by_1d_tensor_with_tensor_to_shape_and_match_cast, | |
_tril_by_1d_tensor_with_tensor_to_shape_and_match_cast_with_decompose, | |
_tril_by_1d_tensor_with_explicit_tensor_to_shape, | |
_tril_by_0d_tensor_with_explicit_tensor_to_shape, | |
) | |
def test_tril(generate_tril_func): | |
func = generate_tril_func() | |
dtype = func.params[0].struct_info.dtype | |
A_np = np.random.uniform(low=0, high=255, size=[16, 16]).astype(dtype) | |
B_np = np.random.uniform(low=2, high=15, size=[]).astype("int64")[()] | |
mod = tvm.IRModule.from_expr(func) | |
built_mod = relax.build(mod, target="llvm") | |
vm = relax.vm.VirtualMachine(built_mod, tvm.cpu(0)) | |
built_func = vm["func"] | |
A_tvm = tvm.nd.array(A_np) | |
B_sinfo = func.params[1].struct_info | |
if isinstance(B_sinfo, relax.TensorStructInfo) and B_sinfo.ndim == 0: | |
B_tvm = tvm.nd.array(np.array(B_np)) | |
elif isinstance(B_sinfo, relax.TensorStructInfo) and B_sinfo.ndim == 1: | |
B_tvm = tvm.nd.array(np.array([B_np])) | |
elif isinstance(B_sinfo, relax.PrimStructInfo): | |
B_tvm = B_np | |
C_np = np.tril(A_np, B_np) | |
C_relax = built_func(A_tvm, B_tvm).numpy() | |
tvm.testing.assert_allclose(C_np, C_relax) | |
if __name__ == "__main__": | |
tvm.testing.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment