Skip to content

Instantly share code, notes, and snippets.

@ScottTodd
Created June 12, 2024 15:34
Show Gist options
  • Save ScottTodd/a0c0e68d1abeb3240f782045c4c70e80 to your computer and use it in GitHub Desktop.
Save ScottTodd/a0c0e68d1abeb3240f782045c4c70e80 to your computer and use it in GitHub Desktop.
D:\dev\projects\SHARK-TestSuite\iree_tests\pytorch\operators (index-put-tests)
(.venv) λ pytest --log-cli-level=info index_put_test.py::TestIndexPut::test_single_value
======================================= test session starts ======================================= platform win32 -- Python 3.11.2, pytest-8.2.2, pluggy-1.5.0
rootdir: D:\dev\projects\SHARK-TestSuite\iree_tests
configfile: pytest.ini
plugins: reportlog-0.4.0, retry-1.6.3, timeout-2.3.1, xdist-3.6.1
collected 1 item
index_put_test.py::TestIndexPut::test_single_value
------------------------------------------ live log call ------------------------------------------ INFO index_put_test:index_put_test.py:62 module @module {
func.func @main(%arg0: !torch.tensor<[4,4],f32>) -> !torch.vtensor<[4,4],f32> attributes {torch.assume_strict_symbolic_shapes} {
%0 = torch.vtensor.literal(dense<2> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%1 = torch.vtensor.literal(dense<2> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%2 = torch.vtensor.literal(dense<5.000000e-01> : tensor<1xf32>) : !torch.vtensor<[1],f32>
%3 = torch.copy.to_vtensor %arg0 : !torch.vtensor<[4,4],f32>
%none = torch.constant.none
%4 = torch.aten.clone %0, %none : !torch.vtensor<[1],si64>, !torch.none -> !torch.vtensor<[1],si64>
%none_0 = torch.constant.none
%5 = torch.aten.clone %1, %none_0 : !torch.vtensor<[1],si64>, !torch.none -> !torch.vtensor<[1],si64>
%none_1 = torch.constant.none
%6 = torch.aten.clone %2, %none_1 : !torch.vtensor<[1],f32>, !torch.none -> !torch.vtensor<[1],f32>
%7 = torch.prim.ListConstruct %4, %5 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<optional<vtensor>>
%false = torch.constant.bool false
%8 = torch.aten.index_put %3, %7, %6, %false : !torch.vtensor<[4,4],f32>, !torch.list<optional<vtensor>>, !torch.vtensor<[1],f32>, !torch.bool -> !torch.vtensor<[4,4],f32>
torch.overwrite.tensor.contents %8 overwrites %arg0 : !torch.vtensor<[4,4],f32>, !torch.tensor<[4,4],f32>
return %8 : !torch.vtensor<[4,4],f32>
}
}
INFO index_put_test:index_put_test.py:82 [[0. 0. 0. 0. ]
[0. 0. 0. 0. ]
[0. 0. 0.5 0. ]
[0. 0. 0. 0. ]]
PASSED [100%]
======================================== 1 passed in 7.68s ======================================== Exception Code: 0xC0000005
#0 0x00007ffa79428b7b PyInit__runtime (D:\dev\projects\SHARK-TestSuite\iree_tests\pytorch\operators\.venv\Lib\site-packages\iree\_runtime_libs\_runtime.cp311-win_amd64.pyd+0x48b7b)
#1 0x00007ffa7949a1f1 PyInit__runtime (D:\dev\projects\SHARK-TestSuite\iree_tests\pytorch\operators\.venv\Lib\site-packages\iree\_runtime_libs\_runtime.cp311-win_amd64.pyd+0xba1f1)
#2 0x00007ffa7949a9a7 PyInit__runtime (D:\dev\projects\SHARK-TestSuite\iree_tests\pytorch\operators\.venv\Lib\site-packages\iree\_runtime_libs\_runtime.cp311-win_amd64.pyd+0xba9a7)
#3 0x00007ffa7949af44 PyInit__runtime (D:\dev\projects\SHARK-TestSuite\iree_tests\pytorch\operators\.venv\Lib\site-packages\iree\_runtime_libs\_runtime.cp311-win_amd64.pyd+0xbaf44)
#4 0x00007ffa7940f5ad PyInit__runtime (D:\dev\projects\SHARK-TestSuite\iree_tests\pytorch\operators\.venv\Lib\site-packages\iree\_runtime_libs\_runtime.cp311-win_amd64.pyd+0x2f5ad)
#5 0x00007ffa794b600c nanobind::python_error::what(void) const (D:\dev\projects\SHARK-TestSuite\iree_tests\pytorch\operators\.venv\Lib\site-packages\iree\_runtime_libs\_runtime.cp311-win_amd64.pyd+0xd600c)
#6 0x00007ffa768576ce C:\Program Files\Python311\python311.dll 0xc76ce C:\Program Files\Python311\python311.dll 0xc7d86
#7 0x00007ffa768576ce C:\Program Files\Python311\python311.dll 0xc811a C:\Program Files\Python311\python311.dll 0xd1a2
#8 0x00007ffa768576ce C:\Program Files\Python311\python311.dll 0xc829 C:\Program Files\Python311\python311.dll 0x872ee
#9 0x00007ffa768576ce C:\Program Files\Python311\python311.dll 0x87671 C:\Program Files\Python311\python311.dll 0x85eaa
#10 0x00007ffa768576ce C:\Program Files\Python311\python311.dll 0x3d21 C:\Program Files\Python311\python.exe 0x1230
#11 0x00007ffa768576ce (C:\Program Files\Python311\python311.dll+0xc76ce)
#12 0x00007ffa76857d86 (C:\Program Files\Python311\python311.dll+0xc7d86)
0x00007FFA79428B7B, D:\dev\projects\SHARK-TestSuite\iree_tests\pytorch\operators\.venv\Lib\site-packages\iree\_runtime_libs\_runtime.cp311-win_amd64.pyd(0x00007FFA793E0000) + 0x48B7B byte(s), PyInit__runtime() + 0x470CB byte(s)
0x00007FFA7949A1F1, D:\dev\projects\SHARK-TestSuite\iree_tests\pytorch\operators\.venv\Lib\site-packages\iree\_runtime_libs\_runtime.cp311-win_amd64.pyd(0x00007FFA793E0000) + 0xBA1F1 byte(s), PyInit__runtime() + 0xB8741 byte(s)
0x00007FFA7949A9A7, D:\dev\projects\SHARK-TestSuite\iree_tests\pytorch\operators\.venv\Lib\site-packages\iree\_runtime_libs\_runtime.cp311-win_amd64.pyd(0x00007FFA793E0000) + 0xBA9A7 byte(s), PyInit__runtime() + 0xB8EF7 byte(s)
0x00007FFA7949AF44, D:\dev\projects\SHARK-TestSuite\iree_tests\pytorch\operators\.venv\Lib\site-packages\iree\_runtime_libs\_runtime.cp311-win_amd64.pyd(0x00007FFA793E0000) + 0xBAF44 byte(s), PyInit__runtime() + 0xB9494 byte(s)
0x00007FFA7940F5AD, D:\dev\projects\SHARK-TestSuite\iree_tests\pytorch\operators\.venv\Lib\site-packages\iree\_runtime_libs\_runtime.cp311-win_amd64.pyd(0x00007FFA793E0000) + 0x2F5AD byte(s), PyInit__runtime() + 0x2DAFD byte(s)
0x00007FFA794B600C, D:\dev\projects\SHARK-TestSuite\iree_tests\pytorch\operators\.venv\Lib\site-packages\iree\_runtime_libs\_runtime.cp311-win_amd64.pyd(0x00007FFA793E0000) + 0xD600C byte(s), ?what@python_error@nanobind@@UEBAPEBDXZ() + 0x904C byte(s)
0x00007FFA768576CE, C:\Program Files\Python311\python311.dll(0x00007FFA76790000) + 0xC76CE byte(s), _PyBytes_Repeat() + 0x24E byte(s)
0x00007FFA76857D86, C:\Program Files\Python311\python311.dll(0x00007FFA76790000) + 0xC7D86 byte(s), _PyBytes_Repeat() + 0x906 byte(s)
0x00007FFA7685811A, C:\Program Files\Python311\python311.dll(0x00007FFA76790000) + 0xC811A byte(s), PyObject_ClearWeakRefs() + 0x1B2 byte(s)
0x00007FFA7679D1A2, C:\Program Files\Python311\python311.dll(0x00007FFA76790000) + 0xD1A2 byte(s), PyLong_AsLongLong() + 0xA82 byte(s)
0x00007FFA7679C829, C:\Program Files\Python311\python311.dll(0x00007FFA76790000) + 0xC829 byte(s), PyLong_AsLongLong() + 0x109 byte(s)
0x00007FFA768172EE, C:\Program Files\Python311\python311.dll(0x00007FFA76790000) + 0x872EE byte(s), PyGC_Collect() + 0x6A byte(s)
0x00007FFA76817671, C:\Program Files\Python311\python311.dll(0x00007FFA76790000) + 0x87671 byte(s), Py_FinalizeEx() + 0x99 byte(s)
0x00007FFA76815EAA, C:\Program Files\Python311\python311.dll(0x00007FFA76790000) + 0x85EAA byte(s), Py_RunMain() + 0x1A byte(s)
0x00007FFA76793D21, C:\Program Files\Python311\python311.dll(0x00007FFA76790000) + 0x3D21 byte(s), Py_Main() + 0x25 byte(s)
0x00007FF6D2AB1230, C:\Program Files\Python311\python.exe(0x00007FF6D2AB0000) + 0x1230 byte(s)
0x00007FFB81117344, C:\WINDOWS\System32\KERNEL32.DLL(0x00007FFB81100000) + 0x17344 byte(s), BaseThreadInitThunk() + 0x14 byte(s)
0x00007FFB821626B1, C:\WINDOWS\SYSTEM32\ntdll.dll(0x00007FFB82110000) + 0x526B1 byte(s), RtlUserThreadStart() + 0x21 byte(s)
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import iree.runtime as ireert
# import pytest
import logging
import shark_turbine.aot as aot
import torch
logger = logging.getLogger(__name__)
class TestIndexPut:
def test_single_value(self):
class TorchModule(torch.nn.Module):
def forward(self, input):
input.index_put_(
# Insert at position [2, 0]
indices=[torch.tensor([2]), torch.tensor([2])],
# Insert a single value
values=torch.tensor([0.5]),
)
return input
torch_module = TorchModule()
test_input = torch.zeros(4, 4)
expected_output = torch.tensor(
[ # col 0 col 1 col 2 col 3
[0.0000, 0.0000, 0.0000, 0.0000], # row 0
[0.0000, 0.0000, 0.0000, 0.0000], # row 1
[0.0000, 0.0000, 0.5000, 0.0000], # row 2
[0.0000, 0.0000, 0.0000, 0.0000], # row 3
]
)
pytorch_output = torch_module.forward(test_input)
# TODO(scotttodd): helper function to assert and print if assert failed
# logger.info(pytorch_output)
torch.testing.assert_close(pytorch_output, expected_output)
# # ---------------------------------------------------------------------
# # Export to MLIR using FxProgramsBuilder.
# fxb = aot.FxProgramsBuilder(torch_module)
# @fxb.export_program(args=(test_input,))
# def index_put_single_value(module, input):
# return module(input)
# exported_module = aot.export(fxb)
# logger.info(exported_module.mlir_module)
# # ---------------------------------------------------------------------
# ---------------------------------------------------------------------
# Export to MLIR using aot.export (based on torch.export).
# exported_program = torch.export.export(torch_module, args=(test_input,))
exported_module = aot.export(torch_module, args=(test_input,))
logger.info(exported_module.mlir_module)
exported_module.save_mlir(
"D:/dev/projects/iree-data/tests/index_put/index_put_single_value.mlir"
)
# ---------------------------------------------------------------------
# ---------------------------------------------------------------------
# Compile through IREE.
compiled_module = exported_module.compile(
save_to=None, target_backends="llvm-cpu"
)
config = ireert.Config("local-sync")
vm_module = ireert.load_vm_module(
ireert.VmModule.wrap_buffer(
config.vm_instance, compiled_module.map_memory()
),
config,
)
iree_output = vm_module.main(test_input.numpy())
logger.info(iree_output.to_host())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment