Skip to content

Instantly share code, notes, and snippets.

@c21
Created May 14, 2021 04:57
Show Gist options
  • Save c21/559ca07d78207cacddd0e281c28122e1 to your computer and use it in GitHub Desktop.
Save c21/559ca07d78207cacddd0e281c28122e1 to your computer and use it in GitHub Desktop.
Example code-gen for left anti sort merge join
val df1 = spark.range(10).select($"id".as("k1"))
val df2 = spark.range(4).select($"id".as("k2"))
df1.join(df2.hint("SHUFFLE_MERGE"), $"k1" === $"k2", "left_anti")
== Subtree 5 / 5 (maxMethodCodeSize:296; maxConstantPoolSize:156(0.24% used); numInnerClasses:0) ==
*(5) Project [id#0L AS k1#2L]
+- *(5) SortMergeJoin [id#0L], [k2#6L], LeftAnti
:- *(2) Sort [id#0L ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(id#0L, 5), ENSURE_REQUIREMENTS, [id=#27]
: +- *(1) Range (0, 10, step=1, splits=2)
+- *(4) Sort [k2#6L ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(k2#6L, 5), ENSURE_REQUIREMENTS, [id=#33]
+- *(3) Project [id#4L AS k2#6L]
+- *(3) Range (0, 4, step=1, splits=2)
Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */ return new GeneratedIteratorForCodegenStage5(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=5
/* 006 */ final class GeneratedIteratorForCodegenStage5 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */ private Object[] references;
/* 008 */ private scala.collection.Iterator[] inputs;
/* 009 */ private scala.collection.Iterator smj_streamedInput_0;
/* 010 */ private scala.collection.Iterator smj_bufferedInput_0;
/* 011 */ private InternalRow smj_streamedRow_0;
/* 012 */ private InternalRow smj_bufferedRow_0;
/* 013 */ private long smj_value_2;
/* 014 */ private org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray smj_matches_0;
/* 015 */ private long smj_value_3;
/* 016 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] smj_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[2];
/* 017 */
/* 018 */ public GeneratedIteratorForCodegenStage5(Object[] references) {
/* 019 */ this.references = references;
/* 020 */ }
/* 021 */
/* 022 */ public void init(int index, scala.collection.Iterator[] inputs) {
/* 023 */ partitionIndex = index;
/* 024 */ this.inputs = inputs;
/* 025 */ smj_streamedInput_0 = inputs[0];
/* 026 */ smj_bufferedInput_0 = inputs[1];
/* 027 */
/* 028 */ smj_matches_0 = new org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray(1, 2147483647);
/* 029 */ smj_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 030 */ smj_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 031 */
/* 032 */ }
/* 033 */
/* 034 */ private boolean findNextJoinRows(
/* 035 */ scala.collection.Iterator streamedIter,
/* 036 */ scala.collection.Iterator bufferedIter) {
/* 037 */ smj_streamedRow_0 = null;
/* 038 */ int comp = 0;
/* 039 */ while (smj_streamedRow_0 == null) {
/* 040 */ if (!streamedIter.hasNext()) return false;
/* 041 */ smj_streamedRow_0 = (InternalRow) streamedIter.next();
/* 042 */ long smj_value_0 = smj_streamedRow_0.getLong(0);
/* 043 */ if (false) {
/* 044 */ if (!smj_matches_0.isEmpty()) {
/* 045 */ smj_matches_0.clear();
/* 046 */ }
/* 047 */ return false;
/* 048 */
/* 049 */ }
/* 050 */ if (!smj_matches_0.isEmpty()) {
/* 051 */ comp = 0;
/* 052 */ if (comp == 0) {
/* 053 */ comp = (smj_value_0 > smj_value_3 ? 1 : smj_value_0 < smj_value_3 ? -1 : 0);
/* 054 */ }
/* 055 */
/* 056 */ if (comp == 0) {
/* 057 */ return true;
/* 058 */ }
/* 059 */ smj_matches_0.clear();
/* 060 */ }
/* 061 */
/* 062 */ do {
/* 063 */ if (smj_bufferedRow_0 == null) {
/* 064 */ if (!bufferedIter.hasNext()) {
/* 065 */ smj_value_3 = smj_value_0;
/* 066 */ return !smj_matches_0.isEmpty();
/* 067 */ }
/* 068 */ smj_bufferedRow_0 = (InternalRow) bufferedIter.next();
/* 069 */ long smj_value_1 = smj_bufferedRow_0.getLong(0);
/* 070 */ if (false) {
/* 071 */ smj_bufferedRow_0 = null;
/* 072 */ continue;
/* 073 */ }
/* 074 */ smj_value_2 = smj_value_1;
/* 075 */ }
/* 076 */
/* 077 */ comp = 0;
/* 078 */ if (comp == 0) {
/* 079 */ comp = (smj_value_0 > smj_value_2 ? 1 : smj_value_0 < smj_value_2 ? -1 : 0);
/* 080 */ }
/* 081 */
/* 082 */ if (comp > 0) {
/* 083 */ smj_bufferedRow_0 = null;
/* 084 */ } else if (comp < 0) {
/* 085 */ if (!smj_matches_0.isEmpty()) {
/* 086 */ smj_value_3 = smj_value_0;
/* 087 */ return true;
/* 088 */ } else {
/* 089 */ return false;
/* 090 */ }
/* 091 */ } else {
/* 092 */ if (smj_matches_0.isEmpty()) {
/* 093 */ smj_matches_0.add((UnsafeRow) smj_bufferedRow_0);
/* 094 */ }
/* 095 */
/* 096 */ smj_bufferedRow_0 = null;
/* 097 */ }
/* 098 */ } while (smj_streamedRow_0 != null);
/* 099 */ }
/* 100 */ return false; // unreachable
/* 101 */ }
/* 102 */
/* 103 */ protected void processNext() throws java.io.IOException {
/* 104 */ while (smj_streamedInput_0.hasNext()) {
/* 105 */ findNextJoinRows(smj_streamedInput_0, smj_bufferedInput_0);
/* 106 */
/* 107 */ long smj_value_4 = -1L;
/* 108 */ smj_value_4 = smj_streamedRow_0.getLong(0);
/* 109 */ scala.collection.Iterator<UnsafeRow> smj_iterator_0 = smj_matches_0.generateIterator();
/* 110 */
/* 111 */ boolean wholestagecodegen_hasOutputRow_0 = false;
/* 112 */
/* 113 */ while (!wholestagecodegen_hasOutputRow_0 && smj_iterator_0.hasNext()) {
/* 114 */ InternalRow smj_bufferedRow_1 = (InternalRow) smj_iterator_0.next();
/* 115 */
/* 116 */ wholestagecodegen_hasOutputRow_0 = true;
/* 117 */ }
/* 118 */
/* 119 */ if (!wholestagecodegen_hasOutputRow_0) {
/* 120 */ // load all values of streamed row, because the values not in join condition are not
/* 121 */ // loaded yet.
/* 122 */
/* 123 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1);
/* 124 */
/* 125 */ // common sub-expressions
/* 126 */
/* 127 */ smj_mutableStateArray_0[1].reset();
/* 128 */
/* 129 */ smj_mutableStateArray_0[1].write(0, smj_value_4);
/* 130 */ append((smj_mutableStateArray_0[1].getRow()).copy());
/* 131 */
/* 132 */ }
/* 133 */ if (shouldStop()) return;
/* 134 */ }
/* 135 */ ((org.apache.spark.sql.execution.joins.SortMergeJoinExec) references[1] /* plan */).cleanupResources();
/* 136 */ }
/* 137 */
/* 138 */ }
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment