Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jpienaar/9d22b710db6122a9a41e8cd3d824a39e to your computer and use it in GitHub Desktop.
Save jpienaar/9d22b710db6122a9a41e8cd3d824a39e to your computer and use it in GitHub Desktop.
Move python return types
From f492f516e190a017e69283ca88c6ed40899918bf Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar@google.com>
Date: Tue, 3 May 2022 21:26:38 -0700
Subject: [PATCH] [mlir][python] Move result types to end in generated
__init__.
Previously result types were specified in multiple different ways at the front
of the builder args. But this meant that when the build time type inference was
added the python API changed and folks needed to update call sites. Instead
move result types to keyword args and make it optional in the case where the
result type can be inferred. This uniformly uses results for the arguments name
rather than attempt to expand or special case it (did consider still supporting
result and results and could be convinced that is better). This allows adding
type inference without breaking change and also brings the arguments to the
fore and more aligned with ODS def. All other arguments are left unchanged.
---
mlir/python/mlir/dialects/_func_ops_ext.py | 8 +-
mlir/python/mlir/dialects/_memref_ops_ext.py | 2 +-
mlir/python/mlir/dialects/_pdl_ops_ext.py | 18 ++--
.../dialects/linalg/opdsl/lang/emitter.py | 22 ++---
mlir/test/python/dialects/vector.py | 8 +-
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 99 +++++--------------
6 files changed, 52 insertions(+), 105 deletions(-)
diff --git a/mlir/python/mlir/dialects/_func_ops_ext.py b/mlir/python/mlir/dialects/_func_ops_ext.py
index 6fe3ff5302e2..a540e6ec87e4 100644
--- a/mlir/python/mlir/dialects/_func_ops_ext.py
+++ b/mlir/python/mlir/dialects/_func_ops_ext.py
@@ -19,7 +19,7 @@ class ConstantOp:
"""Specialization for the constant op class."""
def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None):
- super().__init__(result, value, loc=loc, ip=ip)
+ super().__init__(value, results=[result], loc=loc, ip=ip)
@property
def type(self):
@@ -273,11 +273,11 @@ class CallOp:
"to a function")
super().__init__(
- calleeOrResults.type.results,
FlatSymbolRefAttr.get(
calleeOrResults.name.value,
context=_get_default_loc_context(loc)),
argumentsOrCallee,
+ results=calleeOrResults.type.results,
loc=loc,
ip=ip)
return
@@ -289,12 +289,12 @@ class CallOp:
if isinstance(argumentsOrCallee, FlatSymbolRefAttr):
super().__init__(
- calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip)
+ argumentsOrCallee, arguments, results=calleeOrResults, loc=loc, ip=ip)
elif isinstance(argumentsOrCallee, str):
super().__init__(
- calleeOrResults,
FlatSymbolRefAttr.get(
argumentsOrCallee, context=_get_default_loc_context(loc)),
arguments,
+ results=calleeOrResults,
loc=loc,
ip=ip)
diff --git a/mlir/python/mlir/dialects/_memref_ops_ext.py b/mlir/python/mlir/dialects/_memref_ops_ext.py
index 9cc22a21c628..1e8cb1c5a5c4 100644
--- a/mlir/python/mlir/dialects/_memref_ops_ext.py
+++ b/mlir/python/mlir/dialects/_memref_ops_ext.py
@@ -34,4 +34,4 @@ class LoadOp:
indices_resolved = [] if indices is None else _get_op_results_or_values(
indices)
return_type = MemRefType(memref_resolved.type).element_type
- super().__init__(return_type, memref, indices_resolved, loc=loc, ip=ip)
+ super().__init__(memref, indices_resolved, results=[return_type], loc=loc, ip=ip)
diff --git a/mlir/python/mlir/dialects/_pdl_ops_ext.py b/mlir/python/mlir/dialects/_pdl_ops_ext.py
index fb5b519c7c02..d8fdf52f5931 100644
--- a/mlir/python/mlir/dialects/_pdl_ops_ext.py
+++ b/mlir/python/mlir/dialects/_pdl_ops_ext.py
@@ -79,7 +79,7 @@ class ApplyNativeRewriteOp:
ip=None):
name = _get_str_attr(name)
args = _get_values(args)
- super().__init__(results, name, args, loc=loc, ip=ip)
+ super().__init__(name, args, results=results, loc=loc, ip=ip)
class AttributeOp:
@@ -93,7 +93,7 @@ class AttributeOp:
ip=None):
type = type if type is None else _get_value(type)
result = pdl.AttributeType.get()
- super().__init__(result, type, value, loc=loc, ip=ip)
+ super().__init__(type, value, results=[result], loc=loc, ip=ip)
class EraseOp:
@@ -118,7 +118,7 @@ class OperandOp:
ip=None):
type = type if type is None else _get_value(type)
result = pdl.ValueType.get()
- super().__init__(result, type, loc=loc, ip=ip)
+ super().__init__(type, results=[result], loc=loc, ip=ip)
class OperandsOp:
@@ -131,7 +131,7 @@ class OperandsOp:
ip=None):
types = types if types is None else _get_value(types)
result = pdl.RangeType.get(pdl.ValueType.get())
- super().__init__(result, types, loc=loc, ip=ip)
+ super().__init__(types, results=[result], loc=loc, ip=ip)
class OperationOp:
@@ -155,7 +155,7 @@ class OperationOp:
attributeNames = ArrayAttr.get(attributeNames)
types = _get_values(types)
result = pdl.OperationType.get()
- super().__init__(result, name, args, attributeValues, attributeNames, types, loc=loc, ip=ip)
+ super().__init__(name, args, attributeValues, attributeNames, types, results=[result], loc=loc, ip=ip)
class PatternOp:
@@ -207,7 +207,7 @@ class ResultOp:
index = _get_int_attr(32, index)
parent = _get_value(parent)
result = pdl.ValueType.get()
- super().__init__(result, parent, index, loc=loc, ip=ip)
+ super().__init__(parent, index, results=[result], loc=loc, ip=ip)
class ResultsOp:
@@ -222,7 +222,7 @@ class ResultsOp:
ip=None):
parent = _get_value(parent)
index = index if index is None else _get_int_attr(32, index)
- super().__init__(result, parent, index, loc=loc, ip=ip)
+ super().__init__(parent, index, results=[result], loc=loc, ip=ip)
class RewriteOp:
@@ -261,7 +261,7 @@ class TypeOp:
ip=None):
type = type if type is None else _get_type_attr(type)
result = pdl.TypeType.get()
- super().__init__(result, type, loc=loc, ip=ip)
+ super().__init__(type, results=[result], loc=loc, ip=ip)
class TypesOp:
@@ -275,4 +275,4 @@ class TypesOp:
types = _get_array_attr([_get_type_attr(ty) for ty in types])
types = None if not types else types
result = pdl.RangeType.get(pdl.TypeType.get())
- super().__init__(result, types, loc=loc, ip=ip)
+ super().__init__(types, results=[result], loc=loc, ip=ip)
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 2e71e561a7f5..8d6b00d9b1b6 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -196,7 +196,7 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
[AffineMapAttr.get(am) for am in indexing_maps])
generic_op = linalg.GenericOp(
- result_tensors=result_types,
+ results=result_types,
inputs=ins,
outputs=outs,
indexing_maps=indexing_maps_attr,
@@ -342,18 +342,18 @@ class _BodyBuilder:
operand_type = operand.type
if _is_floating_point_type(operand_type):
if is_unsigned_cast:
- return arith.FPToUIOp(to_type, operand).result
- return arith.FPToSIOp(to_type, operand).result
+ return arith.FPToUIOp(operand, results=[to_type]).result
+ return arith.FPToSIOp(operand, results=[to_type]).result
if _is_index_type(operand_type):
- return arith.IndexCastOp(to_type, operand).result
+ return arith.IndexCastOp(operand, results=[to_type]).result
# Assume integer.
from_width = IntegerType(operand_type).width
if to_width > from_width:
if is_unsigned_cast:
- return arith.ExtUIOp(to_type, operand).result
- return arith.ExtSIOp(to_type, operand).result
+ return arith.ExtUIOp(operand, results=[to_type]).result
+ return arith.ExtSIOp(operand, results=[to_type]).result
elif to_width < from_width:
- return arith.TruncIOp(to_type, operand).result
+ return arith.TruncIOp(operand, results=[to_type]).result
raise ValueError(f"Unable to cast body expression from {operand_type} to "
f"{to_type}")
@@ -362,15 +362,15 @@ class _BodyBuilder:
operand_type = operand.type
if _is_integer_type(operand_type):
if is_unsigned_cast:
- return arith.UIToFPOp(to_type, operand).result
- return arith.SIToFPOp(to_type, operand).result
+ return arith.UIToFPOp(operand, results=[to_type]).result
+ return arith.SIToFPOp(operand, results=[to_type]).result
# Assume FloatType.
to_width = _get_floating_point_width(to_type)
from_width = _get_floating_point_width(operand_type)
if to_width > from_width:
- return arith.ExtFOp(to_type, operand).result
+ return arith.ExtFOp(operand, results=[to_type]).result
elif to_width < from_width:
- return arith.TruncFOp(to_type, operand).result
+ return arith.TruncFOp(operand, results=[to_type]).result
raise ValueError(f"Unable to cast body expression from {operand_type} to "
f"{to_type}")
diff --git a/mlir/test/python/dialects/vector.py b/mlir/test/python/dialects/vector.py
index c31579545e6e..08a7a512190b 100644
--- a/mlir/test/python/dialects/vector.py
+++ b/mlir/test/python/dialects/vector.py
@@ -45,10 +45,10 @@ def testTransferReadOp():
F32Type.get(), mask_type], []))
with InsertionPoint(f.add_entry_block()):
A, zero, padding, mask = f.arguments
- vector.TransferReadOp(vector_type, A, [zero, zero], identity_map_attr,
- padding, mask, None)
- vector.TransferReadOp(vector_type, A, [zero, zero], identity_map_attr,
- padding, None, None)
+ vector.TransferReadOp(A, [zero, zero], identity_map_attr,
+ padding, mask, None, results=[vector_type])
+ vector.TransferReadOp(A, [zero, zero], identity_map_attr,
+ padding, None, None, results=[vector_type])
func.ReturnOp([])
# CHECK: @transfer_read(%[[MEM:.*]]: memref<?x?xf32>, %[[IDX:.*]]: index,
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 16fccff973ca..5b7fac346529 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -490,7 +490,6 @@ static void emitAttributeAccessors(const Operator &op,
constexpr const char *initTemplate = R"Py(
def __init__(self, {0}):
operands = []
- results = []
attributes = {{}
regions = None
{1}
@@ -503,7 +502,6 @@ constexpr const char *initTemplate = R"Py(
/// {0} is the field name.
constexpr const char *singleOperandAppendTemplate =
"operands.append(_get_op_result_or_value({0}))";
-constexpr const char *singleResultAppendTemplate = "results.append({0})";
/// Template for appending an optional element to the operand/result list.
/// {0} is the field name.
@@ -512,8 +510,6 @@ constexpr const char *optionalAppendOperandTemplate =
constexpr const char *optionalAppendAttrSizedOperandsTemplate =
"operands.append(_get_op_result_or_value({0}) if {0} is not None else "
"None)";
-constexpr const char *optionalAppendResultTemplate =
- "if {0} is not None: results.append({0})";
/// Template for appending a list of elements to the operand/result list.
/// {0} is the field name.
@@ -521,7 +517,6 @@ constexpr const char *multiOperandAppendTemplate =
"operands.extend(_get_op_results_or_values({0}))";
constexpr const char *multiOperandAppendPackTemplate =
"operands.append(_get_op_results_or_values({0}))";
-constexpr const char *multiResultAppendTemplate = "results.extend({0})";
/// Template for setting an attribute in the operation builder.
/// {0} is the attribute name;
@@ -576,30 +571,6 @@ static bool canInferType(const Operator &op) {
hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op);
}
-/// Populates `builderArgs` with result names if the builder is expected to
-/// accept them as arguments.
-static void
-populateBuilderArgsResults(const Operator &op,
- llvm::SmallVectorImpl<std::string> &builderArgs) {
- if (canInferType(op))
- return;
-
- for (int i = 0, e = op.getNumResults(); i < e; ++i) {
- std::string name = op.getResultName(i).str();
- if (name.empty()) {
- if (op.getNumResults() == 1) {
- // Special case for one result, make the default name be 'result'
- // to properly match the built-in result accessor.
- name = "result";
- } else {
- name = llvm::formatv("_gen_res_{0}", i);
- }
- }
- name = sanitizeName(name);
- builderArgs.push_back(name);
- }
-}
-
/// Populates `builderArgs` with the Python-compatible names of builder function
/// arguments using intermixed attributes and operands in the same order as they
/// appear in the `arguments` field of the op definition. Additionally,
@@ -724,25 +695,25 @@ populateBuilderLinesOperand(const Operator &op,
/// attribute:
/// - {0} is the name of the attribute from which to derive the types.
constexpr const char *deriveTypeFromAttrTemplate =
- R"PY(_ods_result_type_source_attr = attributes["{0}"]
-_ods_derived_result_type = (
+ R"PY( _ods_result_type_source_attr = attributes["{0}"]
+ _ods_derived_result_type = (
_ods_ir.TypeAttr(_ods_result_type_source_attr).value
if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
_ods_result_type_source_attr.type))PY";
/// Python code template appending {0} type {1} times to the results list.
-constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
+constexpr const char *setSameResultsTemplate = " results=[{0}] * {1}";
/// Python code template for inferring the operation results using the
/// corresponding interface:
/// - {0} is the name of the class for which the types are inferred.
constexpr const char *inferTypeInterfaceTemplate =
- R"PY(_ods_context = _ods_get_default_loc_context(loc)
-results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes(
- operands=operands,
- attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
- context=_ods_context,
- loc=loc)
+ R"PY( _ods_context = _ods_get_default_loc_context(loc)
+ results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes(
+ operands=operands,
+ attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
+ context=_ods_context,
+ loc=loc)
)PY";
/// Appends the given multiline string as individual strings into
@@ -761,13 +732,11 @@ static void appendLineByLine(StringRef string,
/// builder to set up op results.
static void
populateBuilderLinesResult(const Operator &op,
- llvm::ArrayRef<std::string> names,
llvm::SmallVectorImpl<std::string> &builderLines) {
- bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
-
+ builderLines.push_back("if not results:");
if (hasSameArgumentAndResultTypes(op)) {
builderLines.push_back(llvm::formatv(
- appendSameResultsTemplate, "operands[0].type", op.getNumResults()));
+ setSameResultsTemplate, "operands[0].type", op.getNumResults()));
return;
}
@@ -778,7 +747,7 @@ populateBuilderLinesResult(const Operator &op,
appendLineByLine(
llvm::formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(),
builderLines);
- builderLines.push_back(llvm::formatv(appendSameResultsTemplate,
+ builderLines.push_back(llvm::formatv(setSameResultsTemplate,
"_ods_derived_result_type",
op.getNumResults()));
return;
@@ -790,31 +759,6 @@ populateBuilderLinesResult(const Operator &op,
builderLines);
return;
}
-
- // For each element, find or generate a name.
- for (int i = 0, e = op.getNumResults(); i < e; ++i) {
- const NamedTypeConstraint &element = op.getResult(i);
- std::string name = names[i];
-
- // Choose the formatting string based on the element kind.
- llvm::StringRef formatString;
- if (!element.isVariableLength()) {
- formatString = singleResultAppendTemplate;
- } else if (element.isOptional()) {
- formatString = optionalAppendResultTemplate;
- } else {
- assert(element.isVariadic() && "unhandled element group type");
- // If emitting with sizedSegments, then we add the actual list-typed
- // element. Otherwise, we extend the actual operands.
- if (sizedSegments) {
- formatString = singleResultAppendTemplate;
- } else {
- formatString = multiResultAppendTemplate;
- }
- }
-
- builderLines.push_back(llvm::formatv(formatString.data(), name));
- }
}
/// If the operation has variadic regions, adds a builder argument to specify
@@ -854,21 +798,24 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
llvm::SmallVector<std::string> successorArgNames;
builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
op.getNumNativeAttributes() + op.getNumSuccessors());
- populateBuilderArgsResults(op, builderArgs);
- size_t numResultArgs = builderArgs.size();
populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames);
populateBuilderLinesOperand(op, operandArgNames, builderLines);
- populateBuilderLinesAttr(
- op, llvm::makeArrayRef(builderArgs).drop_front(numResultArgs),
- builderLines);
- populateBuilderLinesResult(
- op, llvm::makeArrayRef(builderArgs).take_front(numResultArgs),
- builderLines);
+ populateBuilderLinesAttr(op, llvm::makeArrayRef(builderArgs), builderLines);
populateBuilderLinesSuccessors(op, successorArgNames, builderLines);
populateBuilderRegions(op, builderArgs, builderLines);
builderArgs.push_back("*");
+ if (op.getNumResults()) {
+ if (canInferType(op)) {
+ builderArgs.push_back("results=[]");
+ populateBuilderLinesResult(op, builderLines);
+ } else {
+ builderArgs.push_back("results");
+ }
+ } else {
+ builderLines.push_back("results=[]");
+ }
builderArgs.push_back("loc=None");
builderArgs.push_back("ip=None");
os << llvm::formatv(initTemplate, llvm::join(builderArgs, ", "),
--
2.35.1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment