Last active
November 18, 2015 03:41
-
-
Save okorz001/d5597ecc81e46aa8b9de to your computer and use it in GitHub Desktop.
OO SQL Query Builder
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import com.google.common.base.Preconditions; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import java.sql.Connection; | |
import java.sql.PreparedStatement; | |
import java.sql.ResultSet; | |
import java.sql.SQLException; | |
import java.sql.Types; | |
import java.util.ArrayList; | |
import java.util.Arrays; | |
import java.util.LinkedHashMap; | |
import java.util.List; | |
import java.util.Map; | |
/** | |
* Simple builder for SQL select queries. | |
*/ | |
public class QueryBuilder { | |
private static final Logger LOG = LoggerFactory.getLogger(QueryBuilder.class); | |
private static abstract class QueryParam<T> { | |
private final int sqlType; // needed for PreparedStatement.setNull | |
private final T value; | |
public QueryParam(final int sqlType, final T value) { | |
this.sqlType = sqlType; | |
this.value = value; | |
} | |
public String getValue() { | |
return this.value == null ? "null" : "'" + this.value.toString() + "'"; | |
} | |
public void bind(final PreparedStatement stmt, final int n) throws SQLException { | |
if (this.value == null) { | |
stmt.setNull(n, sqlType); | |
} else { | |
this.bind(stmt, n, value); | |
} | |
} | |
abstract void bind(final PreparedStatement stmt, final int n, final T value) throws SQLException; | |
} | |
private static class QueryParamString extends QueryParam<String> { | |
public QueryParamString(final String value) { | |
super(Types.VARCHAR, value); | |
} | |
@Override | |
protected void bind(final PreparedStatement stmt, final int n, final String value) throws SQLException { | |
stmt.setString(n, value); | |
} | |
} | |
private static class QueryParamInt extends QueryParam<Integer> { | |
public QueryParamInt(final Integer value) { | |
super(Types.INTEGER, value); | |
} | |
@Override | |
protected void bind(final PreparedStatement stmt, final int n, final Integer value) throws SQLException { | |
stmt.setInt(n, value); | |
} | |
} | |
private static class QueryParamShort extends QueryParam<Short> { | |
public QueryParamShort(final Short value) { | |
super(Types.SMALLINT, value); | |
} | |
@Override | |
void bind(final PreparedStatement stmt, final int n, final Short value) throws SQLException { | |
stmt.setShort(n, value); | |
} | |
} | |
private static class QueryParamByte extends QueryParam<Byte> { | |
public QueryParamByte(final Byte value) { | |
super(Types.TINYINT, value); | |
} | |
@Override | |
void bind(final PreparedStatement stmt, final int n, final Byte value) throws SQLException { | |
stmt.setByte(n, value); | |
} | |
} | |
private final List<String> selects; | |
private final String table; | |
private final List<String> joins; | |
private final List<String> wheres; | |
// Order is important. Values can be modified without affecting the key order. | |
private final LinkedHashMap<String, QueryParam> params; | |
/** | |
* Starts a new query. | |
* | |
* @param table The table to query. | |
* @throws NullPointerException If table is null. | |
*/ | |
public QueryBuilder(final String table) { | |
Preconditions.checkNotNull(table, "table is null"); | |
this.selects = new ArrayList<>(); | |
this.table = table; | |
this.joins = new ArrayList<>(); | |
this.wheres = new ArrayList<>(); | |
this.params = new LinkedHashMap<>(); | |
} | |
/** | |
* Adds a field to the query result. | |
* | |
* @param select The field to add to the result. | |
* @return This QueryBuilder. | |
* @throws NullPointerException If select is null. | |
*/ | |
public QueryBuilder addSelect(final String select) { | |
Preconditions.checkNotNull(select, "select is null"); | |
this.selects.add(select); | |
return this; | |
} | |
/** | |
* Adds a join to the query. | |
* | |
* @param join Join to add to the query. | |
* @return This QueryBuilder. | |
* @throws NullPointerException If join is null. | |
*/ | |
public QueryBuilder addJoin(final String join) { | |
Preconditions.checkNotNull(join, "join is null"); | |
this.joins.add(join); | |
return this; | |
} | |
/** | |
* Adds a filter to the query. | |
* | |
* @param where The filter to add to the query. | |
* @param paramNames The names of the parameters in this filter. | |
* @return This QueryBuilder. | |
* @throws NullPointerException If where or any paramName is null. | |
* @throws IllegalArgumentException If the size of paramNames does not match the number of parameters. | |
*/ | |
public QueryBuilder addWhere(final String where, final String... paramNames) { | |
return this.addWhere(where, Arrays.asList(paramNames)); | |
} | |
/** | |
* Adds a filter to the query. | |
* | |
* @param where The filter to add to the query. | |
* @param paramNames The names of the parameters in this filter. | |
* @return This QueryBuilder. | |
* @throws NullPointerException If where or any paramName is null. | |
* @throws IllegalArgumentException If the size of paramNames does not match the number of parameters. | |
*/ | |
public QueryBuilder addWhere(final String where, final List<String> paramNames) { | |
Preconditions.checkNotNull(where, "where is null"); | |
for (int i = 0; i < paramNames.size(); ++i) { | |
Preconditions.checkNotNull(paramNames.get(i), "paramName[" + i + "] is null"); | |
} | |
// Count the question marks. | |
final long paramCount = where.codePoints().filter(c -> c == '?').count(); | |
Preconditions.checkArgument(paramNames.size() == paramCount, | |
"paramNames wrong size: " + paramNames.size() + " != " + paramCount); | |
this.wheres.add(where); | |
paramNames.forEach(paramName -> this.params.put(paramName, null)); | |
return this; | |
} | |
/** | |
* Builds a SQL query. | |
* | |
* @return A SQL query. | |
*/ | |
public String build() { | |
final StringBuilder buf = new StringBuilder(); | |
buf.append("SELECT "); | |
if (this.selects.isEmpty()) { | |
// Default to everything. | |
buf.append('*'); | |
} | |
else { | |
buf.append(String.join(", ", this.selects)); | |
} | |
buf.append("\nFROM "); | |
buf.append(this.table); | |
if (!this.joins.isEmpty()) { | |
buf.append('\n'); | |
buf.append(String.join("\n", this.joins)); | |
} | |
if (!this.wheres.isEmpty()) { | |
buf.append("\nWHERE "); | |
buf.append(String.join("\nAND ", this.wheres)); | |
} | |
final String sql = buf.toString(); | |
//LOG.debug("Built query:\n{}", sql); | |
return sql; | |
} | |
/** | |
* Sets the value of the named parameter. | |
* | |
* @param name The parameter name. | |
* @param value The parameter value. | |
* @return This QueryBuilder. | |
* @throws NullPointerException If name is null. | |
* @throws IllegalArgumentException If the named parameter does not exist. | |
*/ | |
public QueryBuilder setParam(final String name, final String value) { | |
Preconditions.checkNotNull(name, "name is null"); | |
Preconditions.checkArgument(this.params.containsKey(name), "invalid name: " + name); | |
this.params.put(name, new QueryParamString(value)); | |
return this; | |
} | |
/** | |
* Sets the value of the named parameter. | |
* | |
* @param name The parameter name. | |
* @param value The parameter value. | |
* @return This QueryBuilder. | |
* @throws NullPointerException If name is null. | |
* @throws IllegalArgumentException If the named parameter does not exist. | |
*/ | |
public QueryBuilder setParam(final String name, final Integer value) { | |
Preconditions.checkNotNull(name, "name is null"); | |
Preconditions.checkArgument(this.params.containsKey(name), "invalid name: " + name); | |
this.params.put(name, new QueryParamInt(value)); | |
return this; | |
} | |
/** | |
* Sets the value of the named parameter. | |
* | |
* @param name The parameter name. | |
* @param value The parameter value. | |
* @return This QueryBuilder. | |
* @throws NullPointerException If name is null. | |
* @throws IllegalArgumentException If the named parameter does not exist. | |
*/ | |
public QueryBuilder setParam(final String name, final Short value) { | |
Preconditions.checkNotNull(name, "name is null"); | |
Preconditions.checkArgument(this.params.containsKey(name), "invalid name: " + name); | |
this.params.put(name, new QueryParamShort(value)); | |
return this; | |
} | |
/** | |
* Sets the value of the named parameter. | |
* | |
* @param name The parameter name. | |
* @param value The parameter value. | |
* @return This QueryBuilder. | |
* @throws NullPointerException If name is null. | |
* @throws IllegalArgumentException If the named parameter does not exist. | |
*/ | |
public QueryBuilder setParam(final String name, final Byte value) { | |
Preconditions.checkNotNull(name, "name is null"); | |
Preconditions.checkArgument(this.params.containsKey(name), "invalid name: " + name); | |
this.params.put(name, new QueryParamByte(value)); | |
return this; | |
} | |
/** | |
* Binds the parameter values to a PreparedStatement instance. | |
* | |
* This method assumes the PreparedStatement was created with the SQL query generated by this QueryBuilder. | |
* | |
* This method has a weak exception guarantee: partial state changes will not be rolled back. | |
* | |
* @param stmt The PreparedStatement to set variables on. | |
* @throws NullPointerException If stmt is null. | |
* @throws IllegalStateException If one or more named parameters has not been set. | |
* @throws SQLException From the JDBC driver. | |
*/ | |
public void bindParams(final PreparedStatement stmt) throws SQLException { | |
Preconditions.checkNotNull(stmt, "stmt is null"); | |
int i = 1; // JDBC parameters start from 1, not 0! | |
for (final String name : this.params.keySet()) { | |
final QueryParam param = this.params.get(name); | |
if (param == null) { | |
throw new IllegalStateException("parameter has not been set: " + name); | |
} | |
param.bind(stmt, i++); | |
} | |
} | |
/** | |
* Executes the query and returns user defined data generated from the rows. | |
* | |
* This method builds the query, binds the parameters values and executes it with the given Connection. The | |
* RowVisitor is called to generate a value for every SQL row. These values are collected and returned to the | |
* caller. | |
* | |
* @param <T> The data type created from every row. | |
* @param conn The JDBC connection. | |
* @param rv A function that will be executed once for every row. | |
* @return A list of objects returned by the row visitor for every row. | |
* @throws NullPointerException If conn or rv is null. | |
* @throws IllegalStateException If one or more named parameters has not been set. | |
* @throws SQLException From the JDBC driver. | |
*/ | |
public <T> List<T> execute(final Connection conn, final RowVisitor<T> rv) throws SQLException { | |
Preconditions.checkNotNull(conn, "conn is null"); | |
Preconditions.checkNotNull(rv, "rv is null"); | |
LOG.debug("Executing query:\n{}", this); | |
final List<T> ret = new ArrayList<>(); | |
try (final PreparedStatement stmt = conn.prepareStatement(this.build())) { | |
this.bindParams(stmt); | |
final ResultSet rs = stmt.executeQuery(); | |
while (rs.next()) { | |
final T data = rv.visit(rs); | |
ret.add(data); | |
} | |
} | |
return ret; | |
} | |
/** | |
* Executes the query and returns user defined data generated from the rows. | |
* | |
* This method builds the query, binds the parameters values and executes it with the given Connection. The | |
* RowVisitors are called to generate a key and a value for every SQL row. These values are collected and returned | |
* to the caller. | |
* | |
* @param <K> The key type created from every row. | |
* @param <V> The value type created from every row. | |
* @param conn The JDBC connection. | |
* @param rvKey A function that will be executed once for every row to create a key. | |
* @param rvValue A function that will be executed once for every row to create a value. | |
* @return A map of objects created and indexed by the row visitors for every row. | |
* @throws NullPointerException If conn or rv is null. | |
* @throws IllegalStateException If one or more named parameters has not been set. | |
* @throws SQLException From the JDBC driver. | |
*/ | |
public <K, V> Map<K, V> execute(final Connection conn, final RowVisitor<K> rvKey, final RowVisitor<V> rvValue) throws SQLException { | |
Preconditions.checkNotNull(conn, "conn is null"); | |
Preconditions.checkNotNull(rvKey, "rvKey is null"); | |
Preconditions.checkNotNull(rvValue, "rvValue is null"); | |
LOG.debug("Executing query:\n{}", this); | |
final Map<K, V> ret = new LinkedHashMap<>(); | |
try (final PreparedStatement stmt = conn.prepareStatement(this.build())) { | |
this.bindParams(stmt); | |
final ResultSet rs = stmt.executeQuery(); | |
while (rs.next()) { | |
final K key = rvKey.visit(rs); | |
final V value = rvValue.visit(rs); | |
ret.put(key, value); | |
} | |
} | |
return ret; | |
} | |
@Override | |
public String toString() { | |
// Split on parameter placeholders. | |
final String[] parts = this.build().split("\\?"); | |
// Join with parameter values. | |
int i = 0; | |
final StringBuffer ret = new StringBuffer(); | |
for (final String name : this.params.keySet()) { | |
ret.append(parts[i++]); | |
final QueryParam param = this.params.get(name); | |
// Some parameters may still be unbound. | |
final String value = param == null ? "?" : param.getValue(); | |
ret.append(value); | |
} | |
// If the query does not end with '?', we need the join the remainder. | |
if (i < parts.length) { | |
ret.append(parts[i++]); | |
} | |
return ret.toString(); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.testng.annotations.Test; | |
import java.sql.Connection; | |
import java.sql.PreparedStatement; | |
import java.sql.ResultSet; | |
import java.sql.SQLException; | |
import java.sql.Types; | |
import java.util.Arrays; | |
import java.util.HashMap; | |
import java.util.List; | |
import java.util.Map; | |
import static org.mockito.Matchers.anyString; | |
import static org.mockito.Mockito.mock; | |
import static org.mockito.Mockito.times; | |
import static org.mockito.Mockito.verify; | |
import static org.mockito.Mockito.verifyNoMoreInteractions; | |
import static org.mockito.Mockito.when; | |
import static org.testng.Assert.assertEquals; | |
public class QueryBuilderTest { | |
@Test(expectedExceptions = { NullPointerException.class }) | |
public void testNullTable() { | |
new QueryBuilder(null); | |
} | |
@Test | |
public void testBuild() { | |
final String sql = new QueryBuilder("foo").build(); | |
assertEquals(sql, "SELECT *\nFROM foo", "Incorrect SQL"); | |
} | |
@Test | |
public void testAddSelect() { | |
final String sql = new QueryBuilder("foo").addSelect("bar").addSelect("baz").build(); | |
assertEquals(sql, "SELECT bar, baz\nFROM foo", "Incorrect SQL"); | |
} | |
@Test(expectedExceptions = { NullPointerException.class }) | |
public void testAddSelectNull() { | |
new QueryBuilder("foo").addSelect(null); | |
} | |
@Test | |
public void testAddJoin() { | |
final String sql = new QueryBuilder("foo") | |
.addJoin("JOIN bar ON foo.id = bar.id") | |
.addJoin("LEFT JOIN baz ON foo.id = baz.id") | |
.build(); | |
final String expectedSql = "SELECT *\nFROM foo" | |
+ "\nJOIN bar ON foo.id = bar.id" | |
+ "\nLEFT JOIN baz ON foo.id = baz.id"; | |
assertEquals(sql, expectedSql, "Incorrect SQL"); | |
} | |
@Test(expectedExceptions = { NullPointerException.class }) | |
public void testAddJoinNull() { | |
new QueryBuilder("foo").addJoin(null); | |
} | |
@Test | |
public void testAddWhere() { | |
final String sql = new QueryBuilder("foo").addWhere("price >= 5").addWhere("price < 10").build(); | |
final String expectedSql = "SELECT *\nFROM foo\nWHERE price >= 5\nAND price < 10"; | |
assertEquals(sql, expectedSql, "Incorrect SQL"); | |
} | |
@Test(expectedExceptions = { NullPointerException.class }) | |
public void testAddWhereNull() { | |
new QueryBuilder("foo").addWhere(null); | |
} | |
@Test(expectedExceptions = { IllegalArgumentException.class }) | |
public void testAddWhereMissingParamName() { | |
new QueryBuilder("foo").addWhere("id = ?"); | |
} | |
@Test(expectedExceptions = { IllegalArgumentException.class }) | |
public void testAddWhereExtraParamNames() { | |
new QueryBuilder("foo").addWhere("id = ?", "id", "price"); | |
} | |
@Test | |
public void testBindParams() throws SQLException { | |
final QueryBuilder qb = new QueryBuilder("foo") | |
.addWhere("a = ?", "a") | |
.setParam("a", "1") | |
.addWhere("b = ?", "b") | |
.setParam("b", (String) null) | |
.addWhere("c = ?", "c") | |
.setParam("c", 1) | |
.addWhere("d = ?", "d") | |
.setParam("d", (Integer) null) | |
.addWhere("e = ?", "e") | |
.setParam("e", (short) 1) | |
.addWhere("f = ?", "f") | |
.setParam("f", (Short) null) | |
.addWhere("g = ?", "g") | |
.setParam("g", (byte) 1) | |
.addWhere("h = ?", "h") | |
.setParam("h", (Byte) null); | |
final PreparedStatement statement = mock(PreparedStatement.class); | |
qb.bindParams(statement); | |
verify(statement).setString(1, "1"); | |
verify(statement).setNull(2, Types.VARCHAR); | |
verify(statement).setInt(3, 1); | |
verify(statement).setNull(4, Types.INTEGER); | |
verify(statement).setShort(5, (short) 1); | |
verify(statement).setNull(6, Types.SMALLINT); | |
verify(statement).setByte(7, (byte) 1); | |
verify(statement).setNull(8, Types.TINYINT); | |
verifyNoMoreInteractions(statement); | |
} | |
@Test | |
public void testExecute() throws SQLException { | |
final QueryBuilder qb = new QueryBuilder("foo").addWhere("x = ?", "x").setParam("x", 1); | |
final String sql = qb.build(); | |
final Connection conn = mock(Connection.class); | |
final PreparedStatement statement = mock(PreparedStatement.class); | |
when(conn.prepareStatement(anyString())).thenReturn(statement); | |
final ResultSet rs = mock(ResultSet.class); | |
when(statement.executeQuery()).thenReturn(rs); | |
// Pretend there are 2 rows. | |
when(rs.next()).thenReturn(true).thenReturn(true).thenReturn(false); | |
List<Integer> results = qb.execute(conn, rs2 -> { | |
assertEquals(rs2, rs, "Got different ResultSet??"); | |
return 1; | |
}); | |
List<Integer> expectedResults = Arrays.asList(1, 1); | |
assertEquals(results, expectedResults, "Wrong results"); | |
verify(conn).prepareStatement(sql); | |
verify(statement).setInt(1, 1); | |
verify(statement).executeQuery(); | |
verify(statement).close(); | |
verify(rs, times(3)).next(); | |
verifyNoMoreInteractions(conn, statement, rs); | |
} | |
@Test | |
public void testExecuteMap() throws SQLException { | |
final QueryBuilder qb = new QueryBuilder("foo").addWhere("x = ?", "x").setParam("x", 1); | |
final String sql = qb.build(); | |
final Connection conn = mock(Connection.class); | |
final PreparedStatement statement = mock(PreparedStatement.class); | |
when(conn.prepareStatement(anyString())).thenReturn(statement); | |
final ResultSet rs = mock(ResultSet.class); | |
when(statement.executeQuery()).thenReturn(rs); | |
// Pretend there are 2 rows. | |
when(rs.next()).thenReturn(true).thenReturn(true).thenReturn(false); | |
when(rs.getInt("foo")).thenReturn(1).thenReturn(2); | |
when(rs.getString("bar")).thenReturn("spam").thenReturn("eggs"); | |
Map<Integer, String> results = qb.execute(conn, rs2 -> { | |
assertEquals(rs2, rs, "Got different ResultSet??"); | |
return rs2.getInt("foo"); | |
}, rs2 -> { | |
assertEquals(rs2, rs, "Got different ResultSet??"); | |
return rs2.getString("bar"); | |
}); | |
final Map<Integer, String> expectedResults = new HashMap<>(); | |
expectedResults.put(1, "spam"); | |
expectedResults.put(2, "eggs"); | |
assertEquals(results, expectedResults, "Wrong results"); | |
verify(conn).prepareStatement(sql); | |
verify(statement).setInt(1, 1); | |
verify(statement).executeQuery(); | |
verify(statement).close(); | |
verify(rs, times(3)).next(); | |
// This is from the test, not the actual code. | |
verify(rs, times(2)).getInt("foo"); | |
verify(rs, times(2)).getString("bar"); | |
verifyNoMoreInteractions(conn, statement, rs); | |
} | |
@Test | |
public void testToString() { | |
final QueryBuilder qb = new QueryBuilder("foo").addWhere("y = ? AND x = ?", "y", "x").setParam("x", 1); | |
final String expected = "SELECT *\nFROM foo\nWHERE y = ? AND x = '1'"; | |
assertEquals(qb.toString(), expected, "Wrong string representation"); | |
} | |
@Test | |
public void testToStringExtra() { | |
final QueryBuilder qb = new QueryBuilder("foo").addWhere("y = ? AND ? = x", "y", "x").setParam("x", 1); | |
final String expected = "SELECT *\nFROM foo\nWHERE y = ? AND '1' = x"; | |
assertEquals(qb.toString(), expected, "Wrong string representation"); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.sql.ResultSet; | |
import java.sql.SQLException; | |
/** | |
* A function that maps a SQL row into a user defined object. | |
* | |
* The advantage over java.util.function.Function is that RowVisitor allows implementations to throw or leak | |
* SQLExceptions if the SQL row is invalid or unexpected. | |
* | |
* @param <T> The type of the user defined object. | |
*/ | |
@FunctionalInterface | |
public interface RowVisitor<T> { | |
/** | |
* Generates a user defined object from the SQL row. | |
* | |
* @param rs The SQL row. | |
* @return A user defined object. | |
* @throws SQLException May be thrown by ResultSet methods. | |
*/ | |
T visit(final ResultSet rs) throws SQLException; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment