Skip to content

Instantly share code, notes, and snippets.

@qedawkins
Created August 18, 2022 17:10
Show Gist options
  • Save qedawkins/da21e6619a39f543a64a19b8db79c4b3 to your computer and use it in GitHub Desktop.
Save qedawkins/da21e6619a39f543a64a19b8db79c4b3 to your computer and use it in GitHub Desktop.

Python Script

Script can be found at examples/onnx_logsoftmax.py

import torch
import torch_mlir

class ToyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input: torch.Tensor):
        return torch.nn.functional.log_softmax(input, dim=1)

model = ToyModel()

module = torch_mlir.compile(model, torch.ones(1, 3, 224, 224), output_type="onnx")
print(module)

IR Dump

Before Decomposition

module {
  func.func @main_graph(%arg0: tensor<1x3x224x224xf32>) -> tensor<1x3x224x224xf32> attributes {input_names = ["input.1"], output_names = ["ret"]} {
    %0 = "onnx.LogSoftmax"(%arg0) {axis = 1 : si64, onnx_node_name = "LogSoftmax_0"} : (tensor<1x3x224x224xf32>) -> tensor<1x3x224x224xf32>
    return %0 : tensor<1x3x224x224xf32>
  }
  "onnx.EntryPoint"() {func = @main_graph} : () -> ()
}

After Decomposition

module {
  func.func @main_graph(%arg0: tensor<1x3x224x224xf32>) -> tensor<1x3x224x224xf32> attributes {input_names = ["input.1"], output_names = ["ret"]} {
    %0 = "onnx.Softmax"(%arg0) {axis = 1 : si64} : (tensor<1x3x224x224xf32>) -> tensor<*xf32>
    %1 = "onnx.Log"(%0) : (tensor<*xf32>) -> tensor<1x3x224x224xf32>
    return %1 : tensor<1x3x224x224xf32>
  }
  "onnx.EntryPoint"() {func = @main_graph} : () -> ()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment