Skip to content

Instantly share code, notes, and snippets.

@b0o
Last active May 27, 2023 00:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save b0o/2ecbf884e671e2ada6ab2248a87e3191 to your computer and use it in GitHub Desktop.
Save b0o/2ecbf884e671e2ada6ab2248a87e3191 to your computer and use it in GitHub Desktop.
// Ported from https://github.com/trpc/trpc/blob/main/packages/server/src/adapters/ws.ts
import type { ServeOptions, Server, ServerWebSocket } from 'bun'
import {
AnyRouter,
CombinedDataTransformer,
ProcedureType,
TRPCError,
callProcedure,
getTRPCErrorFromUnknown,
inferRouterContext,
} from '@trpc/server'
import { Unsubscribable, isObservable } from '@trpc/server/observable'
import { JSONRPC2, TRPCClientOutgoingMessage, TRPCResponseMessage } from '@trpc/server/rpc'
import { transformTRPCResponse } from '@trpc/server/shared'
import { FetchCreateContextOption } from '@trpc/server/adapters/fetch'
function getCauseFromUnknown(cause: unknown) {
if (cause instanceof Error) {
return cause
}
return undefined
}
function assertIsObject(obj: unknown): asserts obj is Record<string, unknown> {
if (typeof obj !== 'object' || Array.isArray(obj) || !obj) {
throw new Error('Not an object')
}
}
function assertIsProcedureType(obj: unknown): asserts obj is ProcedureType {
if (obj !== 'query' && obj !== 'subscription' && obj !== 'mutation') {
throw new Error('Invalid procedure type')
}
}
function assertIsRequestId(obj: unknown): asserts obj is number | string | null {
if (obj !== null && typeof obj === 'number' && isNaN(obj) && typeof obj !== 'string') {
throw new Error('Invalid request id')
}
}
function assertIsString(obj: unknown): asserts obj is string {
if (typeof obj !== 'string') {
throw new Error('Invalid string')
}
}
function assertIsJSONRPC2OrUndefined(obj: unknown): asserts obj is '2.0' | undefined {
if (typeof obj !== 'undefined' && obj !== '2.0') {
throw new Error('Must be JSONRPC 2.0')
}
}
export function parseMessage(obj: unknown, transformer: CombinedDataTransformer): TRPCClientOutgoingMessage {
assertIsObject(obj)
const { method, params, id, jsonrpc } = obj
assertIsRequestId(id)
assertIsJSONRPC2OrUndefined(jsonrpc)
if (method === 'subscription.stop') {
return {
id,
jsonrpc,
method,
}
}
assertIsProcedureType(method)
assertIsObject(params)
const { input: rawInput, path } = params
assertIsString(path)
const input = transformer.input.deserialize(rawInput)
return {
id,
jsonrpc,
method,
params: {
input,
path,
},
}
}
interface BunWSSHandlerData {
id?: number
req: Request
}
type WSSMessage = string | ArrayBuffer | Uint8Array
function wssMessageToString(data: WSSMessage): string {
if (typeof data === 'string') {
return data
}
if (data instanceof ArrayBuffer) {
return Buffer.from(data).toString('utf-8')
}
return Buffer.from(data).toString('utf-8')
}
function parseWssMessage(data: WSSMessage): unknown {
return JSON.parse(wssMessageToString(data))
}
export type BunWSSHandlerOptions<TData extends BunWSSHandlerData, TRouter extends AnyRouter> = {
router: TRouter
onOpen?: (this: BunWSSHandler<TData, TRouter>, client: ServerWebSocket<TData>) => void | Promise<void>
onMessage?: (
this: BunWSSHandler<TData, TRouter>,
client: ServerWebSocket<TData>,
message: WSSMessage,
) => void | Promise<void>
onClose?: (
this: BunWSSHandler<TData, TRouter>,
client: ServerWebSocket<TData>,
code: number,
reason: string,
) => void | Promise<void>
onDrain?: (this: BunWSSHandler<TData, TRouter>, client: ServerWebSocket<TData>) => void | Promise<void>
} & FetchCreateContextOption<TRouter>
export interface BunWSSClient<TData, TContext> {
ws: ServerWebSocket<TData>
subscriptions: Map<number | string, Unsubscribable>
ctx?: TContext
}
enum ReadyState {
CONNECTING = 0,
OPEN = 1,
CLOSING = 2,
CLOSED = 3,
}
export class BunWSSHandler<TData extends BunWSSHandlerData, TRouter extends AnyRouter> {
private opts: BunWSSHandlerOptions<TData, TRouter>
private router: TRouter
private clients: Map<number, BunWSSClient<TData, inferRouterContext<TRouter>>>
private nextId = 0
constructor(opts: BunWSSHandlerOptions<TData, TRouter>) {
console.log('BunWSSHandler constructor', opts)
this.opts = opts
this.router = opts.router
this.clients = new Map()
}
get connectionCount() {
return this.clients.size
}
private handleError(ws: ServerWebSocket<TData>, error: unknown) {
const err = error instanceof Error ? error : new Error(`Unknown error: ${error}`)
let reason = 'Unknown error'
if (error instanceof Error) {
reason = error.message
}
let code = 1008
if (typeof err === 'object' && 'code' in err && typeof err.code === 'number') {
code = err.code
}
ws.close(code, reason)
}
private respond(ws: ServerWebSocket<TData>, untransformedJSON: TRPCResponseMessage) {
ws.send(JSON.stringify(transformTRPCResponse(this.router, untransformedJSON)))
}
async handleRequest(ws: ServerWebSocket<TData>, msg: TRPCClientOutgoingMessage) {
if (ws.data.id === undefined) {
throw new TRPCError({ code: 'BAD_REQUEST', message: 'Missing client id' })
}
const client = this.clients.get(ws.data.id)
if (!client) {
throw new TRPCError({ code: 'BAD_REQUEST', message: 'Unknown client id' })
}
const { id, jsonrpc } = msg
/* istanbul ignore next -- @preserve */
if (id === null) {
throw new TRPCError({
code: 'BAD_REQUEST',
message: 'missing message id',
})
}
const stopSubscription = (
subscription: Unsubscribable,
{ id, jsonrpc }: { id: JSONRPC2.RequestId } & JSONRPC2.BaseEnvelope,
) => {
subscription.unsubscribe()
this.respond(ws, {
id,
jsonrpc,
result: {
type: 'stopped',
},
})
}
if (msg.method === 'subscription.stop') {
const sub = client.subscriptions.get(id)
if (sub) {
stopSubscription(sub, { id, jsonrpc })
}
client.subscriptions.delete(id)
return
}
const { path, input } = msg.params
const type = msg.method
try {
// TODO
// await ws.ctxPromise // asserts context has been set
const result = await callProcedure({
procedures: this.opts.router._def.procedures,
path,
rawInput: input,
ctx: client.ctx,
type,
})
if (type === 'subscription') {
if (!isObservable(result)) {
throw new TRPCError({
message: `Subscription ${path} did not return an observable`,
code: 'INTERNAL_SERVER_ERROR',
})
}
} else {
// send the value as data if the method is not a subscription
this.respond(ws, {
id,
jsonrpc,
result: {
type: 'data',
data: result,
},
})
return
}
const observable = result
const sub = observable.subscribe({
next: data => {
console.log('next', data)
this.respond(ws, {
id,
jsonrpc,
result: {
type: 'data',
data,
},
})
},
error: err => {
const error = getTRPCErrorFromUnknown(err)
// this.opts.onError?.({ error, path, type, ctx: ws.ctx, req: this.opts.req, input })
this.respond(ws, {
id,
jsonrpc,
error: this.opts.router.getErrorShape({
error,
type,
path,
input,
ctx: client.ctx,
}),
})
},
complete: () => {
console.log('complete')
this.respond(ws, {
id,
jsonrpc,
result: {
type: 'stopped',
},
})
},
})
if (ws.readyState !== ReadyState.OPEN) {
// if the client got disconnected whilst initializing the subscription
// no need to send stopped message if the client is disconnected
sub.unsubscribe()
return
}
if (client.subscriptions.has(id)) {
// duplicate request ids for client
stopSubscription(sub, { id, jsonrpc })
throw new TRPCError({
message: `Duplicate id ${id}`,
code: 'BAD_REQUEST',
})
}
client.subscriptions.set(id, sub)
console.log('started')
this.respond(ws, {
id,
jsonrpc,
result: {
type: 'started',
},
})
} catch (cause) /* istanbul ignore next -- @preserve */ {
// procedure threw an error
const error = getTRPCErrorFromUnknown(cause)
// this.opts.onError?.({ error, path, type, ctx: ws.ctx, req: this.opts.req, input })
this.respond(ws, {
id,
jsonrpc,
error: this.opts.router.getErrorShape({
error,
type,
path,
input,
ctx: client.ctx,
}),
})
}
}
async open(ws: ServerWebSocket<TData>) {
ws.data.id = this.nextId++
const client: BunWSSClient<TData, inferRouterContext<TRouter>> = {
ws,
subscriptions: new Map(),
}
this.clients.set(ws.data.id, client)
if (this.opts.onOpen) {
try {
await this.opts.onOpen.call(this, ws)
} catch (error) {
this.handleError(ws, error)
return
}
}
// TODO
const resHeaders = new Headers()
const ctxPromise =
this.opts.createContext?.({
req: ws.data.req,
resHeaders,
}) ?? Promise.resolve()
try {
client.ctx = await ctxPromise
} catch (cause) {
const error = getTRPCErrorFromUnknown(cause)
// TODO
// this.opts.onError?.({
// error,
// path: undefined,
// type: 'unknown',
// ctx: ws.ctx,
// req: this.opts.req,
// input: undefined,
// })
this.respond(ws, {
id: null,
error: this.opts.router.getErrorShape({
error,
type: 'unknown',
path: undefined,
input: undefined,
ctx: client.ctx,
}),
})
// close in next tick
global.setImmediate(() => {
ws.close()
})
}
}
async message(ws: ServerWebSocket<TData>, message: WSSMessage) {
console.log(`message (${ws.data.id})`, message)
if (this.opts.onMessage) {
try {
await this.opts.onMessage.call(this, ws, message)
} catch (error) {
this.handleError(ws, error)
return
}
}
try {
const msgJSON: unknown = parseWssMessage(message)
const msgs: unknown[] = Array.isArray(msgJSON) ? msgJSON : [msgJSON]
const promises = msgs
.map(raw => parseMessage(raw, this.opts.router._def._config.transformer))
.map(msg => this.handleRequest(ws, msg))
await Promise.all(promises)
} catch (cause) {
const error = new TRPCError({
code: 'PARSE_ERROR',
cause: getCauseFromUnknown(cause),
})
this.respond(ws, {
id: null,
error: this.opts.router.getErrorShape({
error,
type: 'unknown',
path: undefined,
input: undefined,
ctx: undefined,
}),
})
}
}
close(ws: ServerWebSocket<TData>, code: number, reason: string) {
console.log('close', code, reason)
if (this.opts.onClose) {
try {
this.opts.onClose.call(this, ws, code, reason)
} catch (error) {
// TODO: Is it safe to call close() here?
this.handleError(ws, error)
return
}
}
if (ws.data.id !== undefined) {
this.clients.delete(ws.data.id)
}
}
drain(ws: ServerWebSocket<TData>) {
console.log('drain')
if (this.opts.onDrain) {
try {
this.opts.onDrain.call(this, ws)
} catch (error) {
this.handleError(ws, error)
return
}
}
}
serve(opts: Omit<ServeOptions, 'fetch' | 'websocket'>) {
return Bun.serve({
...opts,
fetch: (req: Request, server: Server) => {
const data: BunWSSHandlerData = {
req,
}
server.upgrade(req, { data })
return undefined
},
websocket: {
open: (ws: ServerWebSocket<TData>) => this.open(ws),
message: (ws: ServerWebSocket<TData>, message: WSSMessage) => this.message(ws, message),
close: (ws: ServerWebSocket<TData>, code: number, reason: string) => this.close(ws, code, reason),
drain: (ws: ServerWebSocket<TData>) => this.drain(ws),
},
})
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment