Skip to content

Instantly share code, notes, and snippets.

@okorz001
Last active November 18, 2015 03:41
Show Gist options
  • Save okorz001/d5597ecc81e46aa8b9de to your computer and use it in GitHub Desktop.
Save okorz001/d5597ecc81e46aa8b9de to your computer and use it in GitHub Desktop.
OO SQL Query Builder
import com.google.common.base.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Types;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
/**
* Simple builder for SQL select queries.
*/
public class QueryBuilder {
private static final Logger LOG = LoggerFactory.getLogger(QueryBuilder.class);
private static abstract class QueryParam<T> {
private final int sqlType; // needed for PreparedStatement.setNull
private final T value;
public QueryParam(final int sqlType, final T value) {
this.sqlType = sqlType;
this.value = value;
}
public String getValue() {
return this.value == null ? "null" : "'" + this.value.toString() + "'";
}
public void bind(final PreparedStatement stmt, final int n) throws SQLException {
if (this.value == null) {
stmt.setNull(n, sqlType);
} else {
this.bind(stmt, n, value);
}
}
abstract void bind(final PreparedStatement stmt, final int n, final T value) throws SQLException;
}
private static class QueryParamString extends QueryParam<String> {
public QueryParamString(final String value) {
super(Types.VARCHAR, value);
}
@Override
protected void bind(final PreparedStatement stmt, final int n, final String value) throws SQLException {
stmt.setString(n, value);
}
}
private static class QueryParamInt extends QueryParam<Integer> {
public QueryParamInt(final Integer value) {
super(Types.INTEGER, value);
}
@Override
protected void bind(final PreparedStatement stmt, final int n, final Integer value) throws SQLException {
stmt.setInt(n, value);
}
}
private static class QueryParamShort extends QueryParam<Short> {
public QueryParamShort(final Short value) {
super(Types.SMALLINT, value);
}
@Override
void bind(final PreparedStatement stmt, final int n, final Short value) throws SQLException {
stmt.setShort(n, value);
}
}
private static class QueryParamByte extends QueryParam<Byte> {
public QueryParamByte(final Byte value) {
super(Types.TINYINT, value);
}
@Override
void bind(final PreparedStatement stmt, final int n, final Byte value) throws SQLException {
stmt.setByte(n, value);
}
}
private final List<String> selects;
private final String table;
private final List<String> joins;
private final List<String> wheres;
// Order is important. Values can be modified without affecting the key order.
private final LinkedHashMap<String, QueryParam> params;
/**
* Starts a new query.
*
* @param table The table to query.
* @throws NullPointerException If table is null.
*/
public QueryBuilder(final String table) {
Preconditions.checkNotNull(table, "table is null");
this.selects = new ArrayList<>();
this.table = table;
this.joins = new ArrayList<>();
this.wheres = new ArrayList<>();
this.params = new LinkedHashMap<>();
}
/**
* Adds a field to the query result.
*
* @param select The field to add to the result.
* @return This QueryBuilder.
* @throws NullPointerException If select is null.
*/
public QueryBuilder addSelect(final String select) {
Preconditions.checkNotNull(select, "select is null");
this.selects.add(select);
return this;
}
/**
* Adds a join to the query.
*
* @param join Join to add to the query.
* @return This QueryBuilder.
* @throws NullPointerException If join is null.
*/
public QueryBuilder addJoin(final String join) {
Preconditions.checkNotNull(join, "join is null");
this.joins.add(join);
return this;
}
/**
* Adds a filter to the query.
*
* @param where The filter to add to the query.
* @param paramNames The names of the parameters in this filter.
* @return This QueryBuilder.
* @throws NullPointerException If where or any paramName is null.
* @throws IllegalArgumentException If the size of paramNames does not match the number of parameters.
*/
public QueryBuilder addWhere(final String where, final String... paramNames) {
return this.addWhere(where, Arrays.asList(paramNames));
}
/**
* Adds a filter to the query.
*
* @param where The filter to add to the query.
* @param paramNames The names of the parameters in this filter.
* @return This QueryBuilder.
* @throws NullPointerException If where or any paramName is null.
* @throws IllegalArgumentException If the size of paramNames does not match the number of parameters.
*/
public QueryBuilder addWhere(final String where, final List<String> paramNames) {
Preconditions.checkNotNull(where, "where is null");
for (int i = 0; i < paramNames.size(); ++i) {
Preconditions.checkNotNull(paramNames.get(i), "paramName[" + i + "] is null");
}
// Count the question marks.
final long paramCount = where.codePoints().filter(c -> c == '?').count();
Preconditions.checkArgument(paramNames.size() == paramCount,
"paramNames wrong size: " + paramNames.size() + " != " + paramCount);
this.wheres.add(where);
paramNames.forEach(paramName -> this.params.put(paramName, null));
return this;
}
/**
* Builds a SQL query.
*
* @return A SQL query.
*/
public String build() {
final StringBuilder buf = new StringBuilder();
buf.append("SELECT ");
if (this.selects.isEmpty()) {
// Default to everything.
buf.append('*');
}
else {
buf.append(String.join(", ", this.selects));
}
buf.append("\nFROM ");
buf.append(this.table);
if (!this.joins.isEmpty()) {
buf.append('\n');
buf.append(String.join("\n", this.joins));
}
if (!this.wheres.isEmpty()) {
buf.append("\nWHERE ");
buf.append(String.join("\nAND ", this.wheres));
}
final String sql = buf.toString();
//LOG.debug("Built query:\n{}", sql);
return sql;
}
/**
* Sets the value of the named parameter.
*
* @param name The parameter name.
* @param value The parameter value.
* @return This QueryBuilder.
* @throws NullPointerException If name is null.
* @throws IllegalArgumentException If the named parameter does not exist.
*/
public QueryBuilder setParam(final String name, final String value) {
Preconditions.checkNotNull(name, "name is null");
Preconditions.checkArgument(this.params.containsKey(name), "invalid name: " + name);
this.params.put(name, new QueryParamString(value));
return this;
}
/**
* Sets the value of the named parameter.
*
* @param name The parameter name.
* @param value The parameter value.
* @return This QueryBuilder.
* @throws NullPointerException If name is null.
* @throws IllegalArgumentException If the named parameter does not exist.
*/
public QueryBuilder setParam(final String name, final Integer value) {
Preconditions.checkNotNull(name, "name is null");
Preconditions.checkArgument(this.params.containsKey(name), "invalid name: " + name);
this.params.put(name, new QueryParamInt(value));
return this;
}
/**
* Sets the value of the named parameter.
*
* @param name The parameter name.
* @param value The parameter value.
* @return This QueryBuilder.
* @throws NullPointerException If name is null.
* @throws IllegalArgumentException If the named parameter does not exist.
*/
public QueryBuilder setParam(final String name, final Short value) {
Preconditions.checkNotNull(name, "name is null");
Preconditions.checkArgument(this.params.containsKey(name), "invalid name: " + name);
this.params.put(name, new QueryParamShort(value));
return this;
}
/**
* Sets the value of the named parameter.
*
* @param name The parameter name.
* @param value The parameter value.
* @return This QueryBuilder.
* @throws NullPointerException If name is null.
* @throws IllegalArgumentException If the named parameter does not exist.
*/
public QueryBuilder setParam(final String name, final Byte value) {
Preconditions.checkNotNull(name, "name is null");
Preconditions.checkArgument(this.params.containsKey(name), "invalid name: " + name);
this.params.put(name, new QueryParamByte(value));
return this;
}
/**
* Binds the parameter values to a PreparedStatement instance.
*
* This method assumes the PreparedStatement was created with the SQL query generated by this QueryBuilder.
*
* This method has a weak exception guarantee: partial state changes will not be rolled back.
*
* @param stmt The PreparedStatement to set variables on.
* @throws NullPointerException If stmt is null.
* @throws IllegalStateException If one or more named parameters has not been set.
* @throws SQLException From the JDBC driver.
*/
public void bindParams(final PreparedStatement stmt) throws SQLException {
Preconditions.checkNotNull(stmt, "stmt is null");
int i = 1; // JDBC parameters start from 1, not 0!
for (final String name : this.params.keySet()) {
final QueryParam param = this.params.get(name);
if (param == null) {
throw new IllegalStateException("parameter has not been set: " + name);
}
param.bind(stmt, i++);
}
}
/**
* Executes the query and returns user defined data generated from the rows.
*
* This method builds the query, binds the parameters values and executes it with the given Connection. The
* RowVisitor is called to generate a value for every SQL row. These values are collected and returned to the
* caller.
*
* @param <T> The data type created from every row.
* @param conn The JDBC connection.
* @param rv A function that will be executed once for every row.
* @return A list of objects returned by the row visitor for every row.
* @throws NullPointerException If conn or rv is null.
* @throws IllegalStateException If one or more named parameters has not been set.
* @throws SQLException From the JDBC driver.
*/
public <T> List<T> execute(final Connection conn, final RowVisitor<T> rv) throws SQLException {
Preconditions.checkNotNull(conn, "conn is null");
Preconditions.checkNotNull(rv, "rv is null");
LOG.debug("Executing query:\n{}", this);
final List<T> ret = new ArrayList<>();
try (final PreparedStatement stmt = conn.prepareStatement(this.build())) {
this.bindParams(stmt);
final ResultSet rs = stmt.executeQuery();
while (rs.next()) {
final T data = rv.visit(rs);
ret.add(data);
}
}
return ret;
}
/**
* Executes the query and returns user defined data generated from the rows.
*
* This method builds the query, binds the parameters values and executes it with the given Connection. The
* RowVisitors are called to generate a key and a value for every SQL row. These values are collected and returned
* to the caller.
*
* @param <K> The key type created from every row.
* @param <V> The value type created from every row.
* @param conn The JDBC connection.
* @param rvKey A function that will be executed once for every row to create a key.
* @param rvValue A function that will be executed once for every row to create a value.
* @return A map of objects created and indexed by the row visitors for every row.
* @throws NullPointerException If conn or rv is null.
* @throws IllegalStateException If one or more named parameters has not been set.
* @throws SQLException From the JDBC driver.
*/
public <K, V> Map<K, V> execute(final Connection conn, final RowVisitor<K> rvKey, final RowVisitor<V> rvValue) throws SQLException {
Preconditions.checkNotNull(conn, "conn is null");
Preconditions.checkNotNull(rvKey, "rvKey is null");
Preconditions.checkNotNull(rvValue, "rvValue is null");
LOG.debug("Executing query:\n{}", this);
final Map<K, V> ret = new LinkedHashMap<>();
try (final PreparedStatement stmt = conn.prepareStatement(this.build())) {
this.bindParams(stmt);
final ResultSet rs = stmt.executeQuery();
while (rs.next()) {
final K key = rvKey.visit(rs);
final V value = rvValue.visit(rs);
ret.put(key, value);
}
}
return ret;
}
@Override
public String toString() {
// Split on parameter placeholders.
final String[] parts = this.build().split("\\?");
// Join with parameter values.
int i = 0;
final StringBuffer ret = new StringBuffer();
for (final String name : this.params.keySet()) {
ret.append(parts[i++]);
final QueryParam param = this.params.get(name);
// Some parameters may still be unbound.
final String value = param == null ? "?" : param.getValue();
ret.append(value);
}
// If the query does not end with '?', we need the join the remainder.
if (i < parts.length) {
ret.append(parts[i++]);
}
return ret.toString();
}
}
import org.testng.annotations.Test;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Types;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import static org.testng.Assert.assertEquals;
public class QueryBuilderTest {
@Test(expectedExceptions = { NullPointerException.class })
public void testNullTable() {
new QueryBuilder(null);
}
@Test
public void testBuild() {
final String sql = new QueryBuilder("foo").build();
assertEquals(sql, "SELECT *\nFROM foo", "Incorrect SQL");
}
@Test
public void testAddSelect() {
final String sql = new QueryBuilder("foo").addSelect("bar").addSelect("baz").build();
assertEquals(sql, "SELECT bar, baz\nFROM foo", "Incorrect SQL");
}
@Test(expectedExceptions = { NullPointerException.class })
public void testAddSelectNull() {
new QueryBuilder("foo").addSelect(null);
}
@Test
public void testAddJoin() {
final String sql = new QueryBuilder("foo")
.addJoin("JOIN bar ON foo.id = bar.id")
.addJoin("LEFT JOIN baz ON foo.id = baz.id")
.build();
final String expectedSql = "SELECT *\nFROM foo"
+ "\nJOIN bar ON foo.id = bar.id"
+ "\nLEFT JOIN baz ON foo.id = baz.id";
assertEquals(sql, expectedSql, "Incorrect SQL");
}
@Test(expectedExceptions = { NullPointerException.class })
public void testAddJoinNull() {
new QueryBuilder("foo").addJoin(null);
}
@Test
public void testAddWhere() {
final String sql = new QueryBuilder("foo").addWhere("price >= 5").addWhere("price < 10").build();
final String expectedSql = "SELECT *\nFROM foo\nWHERE price >= 5\nAND price < 10";
assertEquals(sql, expectedSql, "Incorrect SQL");
}
@Test(expectedExceptions = { NullPointerException.class })
public void testAddWhereNull() {
new QueryBuilder("foo").addWhere(null);
}
@Test(expectedExceptions = { IllegalArgumentException.class })
public void testAddWhereMissingParamName() {
new QueryBuilder("foo").addWhere("id = ?");
}
@Test(expectedExceptions = { IllegalArgumentException.class })
public void testAddWhereExtraParamNames() {
new QueryBuilder("foo").addWhere("id = ?", "id", "price");
}
@Test
public void testBindParams() throws SQLException {
final QueryBuilder qb = new QueryBuilder("foo")
.addWhere("a = ?", "a")
.setParam("a", "1")
.addWhere("b = ?", "b")
.setParam("b", (String) null)
.addWhere("c = ?", "c")
.setParam("c", 1)
.addWhere("d = ?", "d")
.setParam("d", (Integer) null)
.addWhere("e = ?", "e")
.setParam("e", (short) 1)
.addWhere("f = ?", "f")
.setParam("f", (Short) null)
.addWhere("g = ?", "g")
.setParam("g", (byte) 1)
.addWhere("h = ?", "h")
.setParam("h", (Byte) null);
final PreparedStatement statement = mock(PreparedStatement.class);
qb.bindParams(statement);
verify(statement).setString(1, "1");
verify(statement).setNull(2, Types.VARCHAR);
verify(statement).setInt(3, 1);
verify(statement).setNull(4, Types.INTEGER);
verify(statement).setShort(5, (short) 1);
verify(statement).setNull(6, Types.SMALLINT);
verify(statement).setByte(7, (byte) 1);
verify(statement).setNull(8, Types.TINYINT);
verifyNoMoreInteractions(statement);
}
@Test
public void testExecute() throws SQLException {
final QueryBuilder qb = new QueryBuilder("foo").addWhere("x = ?", "x").setParam("x", 1);
final String sql = qb.build();
final Connection conn = mock(Connection.class);
final PreparedStatement statement = mock(PreparedStatement.class);
when(conn.prepareStatement(anyString())).thenReturn(statement);
final ResultSet rs = mock(ResultSet.class);
when(statement.executeQuery()).thenReturn(rs);
// Pretend there are 2 rows.
when(rs.next()).thenReturn(true).thenReturn(true).thenReturn(false);
List<Integer> results = qb.execute(conn, rs2 -> {
assertEquals(rs2, rs, "Got different ResultSet??");
return 1;
});
List<Integer> expectedResults = Arrays.asList(1, 1);
assertEquals(results, expectedResults, "Wrong results");
verify(conn).prepareStatement(sql);
verify(statement).setInt(1, 1);
verify(statement).executeQuery();
verify(statement).close();
verify(rs, times(3)).next();
verifyNoMoreInteractions(conn, statement, rs);
}
@Test
public void testExecuteMap() throws SQLException {
final QueryBuilder qb = new QueryBuilder("foo").addWhere("x = ?", "x").setParam("x", 1);
final String sql = qb.build();
final Connection conn = mock(Connection.class);
final PreparedStatement statement = mock(PreparedStatement.class);
when(conn.prepareStatement(anyString())).thenReturn(statement);
final ResultSet rs = mock(ResultSet.class);
when(statement.executeQuery()).thenReturn(rs);
// Pretend there are 2 rows.
when(rs.next()).thenReturn(true).thenReturn(true).thenReturn(false);
when(rs.getInt("foo")).thenReturn(1).thenReturn(2);
when(rs.getString("bar")).thenReturn("spam").thenReturn("eggs");
Map<Integer, String> results = qb.execute(conn, rs2 -> {
assertEquals(rs2, rs, "Got different ResultSet??");
return rs2.getInt("foo");
}, rs2 -> {
assertEquals(rs2, rs, "Got different ResultSet??");
return rs2.getString("bar");
});
final Map<Integer, String> expectedResults = new HashMap<>();
expectedResults.put(1, "spam");
expectedResults.put(2, "eggs");
assertEquals(results, expectedResults, "Wrong results");
verify(conn).prepareStatement(sql);
verify(statement).setInt(1, 1);
verify(statement).executeQuery();
verify(statement).close();
verify(rs, times(3)).next();
// This is from the test, not the actual code.
verify(rs, times(2)).getInt("foo");
verify(rs, times(2)).getString("bar");
verifyNoMoreInteractions(conn, statement, rs);
}
@Test
public void testToString() {
final QueryBuilder qb = new QueryBuilder("foo").addWhere("y = ? AND x = ?", "y", "x").setParam("x", 1);
final String expected = "SELECT *\nFROM foo\nWHERE y = ? AND x = '1'";
assertEquals(qb.toString(), expected, "Wrong string representation");
}
@Test
public void testToStringExtra() {
final QueryBuilder qb = new QueryBuilder("foo").addWhere("y = ? AND ? = x", "y", "x").setParam("x", 1);
final String expected = "SELECT *\nFROM foo\nWHERE y = ? AND '1' = x";
assertEquals(qb.toString(), expected, "Wrong string representation");
}
}
import java.sql.ResultSet;
import java.sql.SQLException;
/**
* A function that maps a SQL row into a user defined object.
*
* The advantage over java.util.function.Function is that RowVisitor allows implementations to throw or leak
* SQLExceptions if the SQL row is invalid or unexpected.
*
* @param <T> The type of the user defined object.
*/
@FunctionalInterface
public interface RowVisitor<T> {
/**
* Generates a user defined object from the SQL row.
*
* @param rs The SQL row.
* @return A user defined object.
* @throws SQLException May be thrown by ResultSet methods.
*/
T visit(final ResultSet rs) throws SQLException;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment