Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Schema Inference Benchmark Harness
library(ggplot2)
library(lubridate)
methods <- c("execute-query", "max-rows", "true = false", "true <> true", "limit 0", "prepared statement", "prepared statement (leaked)")
df <- data.table::fread("data-4.csv")
df$method <- factor(df$method, levels=methods)
df$duration <- as.numeric(as.duration(df$duration), "seconds")
df
y_breaks <- c(0.01, 0.1, 1, 10, 20)
y_labels <- c("10ms", "100ms", "1s", "10s", "20s")
p <- ggplot(df, aes(x=method, y=duration, color=method)) +
geom_boxplot() +
scale_y_log10("latency", breaks = y_breaks, labels = y_labels, limits = c(0.01, 20)) +
xlab("inference method") +
stat_summary(fun.data=mean_sdl, geom="pointrange") +
theme(legend.position="none") +
theme(text = element_text(family="Avenir", size=16),
axis.text.x = element_text(angle = 30, hjust = 1),
axis.title.x = element_text(size=20, vjust=-1),
axis.title.y = element_text(size=20))
p
aggregate(list(df$duration), list(df$method), FUN = function(x) c(mean = mean(x), median = median(x), sd = sd(x)))
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<properties>
<java.version>11</java.version>
</properties>
<groupId>ai.sisu.blog</groupId>
<artifactId>schema-inference</artifactId>
<version>1.0-SNAPSHOT</version>
<dependencies>
<dependency>
<groupId>com.amazon.redshift</groupId>
<artifactId>redshift-jdbc42</artifactId>
<version>1.2.10.1009</version>
</dependency>
</dependencies>
<repositories>
<repository>
<id>redshift</id>
<url>http://redshift-maven-repository.s3-website-us-east-1.amazonaws.com/release</url>
</repository>
</repositories>
<build>
<plugins>
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.8.0</version>
<configuration>
<release>11</release>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
<version>3.1.1</version>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>shade</goal>
</goals>
<configuration>
<filters>
<filter>
<artifact>*:*</artifact>
<!-- Excludes manifest signature files -->
<!-- We're transforming the manifest after build for the main class -->
<!-- So we can't include these - they're no longer valid -->
<excludes>
<exclude>META-INF/*.SF</exclude>
<exclude>META-INF/*.DSA</exclude>
<exclude>META-INF/*.RSA</exclude>
</excludes>
</filter>
</filters>
<transformers>
<!-- Sets the entry point -->
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
<mainClass>SchemaInference</mainClass>
</transformer>
<!-- Merges the jdbc driver definitions instead of having them overwrite each other -->
<transformer implementation="org.apache.maven.plugins.shade.resource.AppendingTransformer">
<resource>META-INF/services/java.sql.Driver</resource>
</transformer>
</transformers>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Statement;
import java.time.Duration;
import java.util.LinkedHashMap;
import java.util.concurrent.Callable;
public class SchemaInference {
private static String USAGE = "usage: java jar <jar> <jdbc-uri> <query>";
public static void main(String[] args) throws Exception {
if (args.length != 2) {
System.err.println(USAGE);
System.exit(1);
}
Connection conn = DriverManager.getConnection(args[0]);
String query = args[1];
// Turn off the query cache to get more reproducible benchmarks.
if (args[0].toLowerCase().contains("redshift")) {
System.err.println("disabling redshift result cache...");
conn.createStatement().execute("SET enable_result_cache_for_session TO off");
}
System.out.println("method,duration");
bench("execute-query", () -> inferSchemaExecute(conn, query, false));
bench("max-rows", () -> inferSchemaExecute(conn, query, true));
bench("true = false", () -> inferSchemaExecute(conn, "SELECT * FROM (" + query + ") WHERE true = false", false));
bench("true <> true", () -> inferSchemaExecute(conn, "SELECT * FROM (" + query + ") WHERE true <> true", false));
bench("limit 0", () -> inferSchemaExecute(conn, "SELECT * FROM (" + query + ") LIMIT 0", false));
bench("prepared statement", () -> inferSchemaPrepare(conn, query, false));
bench("prepared statement (leaked)", () -> inferSchemaPrepare(conn, query, true));
}
/**
* Executes the provided query, and returns the inferred result set schema.
*/
private static LinkedHashMap<String, String> inferSchemaExecute(
Connection conn,
String query,
boolean limitMaxRows
) throws SQLException {
try (Statement statement = conn.createStatement()) {
if (limitMaxRows) {
statement.setMaxRows(0);
}
statement.execute(query);
return metadataToSchema(statement.getResultSet().getMetaData());
}
}
/**
* Creates a prepared statement from the provided query, and returns the inferred result set schema.
* Optionally leaks the prepared statement.
*/
private static LinkedHashMap<String, String> inferSchemaPrepare(
Connection conn,
String query,
boolean leak
) throws SQLException {
PreparedStatement statement = conn.prepareStatement(query);
LinkedHashMap<String, String> schema = metadataToSchema(statement.getMetaData());
if (!leak) {
statement.close();
}
return schema;
}
/**
* Transforms the result set metadata into a schema map of column name to type.
*/
private static LinkedHashMap<String, String> metadataToSchema(
ResultSetMetaData metadata
) throws SQLException {
LinkedHashMap<String, String> schema = new LinkedHashMap<>();
for (int colIdx = 1; colIdx <= metadata.getColumnCount(); colIdx++) {
schema.put(metadata.getColumnName(colIdx), metadata.getColumnTypeName(colIdx));
}
return schema;
}
private static void bench(String method, Callable<?> f) throws Exception {
// Pre-warm query compilation.
for (int i = 0; i < 3; i++) {
f.call();
}
for (int i = 0; i < 20; i++) {
long start = System.nanoTime();
f.call();
long elapsed = System.nanoTime() - start;
System.out.println(method + "," + Duration.ofNanos(elapsed));
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.