Skip to content

Instantly share code, notes, and snippets.

@antiagainst
Created September 9, 2022 23:05
Show Gist options
  • Save antiagainst/0d2cc3463299b497b797662ddbffa3db to your computer and use it in GitHub Desktop.
Save antiagainst/0d2cc3463299b497b797662ddbffa3db to your computer and use it in GitHub Desktop.
# Copied from https://colab.sandbox.google.com/github/iree-org/iree/blob/main/samples/colab/resnet.ipynb
# Running the following commands to install needed packages
# pip install --upgrade iree-compiler iree-runtime iree-tools-tf -f https://github.com/iree-org/iree/releases
# pip install --upgrade tf-nightly
from iree import runtime as ireert
from iree import compiler as ireec
from iree.tf.support import module_utils
import tensorflow as tf
from absl import app
print("TensorFlow version: ", tf.__version__)
INPUT_SHAPE = [1, 224, 224, 3]
tf_model = tf.keras.applications.resnet50.ResNet50(weights="imagenet",
include_top=True,
input_shape=tuple(
INPUT_SHAPE[1:]))
# Wrap the model in a tf.Module to compile it with IREE.
class ResNetModule(tf.Module):
def __init__(self):
super(ResNetModule, self).__init__()
self.m = tf_model
self.m.predict = lambda x: self.m.call(x, training=False)
self.predict = tf.function(
input_signature=[tf.TensorSpec(INPUT_SHAPE, tf.float32)])(
tf_model.predict)
def main(argv):
backend = module_utils.BackendInfo("iree_vulkan")
# Import TF model into MLIR format. This generates quite a few artifacts
# inside the directory; they are the model representation at different
# levels. We only need the mhlo representation there. It also compiles
# the model for vulkan, but that's using the default parameters; so we
# will discard that and recompile later (as there is no way to control
# the compilation options here).
backend.compile_from_class(ResNetModule,
exported_names=["predict"],
artifacts_dir="./iree-resnet50-artifacts")
# Read in the imported mhlo representation of the model.
with open("./iree-resnet50-artifacts/iree_input.mlir") as f:
mhlo_source = f.read()
# Compile using iree-compile wrapper. Here we'll have access to all
# developer command-line option controls.
compilation_args = [
"--iree-vulkan-target-triple=rdna2-unknown-linux",
"--mlir-print-debuginfo=false",
# Add more command-line options you'd like to iree-compile here:
#"--mlir-print-ir-after=iree-hal-materialize-interfaces",
#"--mlir-elide-elementsattrs-if-larger=8",
]
blob = ireec.compile_str(mhlo_source,
target_backends=["vulkan"],
extra_args=compilation_args,
input_type="mhlo")
# Write out the IREE module blob.
with open("./iree-resnet50-artifacts/amd-resnet50.vmfb", "wb") as f:
f.write(blob)
if __name__ == "__main__":
app.run(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment