Skip to content

Instantly share code, notes, and snippets.

@janbols
Last active March 27, 2018 07:03
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save janbols/9af3655233e851cc341c605747243909 to your computer and use it in GitHub Desktop.
Save janbols/9af3655233e851cc341c605747243909 to your computer and use it in GitHub Desktop.
Returning sql data as Observable
import org.slf4j.Logger;
import org.springframework.jdbc.core.*;
import org.springframework.jdbc.core.namedparam.EmptySqlParameterSource;
import org.springframework.jdbc.core.namedparam.MapSqlParameterSource;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate;
import org.springframework.jdbc.datasource.DataSourceUtils;
import org.springframework.jdbc.support.JdbcUtils;
import rx.Observable;
import rx.Observer;
import rx.observables.SyncOnSubscribe;
import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Map;
import static org.slf4j.LoggerFactory.getLogger;
/**
* Extension of {@link NamedParameterJdbcTemplate} that allows you to return query results in an {@link Observable}
*/
public class ObservableJdbcTemplate extends NamedParameterJdbcTemplate {
private static Logger log = getLogger(ObservableJdbcTemplate.class);
private final JdbcTemplate jdbcTemplate;
public ObservableJdbcTemplate(DataSource dataSource) {
super(dataSource);
this.jdbcTemplate = new JdbcTemplate(dataSource);
}
/**
* Returns a backpressure aware {@link Observable} for the given sql and parameters
* using the given {@link RowMapper} to convert the db rows to {@link T}
*
* @param sql SQL query to execute
* @param paramMap arguments to bind to the query
* @param rowMapper object that will map one object per row
* @return the result {@link Observable}, containing mapped objects
*/
public <T> Observable<T> observableQuery(String sql, Map<String, ?> paramMap, RowMapper<T> rowMapper) {
return Observable.create(SyncOnSubscribe.createSingleState(
() -> createJdbcQueryState(sql, paramMap)
, (state, observer) -> state.handleNext(observer, rowMapper)
, JdbcQueryState::cleanup
));
}
/**
* Returns a backpressure aware {@link Observable} for the given sql using the given {@link RowMapper} to convert the db rows to {@link T}
*
* @param sql SQL query to execute
* @param rowMapper object that will map one object per row
* @return the result {@link Observable}, containing mapped objects
*/
public <T> Observable<T> observableQuery(String sql, RowMapper<T> rowMapper) {
return Observable.create(SyncOnSubscribe.createSingleState(
() -> createJdbcQueryState(sql)
, (state, observer) -> state.handleNext(observer, rowMapper)
, JdbcQueryState::cleanup
));
}
private JdbcQueryState createJdbcQueryState(String sql) {
return createJdbcQueryState(getPreparedStatementCreator(sql, EmptySqlParameterSource.INSTANCE));
}
private JdbcQueryState createJdbcQueryState(String sql, Map<String, ?> paramMap) {
return createJdbcQueryState(getPreparedStatementCreator(sql, new MapSqlParameterSource(paramMap)));
}
private JdbcQueryState createJdbcQueryState(PreparedStatementCreator psc) {
if (log.isDebugEnabled()) {
String theSql = getSql(psc);
log.debug("Executing prepared SQL statement [{}]", theSql);
}
PreparedStatement ps = null;
ResultSet rs = null;
DataSource dataSource = jdbcTemplate.getDataSource();
Connection con = DataSourceUtils.getConnection(dataSource);
try {
ps = psc.createPreparedStatement(con);
rs = ps.executeQuery();
return new JdbcQueryState(rs, con, psc, ps, dataSource);
} catch (SQLException ex) {
JdbcUtils.closeResultSet(rs);
cleanupParameters(ps);
cleanupParameters(psc);
JdbcUtils.closeStatement(ps);
DataSourceUtils.releaseConnection(con, dataSource);
throw jdbcTemplate.getExceptionTranslator().translate("PreparedStatementCallback", getSql(psc), ex);
}
}
private String getSql(PreparedStatementCreator psc) {
if (psc instanceof SqlProvider) {
return ((SqlProvider) psc).getSql();
} else {
return null;
}
}
private static void cleanupParameters(Object paramDisposer) {
if (paramDisposer instanceof ParameterDisposer) {
((ParameterDisposer) paramDisposer).cleanupParameters();
}
}
/**
* Holds internal state during iteration over the {@link ResultSet}
*/
private class JdbcQueryState {
public final ResultSet rs;
private final Connection connection;
private final PreparedStatementCreator psc;
private final PreparedStatement ps;
private final DataSource dataSource;
private volatile int rowNum = 0;
public JdbcQueryState(ResultSet rs, Connection connection, PreparedStatementCreator psc, PreparedStatement ps, DataSource dataSource) {
this.rs = rs;
this.connection = connection;
this.psc = psc;
this.ps = ps;
this.dataSource = dataSource;
}
public void cleanup() {
JdbcUtils.closeResultSet(rs);
cleanupParameters(psc);
cleanupParameters(ps);
JdbcUtils.closeStatement(ps);
DataSourceUtils.releaseConnection(connection, dataSource);
}
public <T> void handleNext(Observer<? super T> observer, RowMapper<T> rowMapper) {
try {
if (rs.next()) {
if (rowNum % 10000 == 0) log.debug("Streaming result #{} using {}", rowNum, rowMapper);
observer.onNext(rowMapper.mapRow(rs, rowNum++));
} else {
log.debug("Ended streaming {} results using {}", rowNum, rowMapper);
observer.onCompleted();
}
} catch (SQLException e) {
observer.onError(e);
}
}
}
}
import com.bayer.service.genotyperepository.AbstractDbSpecification
import org.springframework.jdbc.core.ColumnMapRowMapper
import rx.Observer
import rx.exceptions.MissingBackpressureException
import rx.observers.TestSubscriber
import rx.schedulers.Schedulers
/**
* Test that checks ObservableJdbcTemplate behaviour
*/
class ObservableJdbcTemplateIT extends Specification {
def slowObserver = new Observer() {
@Override
void onCompleted() {}
@Override
void onError(Throwable e) {}
@Override
void onNext(Object o) {
sleep(100)
println o
}
}
def dataSource = ...
def query = "select * from MY_TABLE WHERE ROWNUM <= 150"
def "when subscribing to an observable without backpressure all is fine as long as the consumer is faster than the producer"() {
given:
def jdbcTemplate = new ObservableJdbcTemplateWithoutBackpressure(dataSource)
def subscriber = new TestSubscriber(slowObserver)
when:
jdbcTemplate.query(query, new ColumnMapRowMapper())
.observeOn(Schedulers.io())
.subscribe(subscriber)
subscriber.awaitTerminalEvent()
then:
subscriber.assertError(MissingBackpressureException)
}
def "subscribing to an observable without backpressure succeeds using callstack blocking"() {
given:
def jdbcTemplate = new ObservableJdbcTemplateWithoutBackpressure(dataSource)
def subscriber = new TestSubscriber(slowObserver)
when:
jdbcTemplate.query(query, new ColumnMapRowMapper())
// .observeOn(Schedulers.io())
.subscribe(subscriber)
subscriber.awaitTerminalEvent()
then:
subscriber.assertNoErrors()
}
def "when subscribing to an observable with backpressure all is fine"() {
given:
def jdbcTemplate = new ObservableJdbcTemplate(dataSource)
def subscriber = new TestSubscriber(slowObserver)
when:
jdbcTemplate.observableQuery(query, new ColumnMapRowMapper())
.observeOn(Schedulers.io())
.subscribe(subscriber)
subscriber.awaitTerminalEvent()
then:
subscriber.assertNoErrors()
}
}
import org.slf4j.Logger;
import org.springframework.dao.DataAccessException;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.ResultSetExtractor;
import org.springframework.jdbc.core.RowMapper;
import rx.Observable;
import rx.Subscriber;
import javax.sql.DataSource;
import java.sql.ResultSet;
import static org.slf4j.LoggerFactory.getLogger;
/**
* Variant of {@link JdbcTemplate} that allows you to query and return an {@link Observable}.
* This does not support backpressure and as a consequence shouldn't be used when backpressure is needed.
*/
public class ObservableJdbcTemplateWithoutBackpressure {
private static Logger log = getLogger(ObservableJdbcTemplateWithoutBackpressure.class);
private final JdbcTemplate jdbcTemplate;
public ObservableJdbcTemplateWithoutBackpressure(DataSource dataSource) {
this.jdbcTemplate = new JdbcTemplate(dataSource);
}
/**
* Execute a query given static SQL, mapping each row to a Java object via a RowMapper.
*
* @param sql SQL query to execute
* @param rowMapper object that will map one object per row
* @return the result {@link Observable}, containing mapped objects
* @throws DataAccessException if there is any problem executing the query
*/
public <T> Observable<T> query(String sql, RowMapper<? extends T> rowMapper) {
return Observable.create(subscriber ->
this.jdbcTemplate.query(sql, extractResultToSubscriber(subscriber, rowMapper))
);
}
/**
* Returns a {@link ResultSetExtractor} that returns nothing but instead notifies the given {@link Subscriber} of T for each new item.
* Each item is created by iterating over the {@link ResultSet} and converting each row to a T using the given {@link RowMapper}
*
* @param subscriber
* @param rowMapper
* @param <T>
* @return
*/
private static <T> ResultSetExtractor<Void> extractResultToSubscriber(Subscriber<T> subscriber, RowMapper<? extends T> rowMapper) {
return rs -> {
int rowNum = 0;
while (rs.next() && !subscriber.isUnsubscribed()) {
if (rowNum % 10000 == 0) log.debug("Streaming result #{} using {}", rowNum, rowMapper);
subscriber.onNext(rowMapper.mapRow(rs, rowNum++));
}
log.debug("Ended streaming {} results using {}", rowNum, rowMapper);
subscriber.onCompleted();
return null;
};
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment