Skip to content

Instantly share code, notes, and snippets.

@scturtle
Last active May 12, 2023 09:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save scturtle/11c007af25214e7edb24faaff0983532 to your computer and use it in GitHub Desktop.
Save scturtle/11c007af25214e7edb24faaff0983532 to your computer and use it in GitHub Desktop.
dig into torch-mlir
import torch
from torch import nn
# import torch_mlir
from torch_mlir.passmanager import PassManager
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder
torch.manual_seed(42)
module = nn.Sequential(
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 10),
nn.LogSoftmax(1),
)
module = torch.jit.script(module)
arg = torch.ones(32, 784)
# module = torch.jit.trace(module, torch.ones(32, 784))
# module = torch.jit.trace_module(module, dict(forward=arg))
### module = torch_mlir.compile(module, torch.ones(32, 784), output_type="raw")
class_annotator = ClassAnnotator()
class_annotator.exportNone(module._c._type())
class_annotator.exportPath(module._c._type(), ["forward"])
annotations = [None, (arg.shape, arg.dtype, True)] # self
class_annotator.annotateArgs(module._c._type(), ["forward"], annotations)
mb = ModuleBuilder()
mb.import_module(module._c, class_annotator)
module = mb.module
# module = torch_mlir.compile(module, torch.ones(32, 784), output_type="torch")
with module.context:
# option_string = "{backend-legal-ops=aten.warn}"
# pm = PassManager.parse(f"builtin.module(torchscript-module-to-torch-backend-pipeline{option_string})")
pm = PassManager.parse("builtin.module(torchscript-module-to-torch-backend-pipeline)")
pm.run(module.operation)
### module = torch_mlir.compile(module, torch.ones(32, 784), output_type="linalg_on_tensors")
with module.context:
pm = PassManager.parse("builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)")
pm.run(module.operation)
# from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import LOWERING_PIPELINE
LOWERING_PIPELINE = "builtin.module(" + ",".join([
"func.func(refback-generalize-tensor-pad)",
"func.func(linalg-fuse-elementwise-ops)",
# Bufferize.
"func.func(scf-bufferize)",
"func.func(tm-tensor-bufferize)",
"func.func(empty-tensor-to-alloc-tensor)",
"func.func(linalg-bufferize)",
"func-bufferize",
"arith-bufferize",
"refback-mlprogram-bufferize",
"func.func(tensor-bufferize)",
"func.func(finalizing-bufferize)",
"func.func(buffer-deallocation)",
]) + ")"
with module.context:
pm = PassManager.parse(LOWERING_PIPELINE)
pm.run(module.operation)
# print(module)
striped = module.operation.get_asm(large_elements_limit=10, enable_debug_info=False)
print(striped)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment