Skip to content

Instantly share code, notes, and snippets.

@ScottTodd
Created June 10, 2024 21:45
Show Gist options
  • Save ScottTodd/1e95795e79d17964078217ca98a3a398 to your computer and use it in GitHub Desktop.
Save ScottTodd/1e95795e79d17964078217ca98a3a398 to your computer and use it in GitHub Desktop.
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import iree.runtime as ireert
# import pytest
import logging
import shark_turbine.aot as aot
import torch
logger = logging.getLogger(__name__)
class TestIndexPut:
def test_single_value(self):
class TorchModule(torch.nn.Module):
def forward(self, input):
input.index_put_(
# Insert at position [1, 2]
indices=[torch.tensor([1]), torch.tensor([2])],
# Insert a single value
values=torch.tensor([0.5]),
)
return input
torch_module = TorchModule()
test_input = torch.zeros(3, 4)
expected_output = torch.tensor(
[ # col 0 col 1 col 2 col 3
[0.0000, 0.0000, 0.0000, 0.0000], # row 0
[0.0000, 0.0000, 0.5000, 0.0000], # row 1
[0.0000, 0.0000, 0.0000, 0.0000], # row 2
]
)
pytorch_output = torch_module.forward(test_input)
# TODO(scotttodd): helper function to assert and print if assert failed
# logger.info(pytorch_output)
torch.testing.assert_close(pytorch_output, expected_output)
# # ---------------------------------------------------------------------
# # Export to MLIR using FxProgramsBuilder.
# fxb = aot.FxProgramsBuilder(torch_module)
# @fxb.export_program(args=(test_input,))
# def index_put_single_value(module, input):
# return module(input)
# exported_module = aot.export(fxb)
# logger.info(exported_module.mlir_module)
# # ---------------------------------------------------------------------
# ---------------------------------------------------------------------
# Export to MLIR using aot.export (based on torch.export).
# exported_program = torch.export.export(torch_module, args=(test_input,))
exported_module = aot.export(torch_module, args=(test_input,))
logger.info(exported_module.mlir_module)
exported_module.save_mlir(
"D:/dev/projects/iree-data/tests/index_put/index_put_single_value.mlir"
)
# ---------------------------------------------------------------------
# ---------------------------------------------------------------------
# Compile through IREE.
compiled_module = exported_module.compile(
save_to=None, target_backends="llvm-cpu"
)
config = ireert.Config("local-sync")
vm_module = ireert.load_vm_module(
ireert.VmModule.wrap_buffer(
config.vm_instance, compiled_module.map_memory()
),
config,
)
iree_output = vm_module.main(test_input.numpy())
logger.info(iree_output.to_host())
def test_multiple_values(self):
class TorchModule(torch.nn.Module):
def forward(self, input):
input.index_put_(
# Insert at positions [0, 3], [1, 4], and [2, 5]
indices=[torch.tensor([0, 1, 2]), torch.tensor([3, 4, 5])],
# Insert a unique value to each position
values=torch.tensor([0.1, 0.2, 0.3]),
)
return input
torch_module = TorchModule()
test_input = torch.zeros(3, 6)
expected_output = torch.tensor(
[ # col 0 col 1 col 2 col 3 col 4 col 5
[0.0000, 0.0000, 0.0000, 0.1000, 0.0000, 0.0000], # row 0
[0.0000, 0.0000, 0.0000, 0.0000, 0.2000, 0.0000], # row 1
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3000], # row 2
]
)
pytorch_output = torch_module.forward(test_input)
torch.testing.assert_close(pytorch_output, expected_output)
# Export to MLIR using aot.export (based on torch.export).
exported_module = aot.export(torch_module, args=(test_input,))
logger.info(exported_module.mlir_module)
exported_module.save_mlir(
"D:/dev/projects/iree-data/tests/index_put/index_put_multiple_values.mlir"
)
# Compile through IREE.
compiled_module = exported_module.compile(
save_to=None, target_backends="llvm-cpu"
)
config = ireert.Config("local-sync")
vm_module = ireert.load_vm_module(
ireert.VmModule.wrap_buffer(
config.vm_instance, compiled_module.map_memory()
),
config,
)
iree_output = vm_module.main(test_input.numpy())
logger.info(iree_output.to_host())
def test_broadcast_value_along_axis(self):
class TorchModule(torch.nn.Module):
def forward(self, input):
input.index_put_(
# Insert at positions [1, *] (broadcast)
indices=[torch.tensor([1])],
# Insert a single value to all positions (broadcast)
values=torch.tensor([0.5]),
)
return input
torch_module = TorchModule()
test_input = torch.zeros(3, 4)
expected_output = torch.tensor(
[ # col 0 col 1 col 2 col 3
[0.0000, 0.0000, 0.0000, 0.0000], # row 0
[0.5000, 0.5000, 0.5000, 0.5000], # row 1
[0.0000, 0.0000, 0.0000, 0.0000], # row 2
]
)
pytorch_output = torch_module.forward(test_input)
torch.testing.assert_close(pytorch_output, expected_output)
# Export to MLIR using aot.export (based on torch.export).
exported_module = aot.export(torch_module, args=(test_input,))
logger.info(exported_module.mlir_module)
exported_module.save_mlir(
"D:/dev/projects/iree-data/tests/index_put/index_put_broadcast_value_along_axis.mlir"
)
# Compile through IREE.
compiled_module = exported_module.compile(
save_to=None, target_backends="llvm-cpu"
)
config = ireert.Config("local-sync")
vm_module = ireert.load_vm_module(
ireert.VmModule.wrap_buffer(
config.vm_instance, compiled_module.map_memory()
),
config,
)
iree_output = vm_module.main(test_input.numpy())
logger.info(iree_output.to_host())
def test_broadcast_value_along_indices(self):
class TorchModule(torch.nn.Module):
def forward(self, input):
input.index_put_(
# Insert at positions [0, 3], [1, 4], and [2, 5]
indices=[torch.tensor([0, 1, 2]), torch.tensor([3, 4, 5])],
# Insert a single value to all positions (broadcast)
values=torch.tensor([0.5]),
)
return input
torch_module = TorchModule()
test_input = torch.zeros(3, 6)
expected_output = torch.tensor(
[ # col 0 col 1 col 2 col 3 col 4 col 5
[0.0000, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000], # row 0
[0.0000, 0.0000, 0.0000, 0.0000, 0.5000, 0.0000], # row 1
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5000], # row 2
]
)
pytorch_output = torch_module.forward(test_input)
torch.testing.assert_close(pytorch_output, expected_output)
# Export to MLIR using aot.export (based on torch.export).
exported_module = aot.export(torch_module, args=(test_input,))
logger.info(exported_module.mlir_module)
exported_module.save_mlir(
"D:/dev/projects/iree-data/tests/index_put/index_put_broadcast_value_along_indices.mlir"
)
# Compile through IREE.
compiled_module = exported_module.compile(
save_to=None, target_backends="llvm-cpu"
)
config = ireert.Config("local-sync")
vm_module = ireert.load_vm_module(
ireert.VmModule.wrap_buffer(
config.vm_instance, compiled_module.map_memory()
),
config,
)
iree_output = vm_module.main(test_input.numpy())
logger.info(iree_output.to_host())
def test_broadcast_values_along_axis(self):
class TorchModule(torch.nn.Module):
def forward(self, input):
input.index_put_(
# Insert at positions [1, *] (broadcast)
indices=[torch.tensor([1])],
# Insert a unique value to each position
values=torch.tensor([0.1, 0.2, 0.3, 0.4]),
)
return input
torch_module = TorchModule()
test_input = torch.zeros(3, 4)
expected_output = torch.tensor(
[ # col 0 col 1 col 2 col 3
[0.0000, 0.0000, 0.0000, 0.0000], # row 0
[0.1000, 0.2000, 0.3000, 0.4000], # row 1
[0.0000, 0.0000, 0.0000, 0.0000], # row 2
]
)
pytorch_output = torch_module.forward(test_input)
torch.testing.assert_close(pytorch_output, expected_output)
# Export to MLIR using aot.export (based on torch.export).
exported_module = aot.export(torch_module, args=(test_input,))
logger.info(exported_module.mlir_module)
exported_module.save_mlir(
"D:/dev/projects/iree-data/tests/index_put/index_put_broadcast_values_along_axis.mlir"
)
# Compile through IREE.
compiled_module = exported_module.compile(
save_to=None, target_backends="llvm-cpu"
)
config = ireert.Config("local-sync")
vm_module = ireert.load_vm_module(
ireert.VmModule.wrap_buffer(
config.vm_instance, compiled_module.map_memory()
),
config,
)
iree_output = vm_module.main(test_input.numpy())
logger.info(iree_output.to_host())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment