Skip to content

Instantly share code, notes, and snippets.

Last active June 11, 2024 22:51
Show Gist options
  • Save cayter/49d5c256a885d90c399ca6c1eca19f51 to your computer and use it in GitHub Desktop.
Save cayter/49d5c256a885d90c399ca6c1eca19f51 to your computer and use it in GitHub Desktop.
Drizzle ORM Type-Safe Repository With PgTable
import { startSpan } from "@sentry/remix";
import type { StartSpanOptions } from "@sentry/types";
import {
type AnyColumn,
type AnyTable,
type BuildQueryResult,
type DBQueryConfig,
type DrizzleTypeError,
type Equal,
type ExtractTablesWithRelations,
type GetColumnData,
type InferSelectModel,
type KnownKeysOnly,
type Relation,
type SQL,
} from "drizzle-orm";
import { type PostgresJsDatabase, drizzle } from "drizzle-orm/postgres-js";
import type {
} from "drizzle-orm/pg-core";
import type { PostgresJsQueryResultHKT } from "drizzle-orm/postgres-js";
import { camelCase } from "moderndash";
import postgres from "postgres";
* Retrieves the keys of the given object as an array of its own keyof type,
* ensuring the keys are typed according to the keys actually present in `O`.
* @template O - The object type from which keys are extracted.
* @param {O} obj - The object whose keys are to be retrieved.
* @returns {(keyof O)[]} An array of keys of the object `O`.
export function objectKeys<O extends object>(obj: O): (keyof O)[] {
return Object.keys(obj) as (keyof O)[];
* The database instance.
export interface AppDb<T extends Record<string, unknown>> {
* The database's Drizzle ORM instance.
orm: PostgresJsDatabase<T> & {
session: {
client: Sql<T>;
* The database's schemas.
schema: T;
* The database's URL.
url: string;
* The database error class.
export class DatabaseError extends Error {
fieldErrors!: Record<string, string[] | undefined>;
message: string,
fieldErrors?: Record<string, string[] | undefined>,
) {
if (fieldErrors) {
this.fieldErrors = fieldErrors;
* The options for finding the first record.
export type FindFirstOpts<T extends Record<string, unknown>> = KnownKeysOnly<
FindFirstQueryConfig<T, keyof ExtractTablesWithRelations<T>>
* The options for finding many records.
export type FindManyOpts<T extends Record<string, unknown>> = KnownKeysOnly<
FindManyQueryConfig<T, keyof ExtractTablesWithRelations<T>>
* The options for paginating the records by offset.
export type PaginateByOffsetOpts<T extends Record<string, unknown>> =
PaginateByOffsetQueryConfig<T, keyof ExtractTablesWithRelations<T>>
* The find first query builder config.
export type FindFirstQueryConfig<
T extends Record<string, unknown>,
U extends keyof ExtractTablesWithRelations<T>,
> = Omit<
> & {
tx?: Transaction<T>;
* The find many query builder config.
export type FindManyQueryConfig<
T extends Record<string, unknown>,
U extends keyof ExtractTablesWithRelations<T>,
> = DBQueryConfig<
> & {
tx?: Transaction<T>;
* The paginate by offset query builder config.
export type PaginateByOffsetQueryConfig<
T extends Record<string, unknown>,
U extends keyof ExtractTablesWithRelations<T>,
> = Omit<
> & {
page?: number;
perPage?: number;
sortBy?: keyof ExtractTablesWithRelations<T>[U]["columns"];
sortDirection?: "asc" | "desc";
tx?: Transaction<T>;
"limit" | "offset"
* The generic transaction session.
export type Transaction<T extends Record<string, unknown>> = PgTransaction<
type SimplifyShallow<T> = {
[K in keyof T]: T[K];
} & {};
type SelectResultField<
TDeep extends boolean = true,
> = T extends DrizzleTypeError<any>
? T
: T extends AnyTable<any>
? Equal<TDeep, true> extends true
? SelectResultField<T["_"]["columns"], false>
: never
: T extends AnyColumn
? GetColumnData<T>
: T extends SQL | SQL.Aliased
? T["_"]["type"]
: T extends Record<string, any>
? SelectResultFields<T, true>
: never;
type SelectResultFields<
TDeep extends boolean = true,
> = SimplifyShallow<{
[Key in keyof TSelectedFields & string]: SelectResultField<
export abstract class Repository<
T extends Record<string, unknown>,
U extends PgTableWithColumns<any>,
V extends keyof ExtractTablesWithRelations<T>,
> {
* The DB instance.
db: AppDb<T>;
* The DB table.
table: U;
* The DB model name.
#modelName!: keyof ExtractTablesWithRelations<T>;
* The DB table relations.
#relations: Record<string, Relation>;
constructor(db: AppDb<T>, table: U) {
this.db = db;
this.table = table;
Object.getOwnPropertySymbols(table).map((k) => {
if (k.toString() === "Symbol(drizzle:Name)") {
// Replace graphile-worker's table prefix.
this.#modelName = camelCase(
table[k as unknown as string].replace("_private_", ""),
) as keyof ExtractTablesWithRelations<T>;
// @ts-expect-error
this.#relations = this.db.schema[`${this.#modelName}Relations`].config(
get columns() {
return objectKeys(getTableColumns(this.table));
* Asynchronously invalidates the storage cache/object for the provided rows.
* @param {Array<InferSelectModel<T>>} rows The rows to be checked for
* cache/object invalidation.
* @returns {Promise<any>} A promise that resolves once all relevant
* cache/object entries have been invalidated.
* @private
async #cleanUpStorage(rows: Array<InferSelectModel<U>>) {
const promises: Promise<any>[] = []; => {
Object.values(row).map((value) => {
if (
value &&
typeof value === "object" &&
"key" in value &&
"name" in value &&
"url" in value &&
) {
// promises.push(cache.expire(;
// promises.push(storage.delete(value.key));
await Promise.all(promises);
* Get the Sentry's tracing span attributes.
* @returns {StartSpanOptions} The Sentry options to start a tracing span.
#getSentryAttributes(): Partial<StartSpanOptions> {
return {
attributes: {
"db.system": "postgresql",
op: "db.query",
* Convert the unknown error to DatabaseError class with best efforts.
* @param {unknown} err The unknown error.
* @returns {unknown | DatabaseError}
#toDatabaseError(err: unknown) {
* Refer to the errors list at
if (err instanceof postgres.PostgresError) {
switch (err.code) {
case "23505": {
const keyRegex = /Key \(([^=]+)\)=/;
const valueRegex = /=\(([^)]+)\)/;
const keyMatch = err.detail?.match(keyRegex);
const valueMatch = err.detail?.match(valueRegex);
if (keyMatch && valueMatch) {
const keys = keyMatch[1].split(", ").map((key) => key.trim());
const values = valueMatch[1]
.split(", ")
.map((value) => value.trim());
const fieldErrors: Record<string, string[]> = {};
const isComposite = keys.length > 1;
// TODO: Finish up composite key error handling.
// keys.forEach((key, _idx) => {
// fieldErrors[
// `${pluralize.singular(this.#tableName)}.${camel(key)}`
// ] = [
// isComposite
// ? "app:errors.dbUniqueCompositeConstraint"
// : "app:errors.dbUniqueConstraint",
// ] satisfies I18nKeys[];
// });
return new DatabaseError(err.message, fieldErrors);
return err;
* A hook that is invoked right before a row is inserted.
* @param {PgInsertValue<U>} row
* @returns {Promise<void>}
async beforeCreate(row: PgInsertValue<U>) {}
* A hook that is invoked after a row is inserted and right before returning to the caller.
* @param {InferSelectModel<U>} row
* @returns {Promise<InferSelectModel<U>>}
async afterCreate(row: InferSelectModel<U>) {}
* A hook that is invoked after a row is deleted and right before returning to the caller.
* @param {InferSelectModel<U>} row
* @returns {Promise<void>}
async afterDelete(row: InferSelectModel<U>) {}
* A hook that is invoked right before returning to the caller which applies to:
* - findFirst()
* - findMany()
* - paginateByOffset()
* The common use cases:
* - post process s3 storage path to a private s3 URL and cache it
* @param {InferSelectModel<U>} row
* @returns {Promise<void>}
async afterFind(row: InferSelectModel<U>) {}
* A hook that is invoked right before a row is updated.
* @param {PgUpdateSetSource<U>} row
* @returns {Promise<void>}
async beforeUpdate(row: PgUpdateSetSource<U>) {}
* A hook that is invoked after a row is updated and right before returning to the caller.
* @param {InferSelectModel<U>} row
* @returns {Promise<void>}
async afterUpdate(row: InferSelectModel<U>) {}
* Insert 1 value into the database.
* @param {PgInsertValue<U>} value The values to insert.
* @param {object} [opts] The insert options.
* @param {object} [opts.columns] The fields to return.
* @param {Transaction<U>} [opts.tx] The SQL transaction.
* @returns
async create<TSelectedFields extends SelectedFieldsFlat>(
value: PgInsertValue<U>,
opts: {
columns: TSelectedFields;
onConflictDoNothing?: {
target?: IndexColumn | IndexColumn[];
onConflictDoUpdate?: PgInsertOnConflictDoUpdateConfig<
PgInsertBase<U, PostgresJsQueryResultHKT, undefined, false, never>
tx?: Transaction<T>;
): Promise<SelectResultFields<TSelectedFields> | null>;
async create(
value: PgInsertValue<U>,
opts?: {
onConflictDoNothing?: {
target?: IndexColumn | IndexColumn[];
onConflictDoUpdate?: PgInsertOnConflictDoUpdateConfig<
PgInsertBase<U, PostgresJsQueryResultHKT, undefined, false, never>
tx?: Transaction<T>;
): Promise<InferSelectModel<U> | null>;
async create<TSelectedFields extends SelectedFieldsFlat | undefined>(
value: PgInsertValue<U>,
opts?: {
columns?: TSelectedFields;
onConflictDoNothing?: {
target?: IndexColumn | IndexColumn[];
onConflictDoUpdate?: PgInsertOnConflictDoUpdateConfig<
PgInsertBase<U, PostgresJsQueryResultHKT, undefined, false, never>
tx?: Transaction<T>;
) {
try {
await this.beforeCreate(value);
const qb = (opts?.tx || this.db.orm).insert(this.table).values(value);
if (opts?.onConflictDoUpdate) {
? {
set: {
updatedAt: sql`NOW()`,
: {}),
} else if (opts?.onConflictDoNothing) {
if (opts && "columns" in opts) {
qb.returning(opts.columns as SelectedFieldsFlat);
} else {
const rows = await startSpan(
name: qb.toSQL().sql,
async () => qb,
if (rows.length < 1) {
return null;
await this.afterCreate(rows[0]);
return rows[0];
} catch (err) {
throw this.#toDatabaseError(err);
* Insert many values into the database.
* @param {PgInsertValue<U>[]} values The values to insert.
* @param {object} [opts] The insert options.
* @param {object} [opts.columns] The columns to return.
* @param {Transaction<T>} [opts.tx] The SQL transaction.
* @returns
async createMany<TSelectedFields extends SelectedFieldsFlat>(
value: PgInsertValue<U>[],
opts: {
columns: TSelectedFields;
tx?: Transaction<T>;
): Promise<SelectResultFields<TSelectedFields>[]>;
async createMany(
value: PgInsertValue<U>[],
opts?: { tx?: Transaction<T> },
): Promise<InferSelectModel<U>[]>;
async createMany<TSelectedFields extends SelectedFieldsFlat>(
values: PgInsertValue<U>[],
opts?: {
columns?: TSelectedFields;
tx?: Transaction<T>;
) {
try {
const qb = (opts?.tx || this.db.orm).insert(this.table).values(
await Promise.all(
.map((value) => {
return async () => {
await this.beforeCreate(value);
return value;
.map((v) => v()),
if (opts?.columns) {
} else {
const rows = await startSpan(
name: qb.toSQL().sql,
async () => qb,
if (rows.length < 1) {
return [];
await Promise.all( (row) => this.afterCreate(row)));
return rows;
} catch (err) {
throw this.#toDatabaseError(err);
* Delete the data rows in the database based on the where condition.
* @param {object} [opts] The insert options.
* @param {object} [opts.columns] The columns to return.
* @param {SQL<unknown>} [opts.where] The SQL where filter.
* @param {Transaction<T>} [opts.tx] The SQL transaction.
* @returns
async delete<
TSelectedFields extends SelectedFieldsFlat,
QConfig extends FindManyQueryConfig<T, V>,
>(opts: {
columns: TSelectedFields;
where?: QConfig["where"];
tx?: Transaction<T>;
}): Promise<SelectResultFields<TSelectedFields>[] | null>;
async delete<QConfig extends FindManyQueryConfig<T, V>>(opts?: {
where?: QConfig["where"];
tx?: Transaction<T>;
}): Promise<InferSelectModel<U>[] | null>;
async delete<
TSelectedFields extends SelectedFieldsFlat,
QConfig extends FindManyQueryConfig<T, V>,
>(opts?: {
columns?: TSelectedFields;
where?: QConfig["where"];
tx?: Transaction<T>;
}) {
let where;
if (opts?.where) {
if ("queryChunks" in opts.where) {
where = opts.where;
} else if (typeof opts.where === "function") {
where = opts.where(getTableColumns(this.table), getOperators());
const deletingRows = await this.findMany({ where, tx: opts?.tx });
if (deletingRows.length > 0) {
await this.#cleanUpStorage(deletingRows);
const qb = (opts?.tx || this.db.orm).delete(this.table).where(where);
if (opts?.columns) {
} else {
const rows = await startSpan(
name: qb.toSQL().sql,
async () => qb,
if (rows.length < 1) {
return [];
await Promise.all( (row) => this.afterDelete(row)));
return rows;
* Return the 1st record based on the config.
* @param {FindFirstOpts<QConfig>} [opts] The find many options with pagination.
* @param {object} [opts.columns] The columns to select.
* @param {object} [opts.extras] The extras columns to return.
* @param {object} [opts.offset] The offset of the returned rows.
* @param {object} [opts.orderBy] The sorting order.
* @param {SQL<unknown>} [opts.where] The where filter.
* @param {object} [opts.with] The relations to include in query.
* @param {Transaction<T>} [opts.tx] The SQL transaction.
* @returns
async findFirst<QConfig extends FindFirstQueryConfig<T, V>>(
opts?: FindFirstOpts<QConfig>,
) {
const { tx, ...config } = opts || {};
const qb = tx || this.db.orm;
const row = await startSpan(
// @ts-expect-error
name: qb.query[this.#modelName].findFirst(config || {}).toSQL().sql,
// @ts-expect-error
async () => qb.query[this.#modelName].findFirst(config || {}),
if (!row) {
return null;
await this.afterFind(row);
return row as BuildQueryResult<
* Return all the records based on the config.
* @param {FindManyOpts<QConfig>} [opts] The find many options.
* @param {object} [opts.columns] The columns to select.
* @param {object} [opts.extras] The extras columns to return.
* @param {object} [opts.limit] The limit number of the returned rows.
* @param {object} [opts.offset] The offset of the returned rows.
* @param {object} [opts.orderBy] The sorting order.
* @param {SQL<unknown>} [opts.where] The where filter.
* @param {object} [opts.with] The relations to include in query.
* @param {Transaction<T>} [opts.tx] The SQL transaction.
* @returns
async findMany<QConfig extends FindManyQueryConfig<T, V>>(
opts?: FindManyOpts<QConfig>,
) {
const { tx, ...config } = opts || {};
const qb = tx || this.db.orm;
const rows = await startSpan(
// @ts-expect-error
name: qb.query[this.#modelName].findMany(config || {}).toSQL().sql,
// @ts-expect-error
async () => qb.query[this.#modelName].findMany(config || {}),
if (!rows || rows.length < 1) {
return [];
await Promise.all(
.map((row: InferSelectModel<U>) => async () => {
return this.afterFind(row);
.map((v: () => Promise<void>) => v()),
return rows as unknown as BuildQueryResult<
* Return the paginated records based on the config.
* @param {PaginateByOffsetOpts<QConfig>} [opts] The find many options with pagination.
* @param {object} [opts.columns] The columns to select.
* @param {object} [opts.extras] The extras columns to return.
* @param {SQL<unknown>} [opts.orderBy] The order by SQL. Can be overwritten by sortBy.
* @param {object} [opts.sortBy] The sorting column.
* @param {object} [opts.sortDirection] The sorting direction.
* @param {SQL<unknown>} [opts.where] The where filter.
* @param {object} [opts.with] The relations to include in query.
* @param {number} [] The current page.
* @param {number} [opts.perPage=10] The current page size.
* @param {Transaction<T>} [opts.tx] The SQL transaction.
* @returns
async paginateByOffset<QConfig extends PaginateByOffsetQueryConfig<T, V>>(
opts?: PaginateByOffsetOpts<QConfig>,
) {
const {
page = 1,
perPage = 10,
sortDirection = "asc",
} = opts || {
columns: undefined,
extras: undefined,
orderBy: undefined,
tx: undefined,
where: undefined,
with: undefined,
const qb = config.tx || this.db.orm;
let countWhere: SQL<unknown> | undefined;
if (config.where) {
if ("queryChunks" in config.where) {
countWhere = config.where;
} else if (typeof config.where === "function") {
countWhere = config.where(getTableColumns(this.table), getOperators());
if (sortBy) {
config.orderBy =
sortDirection === "asc"
? [
this.table[sortBy as keyof (typeof this.table)["_"]["columns"]],
: [
this.table[sortBy as keyof (typeof this.table)["_"]["columns"]],
const [rows, totals] = await startSpan(
name: qb
.select({ count: sql<number>`count(*)`.mapWith(Number) })
async () =>
offset: (page - 1) * perPage,
limit: perPage + 1,
.select({ count: sql<number>`count(*)`.mapWith(Number) })
const totalRows = totals?.[0]?.count;
const next = rows.length > perPage;
if (next) {
return {
previous: page > 1,
totalPages: Math.ceil(totalRows / perPage),
* Update the data rows in the database based on the where condition.
* @param {PgUpdateSetSource<U>} value The values to update to.
* @param {object} [opts] The insert options.
* @param {object} [opts.columns] The fields to return.
* @param {SQL<unknown>} [opts.where] The SQL where filter.
* @param {Transaction<T>} [opts.tx] The SQL transaction.
* @returns
async update<
TSelectedFields extends SelectedFieldsFlat,
QConfig extends FindManyQueryConfig<T, V>,
value: PgUpdateSetSource<U>,
opts: {
columns: TSelectedFields;
where?: QConfig["where"];
tx?: Transaction<T>;
): Promise<SelectResultFields<TSelectedFields>[]>;
async update<QConfig extends FindManyQueryConfig<T, V>>(
value: PgUpdateSetSource<U>,
opts?: {
where?: QConfig["where"];
tx?: Transaction<T>;
): Promise<InferSelectModel<U>[]>;
async update<
TSelectedFields extends SelectedFieldsFlat,
QConfig extends FindManyQueryConfig<T, V>,
value: PgUpdateSetSource<U>,
opts?: {
columns?: TSelectedFields;
where?: QConfig["where"];
tx?: Transaction<T>;
) {
try {
let where;
if (opts?.where) {
if ("queryChunks" in opts.where) {
where = opts.where;
} else if (typeof opts.where === "function") {
where = opts.where(getTableColumns(this.table), getOperators());
await this.beforeUpdate(value);
const qb = (opts?.tx || this.db.orm)
? { updatedAt: sql`NOW()` }
: {}),
if (opts?.columns) {
} else {
const rows = await startSpan(
name: qb.toSQL().sql,
async () => qb,
if (rows.length < 1) {
return [];
await Promise.all( => this.afterUpdate(row)));
return rows;
} catch (err) {
throw this.#toDatabaseError(err);
import type { PrimarySchema } from "./databases/primary/schemas/index";
import type { User } from "./databases/primary/schemas/index";
import { Repository } from "./repository";
export class UserRepository extends Repository<
> {
async beforeCreate(row: User) {}
async afterCreate(row: User) {}
async afterDelete(row: User) {}
async afterFind(row: User) {}
async beforeUpdate(row: User) {}
async afterUpdate(row: User) {}
Copy link

Hi, thanks for this. I'm trying to replicate this code in my project. I'm having trouble understanding how AppDb is types.


I use pgTable for defining the tables/migrations. But I see Drizzle has a way of defining tables through schema

Do you know if it's possible to satisfy this type even if using pgTable?

Copy link

cayter commented Apr 21, 2024

Hi, thanks for this. I'm trying to replicate this code in my project. I'm having trouble understanding how AppDb is types.

image I use `pgTable` for defining the tables/migrations. But I see Drizzle has a way of defining tables through schema

Do you know if it's possible to satisfy this type even if using pgTable?

Just realised I didn't include AppDb<T>. Had just updated to include it.

Copy link

I used pgTable as in the Quick start

But this is not getting me the schema back. Do we need to work with instead of pgTables or is possible also with pgTable? I don't see how I can get the schema in typescript

Copy link

andresgutgon commented Apr 22, 2024

Oh ok is a custom type. Got it, thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment