|
// CUSTOM CHECKPOINTER (LangGraph Integration) |
|
// Extends BaseCheckpointSaver to support Multi-Tenancy. |
|
// Ensures that every state write ("put") allows strictly scoped access based on Tenant ID. |
|
|
|
import { BaseCheckpointSaver } from '@langchain/langgraph-checkpoint' |
|
import type { Checkpoint, CheckpointMetadata, CheckpointTuple } from '@langchain/langgraph-checkpoint' |
|
import type { RunnableConfig } from '@langchain/core/runnables' |
|
... |
|
|
|
export class SupabaseSaver extends BaseCheckpointSaver { |
|
private tenantId: string |
|
private safeQuery: (callback: (tx: any) => Promise<any>) => Promise<any> |
|
|
|
constructor(tenantId: string, safeQuery: (callback: (tx: any) => Promise<any>) => Promise<any>) { |
|
super() |
|
// Defensive validation |
|
if (!tenantId || tenantId.trim() === '') { |
|
throw new BadRequestError(`SupabaseSaver requires a valid tenantId, got: "${tenantId}"`) |
|
} |
|
this.tenantId = tenantId |
|
this.safeQuery = safeQuery |
|
} |
|
|
|
// Factory to create a saver scoped to a specific request/tenant |
|
static forTenant(tenantId: string, safeQuery: (callback: (tx: any) => Promise<any>) => Promise<any>) { |
|
return new SupabaseSaver(tenantId, safeQuery) |
|
} |
|
|
|
// Retrieve a checkpoint |
|
async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> { |
|
const threadId = config.configurable?.thread_id |
|
const checkpointId = config.configurable?.checkpoint_id |
|
|
|
if (!threadId) return undefined |
|
|
|
// QUERY: Enforce tenant_id in the WHERE clause |
|
return this.safeQuery(async (tx) => { |
|
const conditions = [ |
|
eq(checkpoints.threadId, threadId), |
|
eq(checkpoints.tenantId, this.tenantId), // <--- Security Enforcement |
|
] |
|
|
|
if (checkpointId) { |
|
conditions.push(eq(checkpoints.checkpointId, checkpointId)) |
|
} |
|
|
|
const query = checkpointId |
|
? tx |
|
.select() |
|
.from(checkpoints) |
|
.where(and(...conditions)) |
|
: tx |
|
.select() |
|
.from(checkpoints) |
|
.where(and(...conditions)) |
|
.orderBy(desc(checkpoints.checkpointId)) |
|
.limit(1) |
|
|
|
const [row] = await query |
|
if (!row) return undefined |
|
|
|
return { |
|
config, |
|
checkpoint: row.checkpoint as Checkpoint, |
|
metadata: row.metadata as CheckpointMetadata, |
|
} |
|
}) |
|
} |
|
|
|
// Save a checkpoint |
|
async put( |
|
config: RunnableConfig, |
|
checkpoint: Checkpoint, |
|
metadata: CheckpointMetadata, |
|
newVersions: any, // Channel versions |
|
): Promise<RunnableConfig> { |
|
const threadId = config.configurable?.thread_id |
|
|
|
if (!threadId) { |
|
throw new BadRequestError('Missing thread_id in config') |
|
} |
|
|
|
// 1. DYNAMIC EXTRACTION: Check for 'reasoning_content' safely |
|
const reasoningMetadata = this._extractReasoning(checkpoint) |
|
|
|
// 2. MERGE: Combine standard metadata with any found reasoning |
|
const finalMetadata = { |
|
...metadata, |
|
...reasoningMetadata, |
|
} |
|
|
|
// WRITE: Inject tenant_id into the insert |
|
return this.safeQuery(async (tx) => { |
|
await tx |
|
.insert(checkpoints) |
|
.values({ |
|
threadId: threadId, |
|
checkpointId: checkpoint.id, |
|
parentCheckpointId: config.configurable?.checkpoint_id || null, |
|
type: 'checkpoint', |
|
checkpoint: checkpoint, |
|
metadata: finalMetadata, |
|
tenantId: this.tenantId, // <--- Security Enforcement |
|
}) |
|
.onConflictDoUpdate({ |
|
target: [checkpoints.threadId, checkpoints.checkpointId], |
|
set: { checkpoint, metadata: finalMetadata }, // Allow updates if ID collides (rare) |
|
where: eq(checkpoints.tenantId, this.tenantId), // <--- Security Enforcement |
|
}) |
|
|
|
return { |
|
configurable: { |
|
thread_id: threadId, |
|
checkpoint_id: checkpoint.id, |
|
}, |
|
} |
|
}) |
|
} |
|
|
|
// Save pending writes to channels |
|
async putWrites(config: RunnableConfig, writes: [string, any][], taskId: string): Promise<void> { |
|
const threadId = config.configurable?.thread_id |
|
const checkpointId = config.configurable?.checkpoint_id |
|
|
|
if (!threadId || !checkpointId) return |
|
|
|
// Save writes with tenant enforcement |
|
return this.safeQuery(async (tx) => { |
|
await tx.insert(checkpointWrites).values( |
|
writes.map(([channel, value], idx) => ({ |
|
threadId, |
|
checkpointId, |
|
taskId, |
|
channel, |
|
value, |
|
idx, |
|
tenantId: this.tenantId, // <--- Security Enforcement |
|
})), |
|
) |
|
}) |
|
} |
|
|
|
// List checkpoints for a thread |
|
list( |
|
config: RunnableConfig, |
|
options?: { limit?: number; before?: RunnableConfig }, |
|
): AsyncGenerator<CheckpointTuple, any, any> { |
|
const threadId = config.configurable?.thread_id |
|
if (!threadId) { |
|
// Return an empty async generator if no threadId |
|
async function* empty() { |
|
return |
|
} |
|
return empty() |
|
} |
|
|
|
const safeQuery = this.safeQuery |
|
const tenantId = this.tenantId |
|
|
|
// Create the generator function |
|
async function* generate() { |
|
const rows = await safeQuery(async (tx) => { |
|
const conditions = [ |
|
eq(checkpoints.threadId, threadId), |
|
eq(checkpoints.tenantId, tenantId), // <--- Security Enforcement |
|
] |
|
|
|
return tx |
|
.select() |
|
.from(checkpoints) |
|
.where(and(...conditions)) |
|
.orderBy(desc(checkpoints.checkpointId)) |
|
.limit(options?.limit ?? 10) |
|
}) |
|
|
|
for (const row of rows) { |
|
yield { |
|
config: { configurable: { thread_id: row.threadId, checkpoint_id: row.checkpointId } }, |
|
checkpoint: row.checkpoint as Checkpoint, |
|
metadata: row.metadata as CheckpointMetadata, |
|
} |
|
} |
|
} |
|
|
|
return generate() |
|
} |
|
|
|
// Delete a thread |
|
async deleteThread(threadId: string): Promise<void> { |
|
return this.safeQuery(async (tx) => { |
|
await tx |
|
.update(checkpoints) |
|
.set({ deletedAt: new Date() }) |
|
.where( |
|
and( |
|
eq(checkpoints.threadId, threadId), |
|
eq(checkpoints.tenantId, this.tenantId), // <--- Security Enforcement |
|
), |
|
) |
|
}) |
|
} |
|
|
|
// ---------------------------------------------------------------------- |
|
// HELPER: Safely extracts reasoning/thoughts from the graph state |
|
// ---------------------------------------------------------------------- |
|
private _extractReasoning(checkpoint: Checkpoint): Record<string, any> { |
|
const logger = createLogger() |
|
try { |
|
// 1. Locate the 'messages' channel (standard in LangGraph) |
|
const messages = checkpoint.channel_values['messages'] |
|
|
|
if (!Array.isArray(messages) || messages.length === 0) { |
|
return {} |
|
} |
|
|
|
// 2. Get the last message (the one currently being saved) |
|
const lastMessage = messages[messages.length - 1] |
|
|
|
// 3. Check for specific DeepSeek/LiteLLM fields |
|
const additionalKwargs = lastMessage.kwargs?.additional_kwargs || lastMessage.additional_kwargs || {} |
|
|
|
// DeepSeek V3/R1 field |
|
if (additionalKwargs.reasoning_content) { |
|
return { reasoning_content: additionalKwargs.reasoning_content } |
|
} |
|
|
|
// Fallback: Check if it's inside the standard content (interleaved thinking) |
|
if (typeof lastMessage.content === 'string' && lastMessage.content.includes('<think>')) { |
|
const match = lastMessage.content.match(/<think>(.*?)<\/think>/s) |
|
if (match && match[1]) { |
|
return { reasoning_content: match[1].trim() } |
|
} |
|
} |
|
|
|
return {} |
|
} catch (e) { |
|
// Fail silently to ensure saving never breaks due to extraction logic |
|
logger.error( |
|
`Error extracting reasoning content from checkpoint: ${e instanceof Error ? e.message : 'Unknown error'}`, |
|
) |
|
return {} |
|
} |
|
} |
|
} |