Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save MikaelAmborn/828164 to your computer and use it in GitHub Desktop.
Save MikaelAmborn/828164 to your computer and use it in GitHub Desktop.
package se.avega.dbunit;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.Map;
import javax.sql.DataSource;
import org.dbunit.database.DatabaseConfig;
import org.dbunit.database.DatabaseDataSourceConnection;
import org.dbunit.dataset.DataSetException;
import org.dbunit.dataset.IDataSet;
import org.dbunit.dataset.xml.FlatXmlDataSetBuilder;
import org.dbunit.ext.hsqldb.HsqldbDataTypeFactory;
import org.dbunit.operation.DatabaseOperation;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.jdbc.datasource.DataSourceUtils;
import org.springframework.test.context.TestContext;
import org.springframework.test.context.support.AbstractTestExecutionListener;
import org.springframework.test.context.transaction.TransactionConfiguration;
import org.springframework.test.context.transaction.TransactionConfigurationAttributes;
import org.springframework.transaction.PlatformTransactionManager;
import org.springframework.transaction.TransactionDefinition;
import org.springframework.transaction.TransactionStatus;
import org.springframework.transaction.annotation.AnnotationTransactionAttributeSource;
import org.springframework.transaction.interceptor.DefaultTransactionAttribute;
import org.springframework.transaction.interceptor.TransactionAttributeSource;
public class DbUnitTransactionPerTestClassListener extends AbstractTestExecutionListener {
protected final TransactionAttributeSource attributeSource = new AnnotationTransactionAttributeSource();
private TransactionConfigurationAttributes configurationAttributes;
private volatile int transactionsStarted = 0;
private final Map<Class<?>, TransactionContext> transactionContextCache =
Collections.synchronizedMap(new IdentityHashMap<Class<?>, TransactionContext>());
@Override
public void beforeTestClass(TestContext testContext) throws Exception {
final Class<?> testClass = testContext.getTestClass();
if (this.transactionContextCache.remove(testClass) != null) {
throw new IllegalStateException("Cannot start new transaction without ending existing transaction: " +
"Invoke endTransaction() before startNewTransaction().");
}
PlatformTransactionManager tm = getTransactionManager(testContext);
TransactionContext txContext = new TransactionContext(tm,new DefaultTransactionAttribute());
startNewTransaction(testContext, txContext);
this.transactionContextCache.put(testClass, txContext);
insertData(testContext);
}
private void insertData(TestContext testContext) throws Exception {
DataSource ds = (DataSource) testContext.getApplicationContext().getBean("dataSource");
DatabaseDataSourceConnection jdbcConnection = new SpringDatabaseDataSourceConnection(ds);
jdbcConnection.getConfig().setProperty(DatabaseConfig.PROPERTY_DATATYPE_FACTORY, new HsqldbDataTypeFactory());
try {
String filename = testContext.getTestClass().getSimpleName() + ".xml";
DatabaseOperation.INSERT.execute(jdbcConnection, createDataSetFromFile(filename));
} finally {
DataSourceUtils.releaseConnection(jdbcConnection.getConnection(), ds);
}
}
private IDataSet createDataSetFromFile(String fileName) throws DataSetException, FileNotFoundException {
// Our DbUnit xml files is in a directory called testData so add that to the filename.
return new FlatXmlDataSetBuilder().build(new FileInputStream("testData/" + fileName));
}
@Override
public void afterTestClass(TestContext testContext) throws Exception {
Class<?> testClass = testContext.getTestClass();
// If the transaction is still active...
TransactionContext txContext = this.transactionContextCache.remove(testClass);
if (txContext != null && !txContext.transactionStatus.isCompleted()) {
endTransaction(testContext, txContext);
}
}
private void startNewTransaction(TestContext testContext, TransactionContext txContext) throws Exception {
txContext.startTransaction();
++this.transactionsStarted;
}
private void endTransaction(TestContext testContext, TransactionContext txContext) throws Exception {
boolean rollback = isRollback(testContext);
txContext.endTransaction(rollback);
}
protected final PlatformTransactionManager getTransactionManager(TestContext testContext) {
String tmName = retrieveConfigurationAttributes(testContext).getTransactionManagerName();
return testContext.getApplicationContext().getBean(tmName, PlatformTransactionManager.class);
}
protected final boolean isRollback(TestContext testContext) throws Exception {
return true;
}
private TransactionConfigurationAttributes retrieveConfigurationAttributes(TestContext testContext) {
if (this.configurationAttributes == null) {
Class<?> clazz = testContext.getTestClass();
Class<TransactionConfiguration> annotationType = TransactionConfiguration.class;
TransactionConfiguration config = clazz.getAnnotation(annotationType);
String transactionManagerName;
boolean defaultRollback;
if (config != null) {
transactionManagerName = config.transactionManager();
defaultRollback = config.defaultRollback();
}
else {
transactionManagerName = (String) AnnotationUtils.getDefaultValue(annotationType, "transactionManager");
defaultRollback = (Boolean) AnnotationUtils.getDefaultValue(annotationType, "defaultRollback");
}
TransactionConfigurationAttributes configAttributes =
new TransactionConfigurationAttributes(transactionManagerName, defaultRollback);
this.configurationAttributes = configAttributes;
}
return this.configurationAttributes;
}
private static class TransactionContext {
private final PlatformTransactionManager transactionManager;
private final TransactionDefinition transactionDefinition;
private TransactionStatus transactionStatus;
public TransactionContext(PlatformTransactionManager transactionManager, TransactionDefinition transactionDefinition) {
this.transactionManager = transactionManager;
this.transactionDefinition = transactionDefinition;
}
public void startTransaction() {
this.transactionStatus = this.transactionManager.getTransaction(this.transactionDefinition);
}
public void endTransaction(boolean rollback) {
if (rollback) {
this.transactionManager.rollback(this.transactionStatus);
}
else {
this.transactionManager.commit(this.transactionStatus);
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment