Skip to content

Instantly share code, notes, and snippets.

@ikupenov
Created February 29, 2024 18:40
Show Gist options
  • Save ikupenov/10bc89d92d92eaba8cc5569013e04069 to your computer and use it in GitHub Desktop.
Save ikupenov/10bc89d92d92eaba8cc5569013e04069 to your computer and use it in GitHub Desktop.
Intercepting Drizzle db calls
import { and, type DBQueryConfig, eq, type SQLWrapper } from "drizzle-orm";
import { drizzle } from "drizzle-orm/postgres-js";
import postgres, { type Sql } from "postgres";
import { type AnyArgs } from "@/common";
import {
type DbClient,
type DbTable,
type DeleteArgs,
type DeleteFn,
type FindArgs,
type FindFn,
type FromArgs,
type FromFn,
type InsertArgs,
type JoinArgs,
type JoinFn,
type Owner,
type RlsDbClient,
type SetArgs,
type SetFn,
type UpdateArgs,
type ValuesArgs,
type ValuesFn,
type WhereArgs,
type WhereFn,
} from "./db-client.types";
import * as schema from "./schema";
export const connectDb = (connectionString: string) => {
return postgres(connectionString);
};
export const createDbClient = (client: Sql): DbClient => {
return drizzle(client, { schema });
};
export const createRlsDbClient = (client: Sql, owner: Owner): RlsDbClient => {
const db = createDbClient(client);
const ownerIdColumn = "ownerId" as const;
// eslint-disable-next-line import/namespace
const getTable = (table: DbTable) => schema[table];
const getAccessPolicy = (
table: {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
[ownerIdColumn]: any;
},
owner: Owner,
) => eq(table[ownerIdColumn], owner.id);
interface InvokeContext {
path?: string[];
fnPath?: { name: string; args: unknown[] }[];
}
interface InterceptFn {
invoke: (...args: unknown[]) => unknown;
name: string;
args: unknown[];
}
interface OverrideFn {
pattern: string | string[];
action: () => unknown;
}
const intercept = (fn: InterceptFn, context: InvokeContext = {}) => {
const { path = [], fnPath = [] } = context;
const pathAsString = path.join(".");
const matchPath = (pattern: string) => {
return new RegExp(
`^${pattern.replace(/\./g, "\\.").replace(/\*/g, ".*")}$`,
).test(pathAsString);
};
const overrides: OverrideFn[] = [
{
pattern: ["db.execute", "db.*.execute"],
action: () => {
throw new Error("'execute' in rls DB is not allowed");
},
},
{
pattern: [
"db.query.findMany",
"db.query.*.findMany",
"db.query.findFirst",
"db.query.*.findFirst",
],
action: () => {
const findFn = fn.invoke as FindFn;
const findArgs = fn.args as FindArgs;
const tableIndex = path.findIndex((x) => x === "query") + 1;
const tableName = path[tableIndex]! as keyof typeof db.query;
const table = getTable(tableName as DbTable);
if (ownerIdColumn in table) {
let [config] = findArgs;
if (config?.where) {
config = {
...config,
where: and(
getAccessPolicy(table, owner),
config.where as SQLWrapper,
),
};
}
if (!config?.where) {
config = {
...config,
where: getAccessPolicy(table, owner),
};
}
if (config.with) {
config = {
...config,
with: (
Object.keys(config.with) as (keyof typeof config.with)[]
).reduce<DBQueryConfig["with"]>((acc, key) => {
const value = config!.with![key] as
| true
| null
| DBQueryConfig<"many">;
if (value === true) {
return {
...acc,
[key]: {
where: (table) =>
ownerIdColumn in table
? // eslint-disable-next-line @typescript-eslint/no-unsafe-argument, @typescript-eslint/no-explicit-any
getAccessPolicy(table as any, owner)
: undefined,
},
};
}
if (typeof value === "object" && value !== null) {
return {
...acc,
[key]: {
...value,
where: (table, other) =>
ownerIdColumn in table
? and(
// eslint-disable-next-line @typescript-eslint/no-unsafe-argument, @typescript-eslint/no-explicit-any
getAccessPolicy(table as any, owner),
typeof value.where === "function"
? value.where(table, other)
: value.where,
)
: typeof value.where === "function"
? value.where(table, other)
: value.where,
},
};
}
return { ...acc, [key]: value };
// eslint-disable-next-line @typescript-eslint/no-unsafe-argument, @typescript-eslint/prefer-reduce-type-parameter, @typescript-eslint/no-explicit-any
}, config.with as any),
};
}
return findFn(...([config] as FindArgs));
}
return findFn(...findArgs);
},
},
{
pattern: "db.*.from",
action: () => {
const fromFn = fn.invoke as FromFn;
const fromArgs = fn.args as FromArgs;
const [table] = fromArgs;
if (ownerIdColumn in table) {
return fromFn(...fromArgs).where(getAccessPolicy(table, owner));
}
return fromFn(...fromArgs);
},
},
{
pattern: ["db.*.from.where", "db.*.from.*.where"],
action: () => {
const whereFn = fn.invoke as WhereFn;
const whereArgs = fn.args as WhereArgs;
const [table] = fnPath.findLast((x) => x.name === "from")
?.args as FromArgs;
if (ownerIdColumn in table) {
const [whereFilter] = whereArgs;
return whereFn(
and(getAccessPolicy(table, owner), whereFilter as SQLWrapper),
);
}
return whereFn(...whereArgs);
},
},
{
pattern: [
"db.*.leftJoin",
"db.*.rightJoin",
"db.*.innerJoin",
"db.*.fullJoin",
],
action: () => {
const joinFn = fn.invoke as JoinFn;
const joinArgs = fn.args as JoinArgs;
const [table, joinOptions] = joinArgs;
if (ownerIdColumn in table) {
return joinFn(
table,
and(getAccessPolicy(table, owner), joinOptions as SQLWrapper),
);
}
return joinFn(...joinArgs);
},
},
{
pattern: "db.insert.values",
action: () => {
const valuesFn = fn.invoke as ValuesFn;
const valuesArgs = fn.args as ValuesArgs;
const [table] = fnPath.findLast((x) => x.name === "insert")
?.args as InsertArgs;
if (ownerIdColumn in table) {
let [valuesToInsert] = valuesArgs;
if (!Array.isArray(valuesToInsert)) {
valuesToInsert = [valuesToInsert];
}
const valuesToInsertWithOwner = valuesToInsert.map((value) => ({
...value,
ownerId: owner.id,
}));
return valuesFn(valuesToInsertWithOwner);
}
return valuesFn(...valuesArgs);
},
},
{
pattern: "db.update.set",
action: () => {
const setFn = fn.invoke as SetFn;
const setArgs = fn.args as SetArgs;
const [table] = fnPath.findLast((x) => x.name === "update")
?.args as UpdateArgs;
if (ownerIdColumn in table) {
return setFn(...setArgs).where(getAccessPolicy(table, owner));
}
return setFn(...setArgs);
},
},
{
pattern: ["db.update.where", "db.update.*.where"],
action: () => {
const whereFn = fn.invoke as WhereFn;
const whereArgs = fn.args as WhereArgs;
const [table] = [...fnPath].reverse().find((x) => x.name === "update")
?.args as UpdateArgs;
if (ownerIdColumn in table) {
const [whereFilter] = whereArgs;
return whereFn(
and(getAccessPolicy(table, owner), whereFilter as SQLWrapper),
);
}
return whereFn(...whereArgs);
},
},
{
pattern: "db.delete",
action: () => {
const deleteFn = fn.invoke as DeleteFn;
const deleteArgs = fn.args as DeleteArgs;
const [table] = deleteArgs;
if (ownerIdColumn in table) {
return deleteFn(...deleteArgs).where(getAccessPolicy(table, owner));
}
return deleteFn(...deleteArgs);
},
},
{
pattern: ["db.delete.where", "db.delete.*.where"],
action: () => {
const whereFn = fn.invoke as WhereFn;
const whereArgs = fn.args as WhereArgs;
const [table] = fnPath.findLast((x) => x.name === "delete")
?.args as DeleteArgs;
if (ownerIdColumn in table) {
const [whereOptions] = whereArgs;
return whereFn(
and(getAccessPolicy(table, owner), whereOptions as SQLWrapper),
);
}
return whereFn(...whereArgs);
},
},
];
const fnOverride = overrides.find(({ pattern, action }) => {
if (Array.isArray(pattern) && pattern.some(matchPath)) {
return action;
}
if (typeof pattern === "string" && matchPath(pattern)) {
return action;
}
return null;
})?.action;
return fnOverride ? fnOverride() : fn.invoke(...fn.args);
};
const createProxy = <T extends object>(
target: T,
context: InvokeContext = {},
): T => {
const { path = [], fnPath = [] } = context;
return new Proxy<T>(target, {
get: (innerTarget, innerTargetProp, innerTargetReceiver) => {
const currentPath = path.concat(innerTargetProp.toString());
const innerTargetPropValue = Reflect.get(
innerTarget,
innerTargetProp,
innerTargetReceiver,
);
if (typeof innerTargetPropValue === "function") {
return (...args: AnyArgs) => {
const currentFnPath = [
...fnPath,
{ name: innerTargetProp.toString(), args },
];
const result = intercept(
{
invoke: innerTargetPropValue.bind(
innerTarget,
) as InterceptFn["invoke"],
name: innerTargetProp.toString(),
args,
},
{ path: currentPath, fnPath: currentFnPath },
);
if (
typeof result === "object" &&
result !== null &&
!Array.isArray(result)
) {
return createProxy(result, {
path: currentPath,
fnPath: currentFnPath,
});
}
return result;
};
} else if (
typeof innerTargetPropValue === "object" &&
innerTargetPropValue !== null &&
!Array.isArray(innerTargetPropValue)
) {
// wrap nested objects in a proxy as well
return createProxy(innerTargetPropValue, {
path: currentPath,
fnPath,
});
}
return innerTargetPropValue;
},
});
};
return createProxy(db, { path: ["db"] });
};
import { type drizzle } from "drizzle-orm/postgres-js";
import type * as schema from "./schema";
declare const db: ReturnType<typeof drizzle<typeof schema>>;
export interface Owner {
id: string | null;
}
export type DbClient = typeof db;
export type DbSchema = typeof schema;
export type DbTable = keyof DbSchema;
export type RlsDbClient = Omit<DbClient, "execute">;
export type FindFn<K extends keyof typeof db.query = keyof typeof db.query> = (
...args:
| Parameters<(typeof db.query)[K]["findFirst"]>
| Parameters<(typeof db.query)[K]["findMany"]>
) =>
| ReturnType<(typeof db.query)[K]["findFirst"]>
| ReturnType<(typeof db.query)[K]["findMany"]>;
export type FindArgs<K extends keyof typeof db.query = keyof typeof db.query> =
Parameters<FindFn<K>>;
export type SelectFn = typeof db.select;
export type SelectArgs = Parameters<SelectFn>;
export type FromFn = ReturnType<SelectFn>["from"];
export type FromArgs = Parameters<FromFn>;
export type WhereFn = ReturnType<FromFn>["where"];
export type WhereArgs = Parameters<WhereFn>;
export type JoinFn = ReturnType<FromFn>["leftJoin"];
export type JoinArgs = Parameters<JoinFn>;
export type InsertFn = typeof db.insert;
export type InsertArgs = Parameters<InsertFn>;
export type ValuesFn = ReturnType<InsertFn>["values"];
export type ValuesArgs = Parameters<ValuesFn>;
export type UpdateFn = typeof db.update;
export type UpdateArgs = Parameters<UpdateFn>;
export type SetFn = ReturnType<UpdateFn>["set"];
export type SetArgs = Parameters<SetFn>;
export type DeleteFn = typeof db.delete;
export type DeleteArgs = Parameters<DeleteFn>;
@ikupenov
Copy link
Author

ikupenov commented Mar 8, 2024

You have to provide the functions you want to intercept on line 82 const overrides: OverrideFn[]. I've left the overrides we currently use to serve as an example. We currently filter every row by ownerId and insert ownerId automatically when db.insert is called. That way, everywhere you use rlsDbClient you ensure that the rows will be appropriately filtered by the user making the request, without having to think about it.

@stefanosandes
Copy link

What's the code for AnyArgs?

@cheft
Copy link

cheft commented Mar 18, 2024

ERR TypeError:
db.select(...).from is not a function
db.insert(...).values is not a function

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment