Skip to content

Instantly share code, notes, and snippets.

@qedawkins
Last active August 19, 2022 23:35
Show Gist options
  • Save qedawkins/d0129185e7187b8ee2e042a76cca8f70 to your computer and use it in GitHub Desktop.
Save qedawkins/d0129185e7187b8ee2e042a76cca8f70 to your computer and use it in GitHub Desktop.

Python Script

Script can be found at examples/onnx_add.py

import torch
import torch_mlir
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend

class ToyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor, y: torch.Tensor):
        return x + y 

model = ToyModel()

x = torch.ones(1, 3)
y = torch.ones(1, 3)

module = torch_mlir.compile(model, (x, x), output_type="onnx")
print(module)

backend = refbackend.RefBackendLinalgOnTensorsBackend()
compiled = backend.compile(module)
jit_module = backend.load(compiled)
print(jit_module.main_graph(x.numpy(), y.numpy()))

IR Dump

ONNX Dialect

module {
  func.func @main_graph(%arg0: tensor<1x3x224x224xf32>) -> tensor<1x3x224x224xf32> attributes {input_names = ["input.1"], output_names = ["1"]} {
    %0 = "onnx.Add"(%arg0, %arg0) {onnx_node_name = "Add_0"} : (tensor<1x3x224x224xf32>, tensor<1x3x224x224xf32>) -> tensor<1x3x224x224xf32>
    return %0 : tensor<1x3x224x224xf32>
  }
  "onnx.EntryPoint"() {func = @main_graph} : () -> ()
}

Linalg

{function_type = (!torch.tensor<[1,3],f32>, !torch.tensor<[1,3],f32>) -> !torch.tensor<[1,3],f32>, input_names = ["x.1", "y.1"], output_names = ["2"], sym_name = "main_graph"}
#map0 = affine_map<(d0, d1) -> (0, d1)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
module {
  func.func @main_graph(%arg0: tensor<1x3xf32>, %arg1: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {input_names = ["x.1", "y.1"], output_names = ["2"]} {
    %0 = linalg.init_tensor [1, 3] : tensor<1x3xf32>
    %1 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<1x3xf32>, tensor<1x3xf32>) outs(%0 : tensor<1x3xf32>) {
    ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):
      %2 = arith.addf %arg2, %arg3 : f32
      linalg.yield %2 : f32
    } -> tensor<1x3xf32>
    return %1 : tensor<1x3xf32>
  }
}

Output

[[2. 2. 2.]]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment