Skip to content

Instantly share code, notes, and snippets.

@giovannicandido
Created September 4, 2023 21:16
Show Gist options
  • Save giovannicandido/fecd7c2a79519797d211b1cd76aa0a8b to your computer and use it in GitHub Desktop.
Save giovannicandido/fecd7c2a79519797d211b1cd76aa0a8b to your computer and use it in GitHub Desktop.
Spring service to reset the state of database between unit tests
package com.kugelbit.mercado.estoque.test.listener;
import jakarta.transaction.Transactional;
import lombok.RequiredArgsConstructor;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.HashSet;
import java.util.Set;
@Component
@Slf4j
@RequiredArgsConstructor
public class CleanUpDbService {
private final DataSource dataSource;
private final Set<String> tablesToIgnore = Set.of("geography_columns","geometry_columns",
"spatial_ref_sys","databasechangeloglock","databasechangelog");
@SneakyThrows
@Transactional(Transactional.TxType.REQUIRES_NEW)
public void cleanup(String schemaName) {
try (Connection connection = dataSource.getConnection();) {
disableConstraints(connection);
truncateTables(connection, schemaName);
resetSequences(connection, schemaName);
enableConstraints(connection);
}
}
private void endTransaction(Connection connection) throws SQLException {
executeStatement(connection.prepareStatement("commit"));
}
private void beginTransaction(Connection con) throws SQLException {
executeStatement(con.prepareStatement("begin transaction"));
}
private void resetSequences(Connection con, String schemaName) {
getSchemaSequences(con, schemaName).forEach(sequenceName ->
{
try {
// skipcq: JAVA-A1041
executeStatement(con.prepareStatement(String.format("ALTER SEQUENCE %s RESTART WITH 1", sequenceName)));
} catch (SQLException e) {
throw new RuntimeException(e);
}
});
}
private void truncateTables(Connection con, String schemaName) {
getSchemaTables(con, schemaName)
.stream().filter(t -> !tablesToIgnore.contains(t))
.forEach(tableName -> {
try {
log.info(String.format("Truncating table %s", tableName));
// skipcq: JAVA-A1041
executeStatement(con.prepareStatement("TRUNCATE TABLE " + tableName + " CASCADE"));
} catch (SQLException e) {
throw new RuntimeException(e);
}
});
}
private void enableConstraints(Connection con) throws SQLException {
executeStatement(con.prepareStatement("set constraints all immediate"));
}
private void disableConstraints(Connection con) throws SQLException {
executeStatement(con.prepareStatement("set constraints all deferred"));
}
@SneakyThrows
private void executeStatement(PreparedStatement statement) {
statement.executeUpdate();
}
@SneakyThrows
private Set<String> getSchemaTables(Connection con, String schemaName) {
String sql = String.format("SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES where TABLE_SCHEMA='%s'", schemaName);
// skipcq: JAVA-A1041
return queryForList(con.prepareStatement(sql));
}
@SneakyThrows
private Set<String> getSchemaSequences(Connection con, String schemaName) {
String sql = String.format("SELECT SEQUENCE_NAME FROM INFORMATION_SCHEMA.SEQUENCES WHERE SEQUENCE_SCHEMA='%s'", schemaName);
// skipcq: JAVA-A1041
return queryForList(con.prepareStatement(sql));
}
@SneakyThrows
private Set<String> queryForList(PreparedStatement statement) {
Set<String> tables = new HashSet<>();
try (ResultSet rs = statement.executeQuery()) {
while (rs.next()) {
tables.add(rs.getString(1));
}
}
return tables;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment