Skip to content

Instantly share code, notes, and snippets.

@stellingsimon
Last active September 9, 2023 08:50
Show Gist options
  • Star 9 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save stellingsimon/f09005f666b01f2560d73f9e603b2a97 to your computer and use it in GitHub Desktop.
Save stellingsimon/f09005f666b01f2560d73f9e603b2a97 to your computer and use it in GitHub Desktop.
Recording row-level auditing information from an application's user session context in a Postgres DB (9.5+)
package example.postgres.audit;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.sql.CallableStatement;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Savepoint;
import java.sql.Statement;
import java.sql.Timestamp;
import java.time.Instant;
import java.util.Objects;
import java.util.UUID;
import static java.lang.String.format;
/**
* Whenever the DB is manipulated through this connection, Postgres receives the context variables used in the 'trigger_audit' triggers beforehand.
*/
public class AuditingConnection extends DelegatingConnection {
private static final Logger LOGGER = LoggerFactory.getLogger(AuditingConnection.class);
// TODO: implement the static method `AuditContextProvider.currentUserId()`,
// returning whatever value you'd like persisted in the columns `inserted_by`/`updated_by`
// from a thread-local context
private static final AuditContextProvider AUDIT_CONTEXT_PROVIDER = new AuditContextProvider();
/**
* Remembers the audit.AUDIT_USER set for the ongoing transaction, such that we only send SET LOCAL commands when necessary.
*/
private UUID userCommunicatedToDB;
public AuditingConnection(Connection wrappedConnection) {
super(wrappedConnection);
}
private void setContext() {
UUID user = AUDIT_CONTEXT_PROVIDER.currentUserId();
// The user reported by AUDIT_CONTEXT_PROVIDER may change within the same ongoing transaction,
// e.g. because the user session is initialized after the first DB statements.
// Detect changes and propagate the information to the DB iff necessary:
if (!Objects.equals(user, userCommunicatedToDB)) {
Timestamp nowTimestamp = Timestamp.from(Instant.now());
if (isAutoCommit()) {
// If SET LOCAL is called outside a transaction, it has no effect besides Postgres issuing a warning.
// BUT WHY THE HELL ARE WE USING AUTO-COMMIT???
LOGGER.warn("Did not set AUDIT context variables because connection was in AUTO-COMMIT mode.");
return;
}
try (Statement auditUserStmt = createStatementWithoutAuditing();
Statement auditTimestampStmt = createStatementWithoutAuditing()) {
auditUserStmt.execute(format("SET LOCAL audit.AUDIT_USER = '%s'", user.toString()));
auditTimestampStmt.execute(format("SET LOCAL audit.AUDIT_TIMESTAMP = '%s'", nowTimestamp.toString()));
} catch (SQLException e) {
LOGGER.warn("failed to set AUDIT information", e);
}
this.userCommunicatedToDB = user;
}
}
private Statement createStatementWithoutAuditing() throws SQLException {
return super.createStatement();
}
private void onTransactionEnd() {
// Values set with SET LOCAL are purged automatically by Postgres after a COMMIT/ROLLBACK [TO SAVEPOINT].
this.userCommunicatedToDB = null;
}
private boolean isAutoCommit() {
try {
return getAutoCommit();
} catch (SQLException e) {
LOGGER.warn("failed to read AUTO-COMMIT status", e);
return false; // so we attempt to set the audit info anyway...
}
}
/*
DELEGATION WITH AUDITING:
These methods either manipulate data in the DB or the transaction state of this Connection.
All other methods don't and therefore use direct delegation.
*/
@Override
public Statement createStatement() throws SQLException {
setContext();
return super.createStatement();
}
@Override
public PreparedStatement prepareStatement(String sql) throws SQLException {
setContext();
return super.prepareStatement(sql);
}
@Override
public CallableStatement prepareCall(String sql) throws SQLException {
setContext();
return super.prepareCall(sql);
}
@Override
public Statement createStatement(int resultSetType, int resultSetConcurrency) throws SQLException {
setContext();
return super.createStatement(resultSetType, resultSetConcurrency);
}
@Override
public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency) throws SQLException {
setContext();
return super.prepareStatement(sql, resultSetType, resultSetConcurrency);
}
@Override
public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency) throws SQLException {
setContext();
return super.prepareCall(sql, resultSetType, resultSetConcurrency);
}
@Override
public void setAutoCommit(boolean autoCommit) throws SQLException {
super.setAutoCommit(autoCommit);
onTransactionEnd();
}
@Override
public void commit() throws SQLException {
super.commit();
onTransactionEnd();
}
@Override
public void rollback() throws SQLException {
super.rollback();
onTransactionEnd();
}
@Override
public Statement createStatement(int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException {
setContext();
return super.createStatement(resultSetType, resultSetConcurrency, resultSetHoldability);
}
@Override
public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException {
setContext();
return super.prepareStatement(sql, resultSetType, resultSetConcurrency, resultSetHoldability);
}
@Override
public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException {
setContext();
return super.prepareCall(sql, resultSetType, resultSetConcurrency, resultSetHoldability);
}
@Override
public PreparedStatement prepareStatement(String sql, int autoGeneratedKeys) throws SQLException {
setContext();
return super.prepareStatement(sql, autoGeneratedKeys);
}
@Override
public PreparedStatement prepareStatement(String sql, int[] columnIndexes) throws SQLException {
setContext();
return super.prepareStatement(sql, columnIndexes);
}
@Override
public PreparedStatement prepareStatement(String sql, String[] columnNames) throws SQLException {
setContext();
return super.prepareStatement(sql, columnNames);
}
@Override
public void rollback(Savepoint savepoint) throws SQLException {
super.rollback(savepoint);
// if the SET LOCAL happened before the 'setSavepoint(savepoint)' call, we might have lost our context values:
onTransactionEnd();
}
@Override
public void setTransactionIsolation(int level) throws SQLException {
super.setTransactionIsolation(level);
onTransactionEnd();
}
}
package example.postgres.audit;
import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.SQLException;
/**
* A DataSource wrapper that provides {@link AuditingConnection}s.
*/
public class AuditingDataSource extends DelegatingDataSource {
private final static AuditContextProvider AUDIT_CONTEXT_PROVIDER = new AuditContextProvider();
public AuditingDataSource(DataSource wrappedDataSource) {
super(wrappedDataSource);
}
@Override
public Connection getConnection() throws SQLException {
return new AuditingConnection(super.getConnection());
}
@Override
public Connection getConnection(String username, String password) throws SQLException {
return new AuditingConnection(super.getConnection(username, password));
}
}
package example.postgres.audit;
import java.sql.Array;
import java.sql.Blob;
import java.sql.CallableStatement;
import java.sql.Clob;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.NClob;
import java.sql.PreparedStatement;
import java.sql.SQLClientInfoException;
import java.sql.SQLException;
import java.sql.SQLWarning;
import java.sql.SQLXML;
import java.sql.Savepoint;
import java.sql.Statement;
import java.sql.Struct;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.Executor;
/**
* Template-style base class for wrappers of a JDBC Connection.
*/
public abstract class DelegatingConnection implements Connection {
private final Connection delegate;
public DelegatingConnection(Connection delegate) {
this.delegate = delegate;
}
public Statement createStatement() throws SQLException {
return delegate.createStatement();
}
public PreparedStatement prepareStatement(String sql) throws SQLException {
return delegate.prepareStatement(sql);
}
public CallableStatement prepareCall(String sql) throws SQLException {
return delegate.prepareCall(sql);
}
public String nativeSQL(String sql) throws SQLException {
return delegate.nativeSQL(sql);
}
public void setAutoCommit(boolean autoCommit) throws SQLException {
delegate.setAutoCommit(autoCommit);
}
public boolean getAutoCommit() throws SQLException {
return delegate.getAutoCommit();
}
public void commit() throws SQLException {
delegate.commit();
}
public void rollback() throws SQLException {
delegate.rollback();
}
public void close() throws SQLException {
delegate.close();
}
public boolean isClosed() throws SQLException {
return delegate.isClosed();
}
public DatabaseMetaData getMetaData() throws SQLException {
return delegate.getMetaData();
}
public void setReadOnly(boolean readOnly) throws SQLException {
delegate.setReadOnly(readOnly);
}
public boolean isReadOnly() throws SQLException {
return delegate.isReadOnly();
}
public void setCatalog(String catalog) throws SQLException {
delegate.setCatalog(catalog);
}
public String getCatalog() throws SQLException {
return delegate.getCatalog();
}
public void setTransactionIsolation(int level) throws SQLException {
delegate.setTransactionIsolation(level);
}
public int getTransactionIsolation() throws SQLException {
return delegate.getTransactionIsolation();
}
public SQLWarning getWarnings() throws SQLException {
return delegate.getWarnings();
}
public void clearWarnings() throws SQLException {
delegate.clearWarnings();
}
public Statement createStatement(int resultSetType, int resultSetConcurrency) throws SQLException {
return delegate.createStatement(resultSetType, resultSetConcurrency);
}
public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency) throws SQLException {
return delegate.prepareStatement(sql, resultSetType, resultSetConcurrency);
}
public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency) throws SQLException {
return delegate.prepareCall(sql, resultSetType, resultSetConcurrency);
}
public Map<String, Class<?>> getTypeMap() throws SQLException {
return delegate.getTypeMap();
}
public void setTypeMap(Map<String, Class<?>> map) throws SQLException {
delegate.setTypeMap(map);
}
public void setHoldability(int holdability) throws SQLException {
delegate.setHoldability(holdability);
}
public int getHoldability() throws SQLException {
return delegate.getHoldability();
}
public Savepoint setSavepoint() throws SQLException {
return delegate.setSavepoint();
}
public Savepoint setSavepoint(String name) throws SQLException {
return delegate.setSavepoint(name);
}
public void rollback(Savepoint savepoint) throws SQLException {
delegate.rollback(savepoint);
}
public void releaseSavepoint(Savepoint savepoint) throws SQLException {
delegate.releaseSavepoint(savepoint);
}
public Statement createStatement(int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException {
return delegate.createStatement(resultSetType, resultSetConcurrency, resultSetHoldability);
}
public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException {
return delegate.prepareStatement(sql, resultSetType, resultSetConcurrency, resultSetHoldability);
}
public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException {
return delegate.prepareCall(sql, resultSetType, resultSetConcurrency, resultSetHoldability);
}
public PreparedStatement prepareStatement(String sql, int autoGeneratedKeys) throws SQLException {
return delegate.prepareStatement(sql, autoGeneratedKeys);
}
public PreparedStatement prepareStatement(String sql, int[] columnIndexes) throws SQLException {
return delegate.prepareStatement(sql, columnIndexes);
}
public PreparedStatement prepareStatement(String sql, String[] columnNames) throws SQLException {
return delegate.prepareStatement(sql, columnNames);
}
public Clob createClob() throws SQLException {
return delegate.createClob();
}
public Blob createBlob() throws SQLException {
return delegate.createBlob();
}
public NClob createNClob() throws SQLException {
return delegate.createNClob();
}
public SQLXML createSQLXML() throws SQLException {
return delegate.createSQLXML();
}
public boolean isValid(int timeout) throws SQLException {
return delegate.isValid(timeout);
}
public void setClientInfo(String name, String value) throws SQLClientInfoException {
delegate.setClientInfo(name, value);
}
public void setClientInfo(Properties properties) throws SQLClientInfoException {
delegate.setClientInfo(properties);
}
public String getClientInfo(String name) throws SQLException {
return delegate.getClientInfo(name);
}
public Properties getClientInfo() throws SQLException {
return delegate.getClientInfo();
}
public Array createArrayOf(String typeName, Object[] elements) throws SQLException {
return delegate.createArrayOf(typeName, elements);
}
public Struct createStruct(String typeName, Object[] attributes) throws SQLException {
return delegate.createStruct(typeName, attributes);
}
public void setSchema(String schema) throws SQLException {
delegate.setSchema(schema);
}
public String getSchema() throws SQLException {
return delegate.getSchema();
}
public void abort(Executor executor) throws SQLException {
delegate.abort(executor);
}
public void setNetworkTimeout(Executor executor, int milliseconds) throws SQLException {
delegate.setNetworkTimeout(executor, milliseconds);
}
public int getNetworkTimeout() throws SQLException {
return delegate.getNetworkTimeout();
}
public <T> T unwrap(Class<T> iface) throws SQLException {
return delegate.unwrap(iface);
}
public boolean isWrapperFor(Class<?> iface) throws SQLException {
return delegate.isWrapperFor(iface);
}
}
-- this script assumes:
-- * `my_schema.my_table` to have 4 auditing columns named `inserted_by`, `inserted_at`, `updated_by`, `updated_at`
-- * the values `audit.AUDIT_USER` and `audit.AUDIT_TIMESTAMP` as provided by the application after starting a transaction
CREATE OR REPLACE FUNCTION set_audit_fields()
RETURNS TRIGGER AS $$
DECLARE
audit_user UUID;
audit_timestamp TIMESTAMP;
BEGIN
-- Postgres 9.6 offers current_setting(..., [missing_ok]) which makes the exception handling obsolete.
-- Postgres 9.5 requires the following however:
BEGIN
audit_user := current_setting('audit.AUDIT_USER');
EXCEPTION WHEN OTHERS THEN
audit_user := '77777777-0000-7777-0000-777777777777'; -- UNKNOWN_USER_ID
END;
BEGIN
audit_timestamp := current_setting('audit.AUDIT_TIMESTAMP');
EXCEPTION WHEN OTHERS THEN
audit_timestamp := now();
END;
IF TG_OP = 'INSERT'
THEN
NEW.inserted_by := audit_user;
NEW.inserted_at := audit_timestamp;
NEW.updated_by := NULL;
NEW.updated_at := NULL;
ELSE
NEW.inserted_by := OLD.inserted_by; -- do not allow statement to (accidentally) set INSERTED_BY
NEW.inserted_at := OLD.inserted_at; -- do not allow statement to (accidentally) set INSERTED_AT
NEW.updated_by := audit_user;
NEW.updated_at := audit_timestamp;
END IF;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
-- note: as of Postgres 10, this could alternatively be implemented as a `FOR EACH STATEMENT` trigger
-- which might yield better performance for statements that insert/update large numbers of rows
CREATE TRIGGER trigger_audit BEFORE INSERT OR UPDATE ON my_schema_name.my_table_name FOR EACH ROW EXECUTE PROCEDURE set_audit_fields();
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment