Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created March 27, 2021 19:33
Show Gist options
  • Save pashu123/4276bcfe9e352a3c4a32f9a1c1948ada to your computer and use it in GitHub Desktop.
Save pashu123/4276bcfe9e352a3c4a32f9a1c1948ada to your computer and use it in GitHub Desktop.
import numpy as np
import ctypes
import gc, sys
from mlir.ir import *
from mlir.passmanager import *
from mlir.execution_engine import *
class MemRefDescriptor(ctypes.Structure):
""" Creates a ctype struct for memref descriptor"""
_fields_ = [
("allocated", ctypes.c_longlong),
("aligned", ctypes.POINTER(ctypes.c_float)),
("offset", ctypes.c_longlong),
("sizes", ctypes.c_longlong * 1),
("strides", ctypes.c_longlong * 1),
]
# Reference: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.ctypes.html
def npToCtype(np_array):
x = MemRefDescriptor()
x.allocated = np_array.ctypes.data
x.aligned = np_array.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
x.offset = ctypes.c_longlong(np_array.dtype.itemsize)
x.sizes = np_array.ctypes.shape
x.strides = np_array.ctypes.strides
return ctypes.pointer(x)
# Log everything to stderr and flush so that we have a unified stream to match
# errors/info emitted by MLIR to stderr.
def log(*args):
print(*args, file=sys.stderr)
sys.stderr.flush()
def run(f):
log("\nTEST:", f.__name__)
f()
gc.collect()
assert Context._get_live_count() == 0
def lowerToLLVM(module):
import mlir.conversions
pm = PassManager.parse("convert-std-to-llvm")
pm.run(module)
return module
def testInvokeMemrefAdd():
with Context():
module = Module.parse(
"""
module {
func @main(%arg0: memref<1xf32>, %arg1: memref<1xf32>) attributes { llvm.emit_c_interface } {
%0 = constant 0 : index
%1 = memref.load %arg0[%0] : memref<1xf32>
%2 = memref.load %arg0[%0] : memref<1xf32>
%3 = addf %1, %2 : f32
memref.store %3, %arg1[%0] : memref<1xf32>
return
}
} """
)
inp_arr = np.random.rand(1).astype(np.float32)
res_arr = np.random.rand(1).astype(np.float32)
inp_ctype = npToCtype(inp_arr)
res_ctype = npToCtype(res_arr)
execution_engine = ExecutionEngine(lowerToLLVM(module))
execution_engine.invoke("main", inp_ctype, res_ctype)
run(testInvokeMemrefAdd)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment