Created
December 13, 2017 15:52
-
-
Save lvv83/dbe68dd68b7caf8f045c9081e4b59856 to your computer and use it in GitHub Desktop.
PostgreSQL sequence update with spring-test-dbunit integration
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.example; | |
import java.sql.Connection; | |
import java.sql.PreparedStatement; | |
import java.sql.ResultSet; | |
import java.sql.SQLException; | |
import java.sql.Statement; | |
import java.util.ArrayList; | |
import java.util.List; | |
import org.dbunit.DatabaseUnitException; | |
import org.dbunit.database.IDatabaseConnection; | |
import org.dbunit.dataset.IDataSet; | |
import org.dbunit.operation.AbstractOperation; | |
public class PostgreUpdateSequenceOperation extends AbstractOperation { | |
class SequenceData | |
{ | |
String tableName; | |
String columnName; | |
String sequenceName; | |
} | |
private static final String PG_GET_SEQUENCES_SQL = | |
"SELECT table_name, column_name, pg_get_serial_sequence(table_schema || '.' || table_name, column_name) as seq_name " + | |
"FROM information_schema.columns " + | |
"WHERE column_default LIKE 'nextval(%' AND table_schema = ?"; | |
private List<SequenceData> getSequenceData(IDatabaseConnection connection) throws DatabaseUnitException, SQLException | |
{ | |
List<SequenceData> result = new ArrayList<SequenceData>(); | |
// Find sequences only in schema specified in dbunit connection | |
PreparedStatement sequenceStatement = connection.getConnection().prepareStatement(PG_GET_SEQUENCES_SQL); | |
sequenceStatement.setString(1, connection.getSchema()); | |
try | |
{ | |
ResultSet rs = sequenceStatement.executeQuery(); | |
while (rs.next()) | |
{ | |
SequenceData sd = new SequenceData(); | |
sd.tableName = rs.getString(1); | |
sd.columnName = rs.getString(2); | |
sd.sequenceName = rs.getString(3); | |
result.add(sd); | |
} | |
} | |
catch (SQLException e) { | |
throw new DatabaseUnitException("Error during sequence data fetch", e); | |
} | |
finally | |
{ | |
sequenceStatement.close(); | |
} | |
return result; | |
} | |
@Override | |
public void execute(IDatabaseConnection connection, IDataSet dataSet) throws DatabaseUnitException, SQLException { | |
List<SequenceData> sequences = getSequenceData(connection); | |
Connection sqlConn = connection.getConnection(); | |
String[] tableNames = dataSet.getTableNames(); | |
boolean oldAutoCommit = sqlConn.getAutoCommit(); | |
try | |
{ | |
sqlConn.setAutoCommit(false); | |
for (SequenceData seq : sequences) | |
{ | |
boolean inDataset = false; | |
for (String tableName : tableNames) | |
{ | |
if (seq.tableName.equalsIgnoreCase(tableName)) { | |
inDataset = true; | |
break; | |
} | |
} | |
if (inDataset) | |
{ | |
String qualifiedTableName = getQualifiedName(connection.getSchema(), seq.tableName, connection); | |
Statement lockStatement = sqlConn.createStatement(); | |
lockStatement.execute("LOCK TABLE " + qualifiedTableName + " IN EXCLUSIVE MODE"); | |
String sql = "SELECT setval('" + seq.sequenceName + "', COALESCE((SELECT MAX(" + seq.columnName + ") + 1 FROM " + qualifiedTableName + "), 1), false)"; | |
Statement setValStatement = sqlConn.createStatement(); | |
setValStatement.executeQuery(sql); | |
sqlConn.commit(); | |
lockStatement.close(); | |
setValStatement.close(); | |
} | |
} | |
} | |
catch (SQLException e) { | |
sqlConn.rollback(); | |
throw new DatabaseUnitException("Error during sequence reset", e); | |
} | |
finally | |
{ | |
sqlConn.setAutoCommit(oldAutoCommit); | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.example; | |
import java.util.HashMap; | |
import java.util.Map; | |
import org.dbunit.operation.CompositeOperation; | |
import com.github.springtestdbunit.annotation.DatabaseOperation; | |
import com.github.springtestdbunit.operation.DefaultDatabaseOperationLookup; | |
/** Usage with spring-test-dbunit | |
* <code> | |
* {@literal @DbUnitConfiguration(databaseOperationLookup = PostgreUpdateSequenceOperationLookup.class)} | |
* <br> | |
* {@literal @DatabaseSetup(type = DatabaseOperation.CLEAN_INSERT, value = { "/dbunit/data.xml" }) } | |
* </code> | |
*/ | |
public class PostgreUpdateSequenceOperationLookup extends DefaultDatabaseOperationLookup { | |
private static Map<DatabaseOperation, org.dbunit.operation.DatabaseOperation> LOOKUP; | |
static { | |
PostgreUpdateSequenceOperation op1 = new PostgreUpdateSequenceOperation(); | |
CompositeOperation op2 = new CompositeOperation(new org.dbunit.operation.DatabaseOperation[] { org.dbunit.operation.DatabaseOperation.DELETE_ALL, org.dbunit.operation.DatabaseOperation.INSERT, op1 }); | |
LOOKUP = new HashMap<DatabaseOperation, org.dbunit.operation.DatabaseOperation>(); | |
LOOKUP.put(DatabaseOperation.CLEAN_INSERT, op2); | |
} | |
@Override | |
public org.dbunit.operation.DatabaseOperation get(DatabaseOperation operation) { | |
if (LOOKUP.containsKey(operation)) { | |
return LOOKUP.get(operation); | |
} | |
return super.get(operation); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment