Created
May 4, 2022 20:41
-
-
Save jpienaar/9d22b710db6122a9a41e8cd3d824a39e to your computer and use it in GitHub Desktop.
Move python return types
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
From f492f516e190a017e69283ca88c6ed40899918bf Mon Sep 17 00:00:00 2001 | |
From: Jacques Pienaar <jpienaar@google.com> | |
Date: Tue, 3 May 2022 21:26:38 -0700 | |
Subject: [PATCH] [mlir][python] Move result types to end in generated | |
__init__. | |
Previously result types were specified in multiple different ways at the front | |
of the builder args. But this meant that when the build time type inference was | |
added the python API changed and folks needed to update call sites. Instead | |
move result types to keyword args and make it optional in the case where the | |
result type can be inferred. This uniformly uses results for the arguments name | |
rather than attempt to expand or special case it (did consider still supporting | |
result and results and could be convinced that is better). This allows adding | |
type inference without breaking change and also brings the arguments to the | |
fore and more aligned with ODS def. All other arguments are left unchanged. | |
--- | |
mlir/python/mlir/dialects/_func_ops_ext.py | 8 +- | |
mlir/python/mlir/dialects/_memref_ops_ext.py | 2 +- | |
mlir/python/mlir/dialects/_pdl_ops_ext.py | 18 ++-- | |
.../dialects/linalg/opdsl/lang/emitter.py | 22 ++--- | |
mlir/test/python/dialects/vector.py | 8 +- | |
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 99 +++++-------------- | |
6 files changed, 52 insertions(+), 105 deletions(-) | |
diff --git a/mlir/python/mlir/dialects/_func_ops_ext.py b/mlir/python/mlir/dialects/_func_ops_ext.py | |
index 6fe3ff5302e2..a540e6ec87e4 100644 | |
--- a/mlir/python/mlir/dialects/_func_ops_ext.py | |
+++ b/mlir/python/mlir/dialects/_func_ops_ext.py | |
@@ -19,7 +19,7 @@ class ConstantOp: | |
"""Specialization for the constant op class.""" | |
def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None): | |
- super().__init__(result, value, loc=loc, ip=ip) | |
+ super().__init__(value, results=[result], loc=loc, ip=ip) | |
@property | |
def type(self): | |
@@ -273,11 +273,11 @@ class CallOp: | |
"to a function") | |
super().__init__( | |
- calleeOrResults.type.results, | |
FlatSymbolRefAttr.get( | |
calleeOrResults.name.value, | |
context=_get_default_loc_context(loc)), | |
argumentsOrCallee, | |
+ results=calleeOrResults.type.results, | |
loc=loc, | |
ip=ip) | |
return | |
@@ -289,12 +289,12 @@ class CallOp: | |
if isinstance(argumentsOrCallee, FlatSymbolRefAttr): | |
super().__init__( | |
- calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip) | |
+ argumentsOrCallee, arguments, results=calleeOrResults, loc=loc, ip=ip) | |
elif isinstance(argumentsOrCallee, str): | |
super().__init__( | |
- calleeOrResults, | |
FlatSymbolRefAttr.get( | |
argumentsOrCallee, context=_get_default_loc_context(loc)), | |
arguments, | |
+ results=calleeOrResults, | |
loc=loc, | |
ip=ip) | |
diff --git a/mlir/python/mlir/dialects/_memref_ops_ext.py b/mlir/python/mlir/dialects/_memref_ops_ext.py | |
index 9cc22a21c628..1e8cb1c5a5c4 100644 | |
--- a/mlir/python/mlir/dialects/_memref_ops_ext.py | |
+++ b/mlir/python/mlir/dialects/_memref_ops_ext.py | |
@@ -34,4 +34,4 @@ class LoadOp: | |
indices_resolved = [] if indices is None else _get_op_results_or_values( | |
indices) | |
return_type = MemRefType(memref_resolved.type).element_type | |
- super().__init__(return_type, memref, indices_resolved, loc=loc, ip=ip) | |
+ super().__init__(memref, indices_resolved, results=[return_type], loc=loc, ip=ip) | |
diff --git a/mlir/python/mlir/dialects/_pdl_ops_ext.py b/mlir/python/mlir/dialects/_pdl_ops_ext.py | |
index fb5b519c7c02..d8fdf52f5931 100644 | |
--- a/mlir/python/mlir/dialects/_pdl_ops_ext.py | |
+++ b/mlir/python/mlir/dialects/_pdl_ops_ext.py | |
@@ -79,7 +79,7 @@ class ApplyNativeRewriteOp: | |
ip=None): | |
name = _get_str_attr(name) | |
args = _get_values(args) | |
- super().__init__(results, name, args, loc=loc, ip=ip) | |
+ super().__init__(name, args, results=results, loc=loc, ip=ip) | |
class AttributeOp: | |
@@ -93,7 +93,7 @@ class AttributeOp: | |
ip=None): | |
type = type if type is None else _get_value(type) | |
result = pdl.AttributeType.get() | |
- super().__init__(result, type, value, loc=loc, ip=ip) | |
+ super().__init__(type, value, results=[result], loc=loc, ip=ip) | |
class EraseOp: | |
@@ -118,7 +118,7 @@ class OperandOp: | |
ip=None): | |
type = type if type is None else _get_value(type) | |
result = pdl.ValueType.get() | |
- super().__init__(result, type, loc=loc, ip=ip) | |
+ super().__init__(type, results=[result], loc=loc, ip=ip) | |
class OperandsOp: | |
@@ -131,7 +131,7 @@ class OperandsOp: | |
ip=None): | |
types = types if types is None else _get_value(types) | |
result = pdl.RangeType.get(pdl.ValueType.get()) | |
- super().__init__(result, types, loc=loc, ip=ip) | |
+ super().__init__(types, results=[result], loc=loc, ip=ip) | |
class OperationOp: | |
@@ -155,7 +155,7 @@ class OperationOp: | |
attributeNames = ArrayAttr.get(attributeNames) | |
types = _get_values(types) | |
result = pdl.OperationType.get() | |
- super().__init__(result, name, args, attributeValues, attributeNames, types, loc=loc, ip=ip) | |
+ super().__init__(name, args, attributeValues, attributeNames, types, results=[result], loc=loc, ip=ip) | |
class PatternOp: | |
@@ -207,7 +207,7 @@ class ResultOp: | |
index = _get_int_attr(32, index) | |
parent = _get_value(parent) | |
result = pdl.ValueType.get() | |
- super().__init__(result, parent, index, loc=loc, ip=ip) | |
+ super().__init__(parent, index, results=[result], loc=loc, ip=ip) | |
class ResultsOp: | |
@@ -222,7 +222,7 @@ class ResultsOp: | |
ip=None): | |
parent = _get_value(parent) | |
index = index if index is None else _get_int_attr(32, index) | |
- super().__init__(result, parent, index, loc=loc, ip=ip) | |
+ super().__init__(parent, index, results=[result], loc=loc, ip=ip) | |
class RewriteOp: | |
@@ -261,7 +261,7 @@ class TypeOp: | |
ip=None): | |
type = type if type is None else _get_type_attr(type) | |
result = pdl.TypeType.get() | |
- super().__init__(result, type, loc=loc, ip=ip) | |
+ super().__init__(type, results=[result], loc=loc, ip=ip) | |
class TypesOp: | |
@@ -275,4 +275,4 @@ class TypesOp: | |
types = _get_array_attr([_get_type_attr(ty) for ty in types]) | |
types = None if not types else types | |
result = pdl.RangeType.get(pdl.TypeType.get()) | |
- super().__init__(result, types, loc=loc, ip=ip) | |
+ super().__init__(types, results=[result], loc=loc, ip=ip) | |
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py | |
index 2e71e561a7f5..8d6b00d9b1b6 100644 | |
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py | |
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py | |
@@ -196,7 +196,7 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, | |
[AffineMapAttr.get(am) for am in indexing_maps]) | |
generic_op = linalg.GenericOp( | |
- result_tensors=result_types, | |
+ results=result_types, | |
inputs=ins, | |
outputs=outs, | |
indexing_maps=indexing_maps_attr, | |
@@ -342,18 +342,18 @@ class _BodyBuilder: | |
operand_type = operand.type | |
if _is_floating_point_type(operand_type): | |
if is_unsigned_cast: | |
- return arith.FPToUIOp(to_type, operand).result | |
- return arith.FPToSIOp(to_type, operand).result | |
+ return arith.FPToUIOp(operand, results=[to_type]).result | |
+ return arith.FPToSIOp(operand, results=[to_type]).result | |
if _is_index_type(operand_type): | |
- return arith.IndexCastOp(to_type, operand).result | |
+ return arith.IndexCastOp(operand, results=[to_type]).result | |
# Assume integer. | |
from_width = IntegerType(operand_type).width | |
if to_width > from_width: | |
if is_unsigned_cast: | |
- return arith.ExtUIOp(to_type, operand).result | |
- return arith.ExtSIOp(to_type, operand).result | |
+ return arith.ExtUIOp(operand, results=[to_type]).result | |
+ return arith.ExtSIOp(operand, results=[to_type]).result | |
elif to_width < from_width: | |
- return arith.TruncIOp(to_type, operand).result | |
+ return arith.TruncIOp(operand, results=[to_type]).result | |
raise ValueError(f"Unable to cast body expression from {operand_type} to " | |
f"{to_type}") | |
@@ -362,15 +362,15 @@ class _BodyBuilder: | |
operand_type = operand.type | |
if _is_integer_type(operand_type): | |
if is_unsigned_cast: | |
- return arith.UIToFPOp(to_type, operand).result | |
- return arith.SIToFPOp(to_type, operand).result | |
+ return arith.UIToFPOp(operand, results=[to_type]).result | |
+ return arith.SIToFPOp(operand, results=[to_type]).result | |
# Assume FloatType. | |
to_width = _get_floating_point_width(to_type) | |
from_width = _get_floating_point_width(operand_type) | |
if to_width > from_width: | |
- return arith.ExtFOp(to_type, operand).result | |
+ return arith.ExtFOp(operand, results=[to_type]).result | |
elif to_width < from_width: | |
- return arith.TruncFOp(to_type, operand).result | |
+ return arith.TruncFOp(operand, results=[to_type]).result | |
raise ValueError(f"Unable to cast body expression from {operand_type} to " | |
f"{to_type}") | |
diff --git a/mlir/test/python/dialects/vector.py b/mlir/test/python/dialects/vector.py | |
index c31579545e6e..08a7a512190b 100644 | |
--- a/mlir/test/python/dialects/vector.py | |
+++ b/mlir/test/python/dialects/vector.py | |
@@ -45,10 +45,10 @@ def testTransferReadOp(): | |
F32Type.get(), mask_type], [])) | |
with InsertionPoint(f.add_entry_block()): | |
A, zero, padding, mask = f.arguments | |
- vector.TransferReadOp(vector_type, A, [zero, zero], identity_map_attr, | |
- padding, mask, None) | |
- vector.TransferReadOp(vector_type, A, [zero, zero], identity_map_attr, | |
- padding, None, None) | |
+ vector.TransferReadOp(A, [zero, zero], identity_map_attr, | |
+ padding, mask, None, results=[vector_type]) | |
+ vector.TransferReadOp(A, [zero, zero], identity_map_attr, | |
+ padding, None, None, results=[vector_type]) | |
func.ReturnOp([]) | |
# CHECK: @transfer_read(%[[MEM:.*]]: memref<?x?xf32>, %[[IDX:.*]]: index, | |
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | |
index 16fccff973ca..5b7fac346529 100644 | |
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | |
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | |
@@ -490,7 +490,6 @@ static void emitAttributeAccessors(const Operator &op, | |
constexpr const char *initTemplate = R"Py( | |
def __init__(self, {0}): | |
operands = [] | |
- results = [] | |
attributes = {{} | |
regions = None | |
{1} | |
@@ -503,7 +502,6 @@ constexpr const char *initTemplate = R"Py( | |
/// {0} is the field name. | |
constexpr const char *singleOperandAppendTemplate = | |
"operands.append(_get_op_result_or_value({0}))"; | |
-constexpr const char *singleResultAppendTemplate = "results.append({0})"; | |
/// Template for appending an optional element to the operand/result list. | |
/// {0} is the field name. | |
@@ -512,8 +510,6 @@ constexpr const char *optionalAppendOperandTemplate = | |
constexpr const char *optionalAppendAttrSizedOperandsTemplate = | |
"operands.append(_get_op_result_or_value({0}) if {0} is not None else " | |
"None)"; | |
-constexpr const char *optionalAppendResultTemplate = | |
- "if {0} is not None: results.append({0})"; | |
/// Template for appending a list of elements to the operand/result list. | |
/// {0} is the field name. | |
@@ -521,7 +517,6 @@ constexpr const char *multiOperandAppendTemplate = | |
"operands.extend(_get_op_results_or_values({0}))"; | |
constexpr const char *multiOperandAppendPackTemplate = | |
"operands.append(_get_op_results_or_values({0}))"; | |
-constexpr const char *multiResultAppendTemplate = "results.extend({0})"; | |
/// Template for setting an attribute in the operation builder. | |
/// {0} is the attribute name; | |
@@ -576,30 +571,6 @@ static bool canInferType(const Operator &op) { | |
hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op); | |
} | |
-/// Populates `builderArgs` with result names if the builder is expected to | |
-/// accept them as arguments. | |
-static void | |
-populateBuilderArgsResults(const Operator &op, | |
- llvm::SmallVectorImpl<std::string> &builderArgs) { | |
- if (canInferType(op)) | |
- return; | |
- | |
- for (int i = 0, e = op.getNumResults(); i < e; ++i) { | |
- std::string name = op.getResultName(i).str(); | |
- if (name.empty()) { | |
- if (op.getNumResults() == 1) { | |
- // Special case for one result, make the default name be 'result' | |
- // to properly match the built-in result accessor. | |
- name = "result"; | |
- } else { | |
- name = llvm::formatv("_gen_res_{0}", i); | |
- } | |
- } | |
- name = sanitizeName(name); | |
- builderArgs.push_back(name); | |
- } | |
-} | |
- | |
/// Populates `builderArgs` with the Python-compatible names of builder function | |
/// arguments using intermixed attributes and operands in the same order as they | |
/// appear in the `arguments` field of the op definition. Additionally, | |
@@ -724,25 +695,25 @@ populateBuilderLinesOperand(const Operator &op, | |
/// attribute: | |
/// - {0} is the name of the attribute from which to derive the types. | |
constexpr const char *deriveTypeFromAttrTemplate = | |
- R"PY(_ods_result_type_source_attr = attributes["{0}"] | |
-_ods_derived_result_type = ( | |
+ R"PY( _ods_result_type_source_attr = attributes["{0}"] | |
+ _ods_derived_result_type = ( | |
_ods_ir.TypeAttr(_ods_result_type_source_attr).value | |
if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else | |
_ods_result_type_source_attr.type))PY"; | |
/// Python code template appending {0} type {1} times to the results list. | |
-constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})"; | |
+constexpr const char *setSameResultsTemplate = " results=[{0}] * {1}"; | |
/// Python code template for inferring the operation results using the | |
/// corresponding interface: | |
/// - {0} is the name of the class for which the types are inferred. | |
constexpr const char *inferTypeInterfaceTemplate = | |
- R"PY(_ods_context = _ods_get_default_loc_context(loc) | |
-results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes( | |
- operands=operands, | |
- attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context), | |
- context=_ods_context, | |
- loc=loc) | |
+ R"PY( _ods_context = _ods_get_default_loc_context(loc) | |
+ results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes( | |
+ operands=operands, | |
+ attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context), | |
+ context=_ods_context, | |
+ loc=loc) | |
)PY"; | |
/// Appends the given multiline string as individual strings into | |
@@ -761,13 +732,11 @@ static void appendLineByLine(StringRef string, | |
/// builder to set up op results. | |
static void | |
populateBuilderLinesResult(const Operator &op, | |
- llvm::ArrayRef<std::string> names, | |
llvm::SmallVectorImpl<std::string> &builderLines) { | |
- bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr; | |
- | |
+ builderLines.push_back("if not results:"); | |
if (hasSameArgumentAndResultTypes(op)) { | |
builderLines.push_back(llvm::formatv( | |
- appendSameResultsTemplate, "operands[0].type", op.getNumResults())); | |
+ setSameResultsTemplate, "operands[0].type", op.getNumResults())); | |
return; | |
} | |
@@ -778,7 +747,7 @@ populateBuilderLinesResult(const Operator &op, | |
appendLineByLine( | |
llvm::formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(), | |
builderLines); | |
- builderLines.push_back(llvm::formatv(appendSameResultsTemplate, | |
+ builderLines.push_back(llvm::formatv(setSameResultsTemplate, | |
"_ods_derived_result_type", | |
op.getNumResults())); | |
return; | |
@@ -790,31 +759,6 @@ populateBuilderLinesResult(const Operator &op, | |
builderLines); | |
return; | |
} | |
- | |
- // For each element, find or generate a name. | |
- for (int i = 0, e = op.getNumResults(); i < e; ++i) { | |
- const NamedTypeConstraint &element = op.getResult(i); | |
- std::string name = names[i]; | |
- | |
- // Choose the formatting string based on the element kind. | |
- llvm::StringRef formatString; | |
- if (!element.isVariableLength()) { | |
- formatString = singleResultAppendTemplate; | |
- } else if (element.isOptional()) { | |
- formatString = optionalAppendResultTemplate; | |
- } else { | |
- assert(element.isVariadic() && "unhandled element group type"); | |
- // If emitting with sizedSegments, then we add the actual list-typed | |
- // element. Otherwise, we extend the actual operands. | |
- if (sizedSegments) { | |
- formatString = singleResultAppendTemplate; | |
- } else { | |
- formatString = multiResultAppendTemplate; | |
- } | |
- } | |
- | |
- builderLines.push_back(llvm::formatv(formatString.data(), name)); | |
- } | |
} | |
/// If the operation has variadic regions, adds a builder argument to specify | |
@@ -854,21 +798,24 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) { | |
llvm::SmallVector<std::string> successorArgNames; | |
builderArgs.reserve(op.getNumOperands() + op.getNumResults() + | |
op.getNumNativeAttributes() + op.getNumSuccessors()); | |
- populateBuilderArgsResults(op, builderArgs); | |
- size_t numResultArgs = builderArgs.size(); | |
populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames); | |
populateBuilderLinesOperand(op, operandArgNames, builderLines); | |
- populateBuilderLinesAttr( | |
- op, llvm::makeArrayRef(builderArgs).drop_front(numResultArgs), | |
- builderLines); | |
- populateBuilderLinesResult( | |
- op, llvm::makeArrayRef(builderArgs).take_front(numResultArgs), | |
- builderLines); | |
+ populateBuilderLinesAttr(op, llvm::makeArrayRef(builderArgs), builderLines); | |
populateBuilderLinesSuccessors(op, successorArgNames, builderLines); | |
populateBuilderRegions(op, builderArgs, builderLines); | |
builderArgs.push_back("*"); | |
+ if (op.getNumResults()) { | |
+ if (canInferType(op)) { | |
+ builderArgs.push_back("results=[]"); | |
+ populateBuilderLinesResult(op, builderLines); | |
+ } else { | |
+ builderArgs.push_back("results"); | |
+ } | |
+ } else { | |
+ builderLines.push_back("results=[]"); | |
+ } | |
builderArgs.push_back("loc=None"); | |
builderArgs.push_back("ip=None"); | |
os << llvm::formatv(initTemplate, llvm::join(builderArgs, ", "), | |
-- | |
2.35.1 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment