Created
May 24, 2021 17:26
-
-
Save stellaraccident/3c17025ae9ed72c48f0e16bf6ec517f1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from mlir.ir import * | |
from mlir.dialects.builtin import * | |
from mlir.dialects.tosa import * | |
from mlir.passmanager import * | |
import mlir.dialects.sparse_tensor as st | |
import mlir.conversions | |
def sparse_tensor(shape, levels=None, ordering=None, dtype=None): | |
rank = len(shape) | |
if not levels: | |
levels = [st.DimLevelType.compressed] * rank | |
if not ordering: | |
ordering = AffineMap.get_identity(rank) | |
encoding = st.EncodingAttr.get(levels, ordering, 32, 32) | |
return RankedTensorType.get(shape, | |
dtype if dtype else F32Type.get(), encoding=encoding) | |
def dense_tensor(shape, dtype=None): | |
return RankedTensorType.get(shape, | |
dtype if dtype else F32Type.get()) | |
def create_sample_fc_module(): | |
m = Module.create() | |
with InsertionPoint(m.body): | |
@FuncOp.from_py_func( | |
dense_tensor([256, 1024]), | |
sparse_tensor([64, 1024]), | |
dense_tensor([64])) | |
def fc(inputs, weights, bias): | |
d0 = RankedTensorType(inputs.type).get_dim_size(0) | |
d1 = RankedTensorType(weights.type).get_dim_size(0) | |
result_type = dense_tensor([d0, d1]) | |
return FullyConnectedOp( | |
result_type, | |
input=inputs, weight=weights, bias=bias, | |
quantization_info=None).result | |
return m | |
with Context() as ctx, Location.unknown(): | |
m = create_sample_fc_module() | |
print("// Input module") | |
print(m) | |
pm = PassManager.parse("func(tosa-to-linalg-on-tensors)") | |
pm.run(m) | |
print("\n\n// Post linalg conversion") | |
print(m) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment