Skip to content

Instantly share code, notes, and snippets.

@BruceDai003
Created May 9, 2024 05:56
Show Gist options
  • Save BruceDai003/3ed46af13cee64e182f1e763c4ccaf28 to your computer and use it in GitHub Desktop.
Save BruceDai003/3ed46af13cee64e182f1e763c4ccaf28 to your computer and use it in GitHub Desktop.
Simple [0,1,2,3]*2 model using iree
import torch
import torch.nn as nn
import os
import numpy as np
import shark_turbine.aot as aot
torch.manual_seed(42)
torch.cuda.manual_seed(42)
np.random.seed(42)
class MLP(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor):
x = x*2
return x
model = MLP()
example_x = torch.empty(4, dtype=torch.float32)
exported = aot.export(model, example_x)
exported.print_readable()
out_dir = "notes/mlir/mul_2"
os.makedirs(out_dir, exist_ok=True)
exported.save_mlir(f"{out_dir}/mul_2.mlir")
compiled_binary = exported.compile(save_to=None)
def infer():
import iree.runtime as rt
config = rt.Config("local-task")
vmm = rt.load_vm_module(
rt.VmModule.wrap_buffer(
config.vm_instance, compiled_binary.map_memory()),
config,
)
x = np.arange(4).astype(np.float32) # definitive inputs
y = vmm.main(x)
np.save(f"{out_dir}/input.npy", x)
np.save(f"{out_dir}/output.npy", y.to_host())
x_torch = torch.from_numpy(x)
y_torch = model(x_torch)
torch.testing.assert_close(y.to_host(), y_torch.detach().numpy())
print("y_torch = ", y_torch)
print("y_iree_cpu = ", y.to_host())
if os.path.exists(f"{out_dir}/output_cuda.npy"):
with open(f'{out_dir}/output_cuda.npy', 'rb') as f:
y_iree_cuda = np.load(f)
print("y_iree_cuda = ", y_iree_cuda)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment