-
-
Save ScottTodd/1e95795e79d17964078217ca98a3a398 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2024 Advanced Micro Devices, Inc. | |
# | |
# Licensed under the Apache License v2.0 with LLVM Exceptions. | |
# See https://llvm.org/LICENSE.txt for license information. | |
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |
import iree.runtime as ireert | |
# import pytest | |
import logging | |
import shark_turbine.aot as aot | |
import torch | |
logger = logging.getLogger(__name__) | |
class TestIndexPut: | |
def test_single_value(self): | |
class TorchModule(torch.nn.Module): | |
def forward(self, input): | |
input.index_put_( | |
# Insert at position [1, 2] | |
indices=[torch.tensor([1]), torch.tensor([2])], | |
# Insert a single value | |
values=torch.tensor([0.5]), | |
) | |
return input | |
torch_module = TorchModule() | |
test_input = torch.zeros(3, 4) | |
expected_output = torch.tensor( | |
[ # col 0 col 1 col 2 col 3 | |
[0.0000, 0.0000, 0.0000, 0.0000], # row 0 | |
[0.0000, 0.0000, 0.5000, 0.0000], # row 1 | |
[0.0000, 0.0000, 0.0000, 0.0000], # row 2 | |
] | |
) | |
pytorch_output = torch_module.forward(test_input) | |
# TODO(scotttodd): helper function to assert and print if assert failed | |
# logger.info(pytorch_output) | |
torch.testing.assert_close(pytorch_output, expected_output) | |
# # --------------------------------------------------------------------- | |
# # Export to MLIR using FxProgramsBuilder. | |
# fxb = aot.FxProgramsBuilder(torch_module) | |
# @fxb.export_program(args=(test_input,)) | |
# def index_put_single_value(module, input): | |
# return module(input) | |
# exported_module = aot.export(fxb) | |
# logger.info(exported_module.mlir_module) | |
# # --------------------------------------------------------------------- | |
# --------------------------------------------------------------------- | |
# Export to MLIR using aot.export (based on torch.export). | |
# exported_program = torch.export.export(torch_module, args=(test_input,)) | |
exported_module = aot.export(torch_module, args=(test_input,)) | |
logger.info(exported_module.mlir_module) | |
exported_module.save_mlir( | |
"D:/dev/projects/iree-data/tests/index_put/index_put_single_value.mlir" | |
) | |
# --------------------------------------------------------------------- | |
# --------------------------------------------------------------------- | |
# Compile through IREE. | |
compiled_module = exported_module.compile( | |
save_to=None, target_backends="llvm-cpu" | |
) | |
config = ireert.Config("local-sync") | |
vm_module = ireert.load_vm_module( | |
ireert.VmModule.wrap_buffer( | |
config.vm_instance, compiled_module.map_memory() | |
), | |
config, | |
) | |
iree_output = vm_module.main(test_input.numpy()) | |
logger.info(iree_output.to_host()) | |
def test_multiple_values(self): | |
class TorchModule(torch.nn.Module): | |
def forward(self, input): | |
input.index_put_( | |
# Insert at positions [0, 3], [1, 4], and [2, 5] | |
indices=[torch.tensor([0, 1, 2]), torch.tensor([3, 4, 5])], | |
# Insert a unique value to each position | |
values=torch.tensor([0.1, 0.2, 0.3]), | |
) | |
return input | |
torch_module = TorchModule() | |
test_input = torch.zeros(3, 6) | |
expected_output = torch.tensor( | |
[ # col 0 col 1 col 2 col 3 col 4 col 5 | |
[0.0000, 0.0000, 0.0000, 0.1000, 0.0000, 0.0000], # row 0 | |
[0.0000, 0.0000, 0.0000, 0.0000, 0.2000, 0.0000], # row 1 | |
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3000], # row 2 | |
] | |
) | |
pytorch_output = torch_module.forward(test_input) | |
torch.testing.assert_close(pytorch_output, expected_output) | |
# Export to MLIR using aot.export (based on torch.export). | |
exported_module = aot.export(torch_module, args=(test_input,)) | |
logger.info(exported_module.mlir_module) | |
exported_module.save_mlir( | |
"D:/dev/projects/iree-data/tests/index_put/index_put_multiple_values.mlir" | |
) | |
# Compile through IREE. | |
compiled_module = exported_module.compile( | |
save_to=None, target_backends="llvm-cpu" | |
) | |
config = ireert.Config("local-sync") | |
vm_module = ireert.load_vm_module( | |
ireert.VmModule.wrap_buffer( | |
config.vm_instance, compiled_module.map_memory() | |
), | |
config, | |
) | |
iree_output = vm_module.main(test_input.numpy()) | |
logger.info(iree_output.to_host()) | |
def test_broadcast_value_along_axis(self): | |
class TorchModule(torch.nn.Module): | |
def forward(self, input): | |
input.index_put_( | |
# Insert at positions [1, *] (broadcast) | |
indices=[torch.tensor([1])], | |
# Insert a single value to all positions (broadcast) | |
values=torch.tensor([0.5]), | |
) | |
return input | |
torch_module = TorchModule() | |
test_input = torch.zeros(3, 4) | |
expected_output = torch.tensor( | |
[ # col 0 col 1 col 2 col 3 | |
[0.0000, 0.0000, 0.0000, 0.0000], # row 0 | |
[0.5000, 0.5000, 0.5000, 0.5000], # row 1 | |
[0.0000, 0.0000, 0.0000, 0.0000], # row 2 | |
] | |
) | |
pytorch_output = torch_module.forward(test_input) | |
torch.testing.assert_close(pytorch_output, expected_output) | |
# Export to MLIR using aot.export (based on torch.export). | |
exported_module = aot.export(torch_module, args=(test_input,)) | |
logger.info(exported_module.mlir_module) | |
exported_module.save_mlir( | |
"D:/dev/projects/iree-data/tests/index_put/index_put_broadcast_value_along_axis.mlir" | |
) | |
# Compile through IREE. | |
compiled_module = exported_module.compile( | |
save_to=None, target_backends="llvm-cpu" | |
) | |
config = ireert.Config("local-sync") | |
vm_module = ireert.load_vm_module( | |
ireert.VmModule.wrap_buffer( | |
config.vm_instance, compiled_module.map_memory() | |
), | |
config, | |
) | |
iree_output = vm_module.main(test_input.numpy()) | |
logger.info(iree_output.to_host()) | |
def test_broadcast_value_along_indices(self): | |
class TorchModule(torch.nn.Module): | |
def forward(self, input): | |
input.index_put_( | |
# Insert at positions [0, 3], [1, 4], and [2, 5] | |
indices=[torch.tensor([0, 1, 2]), torch.tensor([3, 4, 5])], | |
# Insert a single value to all positions (broadcast) | |
values=torch.tensor([0.5]), | |
) | |
return input | |
torch_module = TorchModule() | |
test_input = torch.zeros(3, 6) | |
expected_output = torch.tensor( | |
[ # col 0 col 1 col 2 col 3 col 4 col 5 | |
[0.0000, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000], # row 0 | |
[0.0000, 0.0000, 0.0000, 0.0000, 0.5000, 0.0000], # row 1 | |
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5000], # row 2 | |
] | |
) | |
pytorch_output = torch_module.forward(test_input) | |
torch.testing.assert_close(pytorch_output, expected_output) | |
# Export to MLIR using aot.export (based on torch.export). | |
exported_module = aot.export(torch_module, args=(test_input,)) | |
logger.info(exported_module.mlir_module) | |
exported_module.save_mlir( | |
"D:/dev/projects/iree-data/tests/index_put/index_put_broadcast_value_along_indices.mlir" | |
) | |
# Compile through IREE. | |
compiled_module = exported_module.compile( | |
save_to=None, target_backends="llvm-cpu" | |
) | |
config = ireert.Config("local-sync") | |
vm_module = ireert.load_vm_module( | |
ireert.VmModule.wrap_buffer( | |
config.vm_instance, compiled_module.map_memory() | |
), | |
config, | |
) | |
iree_output = vm_module.main(test_input.numpy()) | |
logger.info(iree_output.to_host()) | |
def test_broadcast_values_along_axis(self): | |
class TorchModule(torch.nn.Module): | |
def forward(self, input): | |
input.index_put_( | |
# Insert at positions [1, *] (broadcast) | |
indices=[torch.tensor([1])], | |
# Insert a unique value to each position | |
values=torch.tensor([0.1, 0.2, 0.3, 0.4]), | |
) | |
return input | |
torch_module = TorchModule() | |
test_input = torch.zeros(3, 4) | |
expected_output = torch.tensor( | |
[ # col 0 col 1 col 2 col 3 | |
[0.0000, 0.0000, 0.0000, 0.0000], # row 0 | |
[0.1000, 0.2000, 0.3000, 0.4000], # row 1 | |
[0.0000, 0.0000, 0.0000, 0.0000], # row 2 | |
] | |
) | |
pytorch_output = torch_module.forward(test_input) | |
torch.testing.assert_close(pytorch_output, expected_output) | |
# Export to MLIR using aot.export (based on torch.export). | |
exported_module = aot.export(torch_module, args=(test_input,)) | |
logger.info(exported_module.mlir_module) | |
exported_module.save_mlir( | |
"D:/dev/projects/iree-data/tests/index_put/index_put_broadcast_values_along_axis.mlir" | |
) | |
# Compile through IREE. | |
compiled_module = exported_module.compile( | |
save_to=None, target_backends="llvm-cpu" | |
) | |
config = ireert.Config("local-sync") | |
vm_module = ireert.load_vm_module( | |
ireert.VmModule.wrap_buffer( | |
config.vm_instance, compiled_module.map_memory() | |
), | |
config, | |
) | |
iree_output = vm_module.main(test_input.numpy()) | |
logger.info(iree_output.to_host()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment