Skip to content

Instantly share code, notes, and snippets.

@qedawkins
Created July 29, 2022 18:09
Show Gist options
  • Save qedawkins/3cb1bb7bf297840e93f3d99ffd23ba1a to your computer and use it in GitHub Desktop.
Save qedawkins/3cb1bb7bf297840e93f3d99ffd23ba1a to your computer and use it in GitHub Desktop.
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_torch_model
import numpy as np
mlir_model, func_name, inputs, golden_out = download_torch_model("v_diffusion")
shark_module = SharkInference(
mlir_model, func_name, device="cpu", mlir_dialect="linalg"
)
shark_module.compile()
result = shark_module.forward(inputs)
print("The obtained result via shark is: ", result)
print("The golden result is:", golden_out)
golden_out = np.reshape(golden_out[0], (2, 3, 256, 256))
result = np.reshape(result, (2, 3, 256, 256))
# Passes before cf21395c35aa6650df558010cc54ee0c5fbee646
np.testing.assert_allclose(golden_out, result, rtol=1e-02, atol=1e-03)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment