-
-
Save crossworth/8ef7e50e681ccf3f9d2ac79691d9c9d1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package app | |
// AuditLog is the audit log used by the audit hook. | |
type AuditLog struct { | |
ID *int `json:"id,omitempty"` | |
Changes map[string]any `json:"changes,omitempty"` | |
} | |
func getID(m ent.Mutation) *int { | |
if idr, ok := m.(interface{ ID() (int, bool) }); ok { | |
if id, found := idr.ID(); found { | |
return &id | |
} | |
} | |
return nil | |
} | |
// AuditRecorder is a recorder of audits events. | |
type AuditRecorder interface { | |
RegisterAuditLog(ctx context.Context, userID int, event string, data any) error | |
} | |
// skipTables is a list of tables that we should avoid creating audit events. | |
var skipTables = []string{ | |
ent.TypeAudit, | |
ent.TypeUserSession, | |
ent.TypeStatefulAction, | |
ent.TypeIdempotency, | |
ent.TypeChannelSubscriptionMessage, | |
ent.TypeFingerprint, | |
} | |
// auditHook is the audit hook used for most of the events. | |
func (tt *TT) auditHook() ent.Hook { | |
shouldSkip := func(name string) bool { | |
for _, n := range skipTables { | |
if strings.EqualFold(n, name) { | |
return true | |
} | |
} | |
return false | |
} | |
return func(next ent.Mutator) ent.Mutator { | |
return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) { | |
if shouldSkip(m.Type()) { | |
return next.Mutate(ctx, m) | |
} | |
recordID := getID(m) | |
// we run the mutation first. | |
val, err := next.Mutate(ctx, m) | |
if err != nil { | |
return nil, err | |
} | |
// maybe we are creating a record, check | |
// if we have a record id again | |
if recordID == nil { | |
recordID = getID(m) | |
} | |
var userID *int | |
uID, found := httputil.UserIDFromContext(ctx) | |
if found { | |
userID = &uID | |
} | |
al := AuditLog{ | |
ID: recordID, | |
Changes: map[string]any{}, | |
} | |
if !m.Op().Is(ent.OpDelete | ent.OpDeleteOne) { | |
for _, f := range m.Fields() { | |
v, _ := m.Field(f) | |
al.Changes[f] = v | |
} | |
} | |
event := fmt.Sprintf("%s%s", m.Type(), m.Op().String()) | |
err = tt.RegisterAuditLog(ctx, userID, event, al) | |
if err != nil { | |
// in case of an error, we log the information to leave an audit trail. | |
applog.LogCtx(ctx).Err(err). | |
Interface("audit_log", al). | |
Str("event", event). | |
Interface("user_id", userID). | |
Msg("error saving an audit log") | |
return nil, err | |
} | |
return val, nil | |
}) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment