Last active
September 9, 2023 08:50
-
-
Save stellingsimon/f09005f666b01f2560d73f9e603b2a97 to your computer and use it in GitHub Desktop.
Recording row-level auditing information from an application's user session context in a Postgres DB (9.5+)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package example.postgres.audit; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import java.sql.CallableStatement; | |
import java.sql.Connection; | |
import java.sql.PreparedStatement; | |
import java.sql.SQLException; | |
import java.sql.Savepoint; | |
import java.sql.Statement; | |
import java.sql.Timestamp; | |
import java.time.Instant; | |
import java.util.Objects; | |
import java.util.UUID; | |
import static java.lang.String.format; | |
/** | |
* Whenever the DB is manipulated through this connection, Postgres receives the context variables used in the 'trigger_audit' triggers beforehand. | |
*/ | |
public class AuditingConnection extends DelegatingConnection { | |
private static final Logger LOGGER = LoggerFactory.getLogger(AuditingConnection.class); | |
// TODO: implement the static method `AuditContextProvider.currentUserId()`, | |
// returning whatever value you'd like persisted in the columns `inserted_by`/`updated_by` | |
// from a thread-local context | |
private static final AuditContextProvider AUDIT_CONTEXT_PROVIDER = new AuditContextProvider(); | |
/** | |
* Remembers the audit.AUDIT_USER set for the ongoing transaction, such that we only send SET LOCAL commands when necessary. | |
*/ | |
private UUID userCommunicatedToDB; | |
public AuditingConnection(Connection wrappedConnection) { | |
super(wrappedConnection); | |
} | |
private void setContext() { | |
UUID user = AUDIT_CONTEXT_PROVIDER.currentUserId(); | |
// The user reported by AUDIT_CONTEXT_PROVIDER may change within the same ongoing transaction, | |
// e.g. because the user session is initialized after the first DB statements. | |
// Detect changes and propagate the information to the DB iff necessary: | |
if (!Objects.equals(user, userCommunicatedToDB)) { | |
Timestamp nowTimestamp = Timestamp.from(Instant.now()); | |
if (isAutoCommit()) { | |
// If SET LOCAL is called outside a transaction, it has no effect besides Postgres issuing a warning. | |
// BUT WHY THE HELL ARE WE USING AUTO-COMMIT??? | |
LOGGER.warn("Did not set AUDIT context variables because connection was in AUTO-COMMIT mode."); | |
return; | |
} | |
try (Statement auditUserStmt = createStatementWithoutAuditing(); | |
Statement auditTimestampStmt = createStatementWithoutAuditing()) { | |
auditUserStmt.execute(format("SET LOCAL audit.AUDIT_USER = '%s'", user.toString())); | |
auditTimestampStmt.execute(format("SET LOCAL audit.AUDIT_TIMESTAMP = '%s'", nowTimestamp.toString())); | |
} catch (SQLException e) { | |
LOGGER.warn("failed to set AUDIT information", e); | |
} | |
this.userCommunicatedToDB = user; | |
} | |
} | |
private Statement createStatementWithoutAuditing() throws SQLException { | |
return super.createStatement(); | |
} | |
private void onTransactionEnd() { | |
// Values set with SET LOCAL are purged automatically by Postgres after a COMMIT/ROLLBACK [TO SAVEPOINT]. | |
this.userCommunicatedToDB = null; | |
} | |
private boolean isAutoCommit() { | |
try { | |
return getAutoCommit(); | |
} catch (SQLException e) { | |
LOGGER.warn("failed to read AUTO-COMMIT status", e); | |
return false; // so we attempt to set the audit info anyway... | |
} | |
} | |
/* | |
DELEGATION WITH AUDITING: | |
These methods either manipulate data in the DB or the transaction state of this Connection. | |
All other methods don't and therefore use direct delegation. | |
*/ | |
@Override | |
public Statement createStatement() throws SQLException { | |
setContext(); | |
return super.createStatement(); | |
} | |
@Override | |
public PreparedStatement prepareStatement(String sql) throws SQLException { | |
setContext(); | |
return super.prepareStatement(sql); | |
} | |
@Override | |
public CallableStatement prepareCall(String sql) throws SQLException { | |
setContext(); | |
return super.prepareCall(sql); | |
} | |
@Override | |
public Statement createStatement(int resultSetType, int resultSetConcurrency) throws SQLException { | |
setContext(); | |
return super.createStatement(resultSetType, resultSetConcurrency); | |
} | |
@Override | |
public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency) throws SQLException { | |
setContext(); | |
return super.prepareStatement(sql, resultSetType, resultSetConcurrency); | |
} | |
@Override | |
public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency) throws SQLException { | |
setContext(); | |
return super.prepareCall(sql, resultSetType, resultSetConcurrency); | |
} | |
@Override | |
public void setAutoCommit(boolean autoCommit) throws SQLException { | |
super.setAutoCommit(autoCommit); | |
onTransactionEnd(); | |
} | |
@Override | |
public void commit() throws SQLException { | |
super.commit(); | |
onTransactionEnd(); | |
} | |
@Override | |
public void rollback() throws SQLException { | |
super.rollback(); | |
onTransactionEnd(); | |
} | |
@Override | |
public Statement createStatement(int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException { | |
setContext(); | |
return super.createStatement(resultSetType, resultSetConcurrency, resultSetHoldability); | |
} | |
@Override | |
public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException { | |
setContext(); | |
return super.prepareStatement(sql, resultSetType, resultSetConcurrency, resultSetHoldability); | |
} | |
@Override | |
public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException { | |
setContext(); | |
return super.prepareCall(sql, resultSetType, resultSetConcurrency, resultSetHoldability); | |
} | |
@Override | |
public PreparedStatement prepareStatement(String sql, int autoGeneratedKeys) throws SQLException { | |
setContext(); | |
return super.prepareStatement(sql, autoGeneratedKeys); | |
} | |
@Override | |
public PreparedStatement prepareStatement(String sql, int[] columnIndexes) throws SQLException { | |
setContext(); | |
return super.prepareStatement(sql, columnIndexes); | |
} | |
@Override | |
public PreparedStatement prepareStatement(String sql, String[] columnNames) throws SQLException { | |
setContext(); | |
return super.prepareStatement(sql, columnNames); | |
} | |
@Override | |
public void rollback(Savepoint savepoint) throws SQLException { | |
super.rollback(savepoint); | |
// if the SET LOCAL happened before the 'setSavepoint(savepoint)' call, we might have lost our context values: | |
onTransactionEnd(); | |
} | |
@Override | |
public void setTransactionIsolation(int level) throws SQLException { | |
super.setTransactionIsolation(level); | |
onTransactionEnd(); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package example.postgres.audit; | |
import javax.sql.DataSource; | |
import java.sql.Connection; | |
import java.sql.SQLException; | |
/** | |
* A DataSource wrapper that provides {@link AuditingConnection}s. | |
*/ | |
public class AuditingDataSource extends DelegatingDataSource { | |
private final static AuditContextProvider AUDIT_CONTEXT_PROVIDER = new AuditContextProvider(); | |
public AuditingDataSource(DataSource wrappedDataSource) { | |
super(wrappedDataSource); | |
} | |
@Override | |
public Connection getConnection() throws SQLException { | |
return new AuditingConnection(super.getConnection()); | |
} | |
@Override | |
public Connection getConnection(String username, String password) throws SQLException { | |
return new AuditingConnection(super.getConnection(username, password)); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package example.postgres.audit; | |
import java.sql.Array; | |
import java.sql.Blob; | |
import java.sql.CallableStatement; | |
import java.sql.Clob; | |
import java.sql.Connection; | |
import java.sql.DatabaseMetaData; | |
import java.sql.NClob; | |
import java.sql.PreparedStatement; | |
import java.sql.SQLClientInfoException; | |
import java.sql.SQLException; | |
import java.sql.SQLWarning; | |
import java.sql.SQLXML; | |
import java.sql.Savepoint; | |
import java.sql.Statement; | |
import java.sql.Struct; | |
import java.util.Map; | |
import java.util.Properties; | |
import java.util.concurrent.Executor; | |
/** | |
* Template-style base class for wrappers of a JDBC Connection. | |
*/ | |
public abstract class DelegatingConnection implements Connection { | |
private final Connection delegate; | |
public DelegatingConnection(Connection delegate) { | |
this.delegate = delegate; | |
} | |
public Statement createStatement() throws SQLException { | |
return delegate.createStatement(); | |
} | |
public PreparedStatement prepareStatement(String sql) throws SQLException { | |
return delegate.prepareStatement(sql); | |
} | |
public CallableStatement prepareCall(String sql) throws SQLException { | |
return delegate.prepareCall(sql); | |
} | |
public String nativeSQL(String sql) throws SQLException { | |
return delegate.nativeSQL(sql); | |
} | |
public void setAutoCommit(boolean autoCommit) throws SQLException { | |
delegate.setAutoCommit(autoCommit); | |
} | |
public boolean getAutoCommit() throws SQLException { | |
return delegate.getAutoCommit(); | |
} | |
public void commit() throws SQLException { | |
delegate.commit(); | |
} | |
public void rollback() throws SQLException { | |
delegate.rollback(); | |
} | |
public void close() throws SQLException { | |
delegate.close(); | |
} | |
public boolean isClosed() throws SQLException { | |
return delegate.isClosed(); | |
} | |
public DatabaseMetaData getMetaData() throws SQLException { | |
return delegate.getMetaData(); | |
} | |
public void setReadOnly(boolean readOnly) throws SQLException { | |
delegate.setReadOnly(readOnly); | |
} | |
public boolean isReadOnly() throws SQLException { | |
return delegate.isReadOnly(); | |
} | |
public void setCatalog(String catalog) throws SQLException { | |
delegate.setCatalog(catalog); | |
} | |
public String getCatalog() throws SQLException { | |
return delegate.getCatalog(); | |
} | |
public void setTransactionIsolation(int level) throws SQLException { | |
delegate.setTransactionIsolation(level); | |
} | |
public int getTransactionIsolation() throws SQLException { | |
return delegate.getTransactionIsolation(); | |
} | |
public SQLWarning getWarnings() throws SQLException { | |
return delegate.getWarnings(); | |
} | |
public void clearWarnings() throws SQLException { | |
delegate.clearWarnings(); | |
} | |
public Statement createStatement(int resultSetType, int resultSetConcurrency) throws SQLException { | |
return delegate.createStatement(resultSetType, resultSetConcurrency); | |
} | |
public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency) throws SQLException { | |
return delegate.prepareStatement(sql, resultSetType, resultSetConcurrency); | |
} | |
public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency) throws SQLException { | |
return delegate.prepareCall(sql, resultSetType, resultSetConcurrency); | |
} | |
public Map<String, Class<?>> getTypeMap() throws SQLException { | |
return delegate.getTypeMap(); | |
} | |
public void setTypeMap(Map<String, Class<?>> map) throws SQLException { | |
delegate.setTypeMap(map); | |
} | |
public void setHoldability(int holdability) throws SQLException { | |
delegate.setHoldability(holdability); | |
} | |
public int getHoldability() throws SQLException { | |
return delegate.getHoldability(); | |
} | |
public Savepoint setSavepoint() throws SQLException { | |
return delegate.setSavepoint(); | |
} | |
public Savepoint setSavepoint(String name) throws SQLException { | |
return delegate.setSavepoint(name); | |
} | |
public void rollback(Savepoint savepoint) throws SQLException { | |
delegate.rollback(savepoint); | |
} | |
public void releaseSavepoint(Savepoint savepoint) throws SQLException { | |
delegate.releaseSavepoint(savepoint); | |
} | |
public Statement createStatement(int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException { | |
return delegate.createStatement(resultSetType, resultSetConcurrency, resultSetHoldability); | |
} | |
public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException { | |
return delegate.prepareStatement(sql, resultSetType, resultSetConcurrency, resultSetHoldability); | |
} | |
public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException { | |
return delegate.prepareCall(sql, resultSetType, resultSetConcurrency, resultSetHoldability); | |
} | |
public PreparedStatement prepareStatement(String sql, int autoGeneratedKeys) throws SQLException { | |
return delegate.prepareStatement(sql, autoGeneratedKeys); | |
} | |
public PreparedStatement prepareStatement(String sql, int[] columnIndexes) throws SQLException { | |
return delegate.prepareStatement(sql, columnIndexes); | |
} | |
public PreparedStatement prepareStatement(String sql, String[] columnNames) throws SQLException { | |
return delegate.prepareStatement(sql, columnNames); | |
} | |
public Clob createClob() throws SQLException { | |
return delegate.createClob(); | |
} | |
public Blob createBlob() throws SQLException { | |
return delegate.createBlob(); | |
} | |
public NClob createNClob() throws SQLException { | |
return delegate.createNClob(); | |
} | |
public SQLXML createSQLXML() throws SQLException { | |
return delegate.createSQLXML(); | |
} | |
public boolean isValid(int timeout) throws SQLException { | |
return delegate.isValid(timeout); | |
} | |
public void setClientInfo(String name, String value) throws SQLClientInfoException { | |
delegate.setClientInfo(name, value); | |
} | |
public void setClientInfo(Properties properties) throws SQLClientInfoException { | |
delegate.setClientInfo(properties); | |
} | |
public String getClientInfo(String name) throws SQLException { | |
return delegate.getClientInfo(name); | |
} | |
public Properties getClientInfo() throws SQLException { | |
return delegate.getClientInfo(); | |
} | |
public Array createArrayOf(String typeName, Object[] elements) throws SQLException { | |
return delegate.createArrayOf(typeName, elements); | |
} | |
public Struct createStruct(String typeName, Object[] attributes) throws SQLException { | |
return delegate.createStruct(typeName, attributes); | |
} | |
public void setSchema(String schema) throws SQLException { | |
delegate.setSchema(schema); | |
} | |
public String getSchema() throws SQLException { | |
return delegate.getSchema(); | |
} | |
public void abort(Executor executor) throws SQLException { | |
delegate.abort(executor); | |
} | |
public void setNetworkTimeout(Executor executor, int milliseconds) throws SQLException { | |
delegate.setNetworkTimeout(executor, milliseconds); | |
} | |
public int getNetworkTimeout() throws SQLException { | |
return delegate.getNetworkTimeout(); | |
} | |
public <T> T unwrap(Class<T> iface) throws SQLException { | |
return delegate.unwrap(iface); | |
} | |
public boolean isWrapperFor(Class<?> iface) throws SQLException { | |
return delegate.isWrapperFor(iface); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- this script assumes: | |
-- * `my_schema.my_table` to have 4 auditing columns named `inserted_by`, `inserted_at`, `updated_by`, `updated_at` | |
-- * the values `audit.AUDIT_USER` and `audit.AUDIT_TIMESTAMP` as provided by the application after starting a transaction | |
CREATE OR REPLACE FUNCTION set_audit_fields() | |
RETURNS TRIGGER AS $$ | |
DECLARE | |
audit_user UUID; | |
audit_timestamp TIMESTAMP; | |
BEGIN | |
-- Postgres 9.6 offers current_setting(..., [missing_ok]) which makes the exception handling obsolete. | |
-- Postgres 9.5 requires the following however: | |
BEGIN | |
audit_user := current_setting('audit.AUDIT_USER'); | |
EXCEPTION WHEN OTHERS THEN | |
audit_user := '77777777-0000-7777-0000-777777777777'; -- UNKNOWN_USER_ID | |
END; | |
BEGIN | |
audit_timestamp := current_setting('audit.AUDIT_TIMESTAMP'); | |
EXCEPTION WHEN OTHERS THEN | |
audit_timestamp := now(); | |
END; | |
IF TG_OP = 'INSERT' | |
THEN | |
NEW.inserted_by := audit_user; | |
NEW.inserted_at := audit_timestamp; | |
NEW.updated_by := NULL; | |
NEW.updated_at := NULL; | |
ELSE | |
NEW.inserted_by := OLD.inserted_by; -- do not allow statement to (accidentally) set INSERTED_BY | |
NEW.inserted_at := OLD.inserted_at; -- do not allow statement to (accidentally) set INSERTED_AT | |
NEW.updated_by := audit_user; | |
NEW.updated_at := audit_timestamp; | |
END IF; | |
RETURN NEW; | |
END; | |
$$ LANGUAGE plpgsql; | |
-- note: as of Postgres 10, this could alternatively be implemented as a `FOR EACH STATEMENT` trigger | |
-- which might yield better performance for statements that insert/update large numbers of rows | |
CREATE TRIGGER trigger_audit BEFORE INSERT OR UPDATE ON my_schema_name.my_table_name FOR EACH ROW EXECUTE PROCEDURE set_audit_fields(); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment