Skip to content

Instantly share code, notes, and snippets.

@chimame
Created March 17, 2024 11:53
Show Gist options
  • Save chimame/f8ab9ae3172ded0e97f64010cad3d578 to your computer and use it in GitHub Desktop.
Save chimame/f8ab9ae3172ded0e97f64010cad3d578 to your computer and use it in GitHub Desktop.
Automatic rollback of vitest using drizzle
import { Client } from "pg";
import { drizzle } from "drizzle-orm/node-postgres";
import * as schema from "../drizzle/schema";
import { Logger } from "drizzle-orm/logger";
export async function createContext() {
const client = new Client({
connectionString: "your database connection string",
});
await client.connect();
const db = drizzle(client, {
schema,
logger: process.env.NODE_ENV === "development",
});
return {
db,
};
}
import { createContext } from "../context";
import { sql } from "drizzle-orm";
import * as crypto from "crypto";
import { NodePgTransaction } from "drizzle-orm/node-postgres";
type Context = Awaited<ReturnType<typeof createContext>>;
declare module "vitest" {
export interface TestContext {
ctx: Context;
}
}
let ctx: Context;
let savePoints: string[] = [];
beforeAll(async () => {
ctx = await createContext();
await ctx.db.execute(sql`BEGIN`);
});
beforeEach(async (context) => {
context.ctx = ctx;
const uuid = crypto.randomUUID();
savePoints.push(uuid);
// savepoint name must begin with a letter
await ctx.db.execute(sql.raw(`SAVEPOINT A${uuid.replace(/\-/g, "")}`));
const nodePgSessionSpy = vi.spyOn(NodePgTransaction.prototype, "execute");
nodePgSessionSpy.mockImplementation(async (query) => {
return {
rows: [],
rowCount: 0,
command: "SELECT",
oid: 0,
fields: [],
};
});
});
afterEach(async (context) => {
const uuid = savePoints.pop();
await context.ctx.db.execute(
sql.raw(`ROLLBACK TO SAVEPOINT A${uuid!.replace(/\-/g, "")}`),
);
});
afterAll(async () => {
await ctx.db.execute(sql`ROLLBACK`);
});
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment