Skip to content

Instantly share code, notes, and snippets.

@kelemen
Created June 5, 2020 21:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kelemen/1bafe46e898326252cfda224b98a0e07 to your computer and use it in GitHub Desktop.
Save kelemen/1bafe46e898326252cfda224b98a0e07 to your computer and use it in GitHub Desktop.
Replicating Spark slowness with structs
package sparktest;
import java.io.IOException;
import java.util.Arrays;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.spark.SparkConf;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StructType;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import static org.apache.spark.sql.functions.*;
import static org.apache.spark.sql.types.DataTypes.*;
public class SparkTest {
@Rule
public final TemporaryFolder tmpFolder = new TemporaryFolder();
private SparkSession newSession() throws IOException {
SparkConf config = new SparkConf();
config.set("spark.local.dir", tmpFolder.newFolder("spark-test").toString());
return SparkSession.builder()
.master("local[4]")
.appName("SparkTest")
.config(config)
.getOrCreate();
}
@Test
public void testSparkStructs() throws IOException {
try (SparkSession session = newSession()) {
Dataset<Row> src = session.createDataFrame(
Arrays.asList(RowFactory.create("G1", 1.0), RowFactory.create("G1", 2.0), RowFactory.create("G2", 3.0)),
createStructType(Arrays.asList(createStructField("G", StringType, true), createStructField("A", DoubleType, true)))
);
Column srcCol = col("A");
StructWrapper srcColRef = new StructWrapper(
DoubleType,
IntStream.range(0, 10).mapToObj(i -> srcCol.plus((double) i)).toArray(Column[]::new)
);
DataType decimalType = createDecimalType(38, 18);
StructWrapper aggColRef = srcColRef
.mapFields(a -> a.multiply(a))
.mapFields(a -> a.divide(a.plus(1.0)))
.mapFields(a -> coalesce(a, lit(100.0)))
.mapFields(decimalType, a -> a.cast(decimalType))
.mapFields(functions::sum);
long startTime = System.nanoTime();
Dataset<Row> result = src.groupBy("G").agg(aggColRef.getSrcCol().as("A"));
System.out.println("Elapsed (agg): " + TimeUnit.NANOSECONDS.toSeconds(System.nanoTime() - startTime) + " s");
startTime = System.nanoTime();
result.show(false);
System.out.println("Elapsed (show): " + TimeUnit.NANOSECONDS.toSeconds(System.nanoTime() - startTime) + " s");
}
}
private static final class StructWrapper {
private final Column srcCol;
private final DataType fieldType;
private final StructType schema;
public StructWrapper(DataType fieldType, Column... fields) {
this.fieldType = fieldType;
this.srcCol = struct(IntStream
.range(0, fields.length)
.mapToObj(i -> fields[i].as("A_" + i))
.toArray(Column[]::new)
);
this.schema = createStructType(IntStream
.range(0, fields.length)
.mapToObj(i -> createStructField("A_" + i, fieldType, true))
.collect(Collectors.toList())
);
}
public Column getSrcCol() {
return srcCol;
}
public StructWrapper mapFields(Function<Column, Column> fieldMapper) {
return mapFields(fieldType, fieldMapper);
}
public StructWrapper mapFields(DataType resultFieldType, Function<Column, Column> fieldMapper) {
return new StructWrapper(resultFieldType, Arrays
.stream(schema.fieldNames())
.map(f -> fieldMapper.apply(srcCol.getField(f)))
.toArray(Column[]::new)
);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment