Skip to content

Instantly share code, notes, and snippets.

@gunhaxxor
Last active May 28, 2023 08:56
Show Gist options
  • Save gunhaxxor/1e8d2593697e174e418d128e2319e4e4 to your computer and use it in GitHub Desktop.
Save gunhaxxor/1e8d2593697e174e418d128e2319e4e4 to your computer and use it in GitHub Desktop.
TRPC adapter for uWebsockets.js. The ws-adapter is a modified version of the original ws-adapter in the TRPC repo. Also added some utility functions to make it easier to setup subscriptions that don't send data to the triggering client.
import {
createTRPCProxyClient,
createWSClient,
wsLink,
} from '@trpc/client';
import { Unsubscribable } from '@trpc/server/observable';
import AbortController from 'abort-controller';
import fetch from 'node-fetch';
import ws from 'ws';
import type { AppRouter } from './server';
// polyfill fetch & websocket
const globalAny = global as any;
globalAny.AbortController = AbortController;
globalAny.fetch = fetch;
globalAny.WebSocket = ws;
const randomInt = Math.trunc(Math.random()*1000)
const wsClient = createWSClient({
url: `ws://localhost:2022?user-${randomInt}`,
});
const trpc = createTRPCProxyClient<AppRouter>({
links: [
wsLink({
client: wsClient,
})
],
});
async function main() {
const myToken = await trpc.room.getMyToken.query();
console.log('MY TOKEN IS:', myToken);
const subShouldNotTrigger = await new Promise<Unsubscribable>(resolve => {
const sub = trpc.room.onRoomUpdate.subscribe({excludeSelf: true}, {
onData: (data) => console.log('received subscribed roomState:', data),
onStarted() {
resolve(sub);
},
});
})
const createdRoom = await trpc.room.createAndJoinRoom.mutate('coolRoom');
console.log('created room: ', createdRoom);
// subShouldNotTrigger.unsubscribe();
await trpc.room.updateMyPosition.mutate({
x: 1,
y: 2,
z: 3
});
const subShouldTrigger = await new Promise<Unsubscribable>(resolve => {
const sub = trpc.room.onRoomUpdate.subscribe({excludeSelf: false}, {
onData: (data) => console.log('received subscribed roomState:', data),
onStarted() {
resolve(sub);
},
});
})
const createdRoom2 = await trpc.room.createAndJoinRoom.mutate('boringRoom');
console.log('created room 2:', createdRoom2);
subShouldTrigger.unsubscribe();
wsClient.close();
}
main();
import { initTRPC, TRPCError } from '@trpc/server';
import { applyWSHandler } from './ws-adapter';
import { z } from 'zod';
import uWebSockets from 'uWebSockets.js';
import { TypedEmitter } from "tiny-typed-emitter";
import { attachFilteredEmitter, FilteredEvents } from "./trpc-utils";
type UData = {
token: string
}
type ClientEvents = FilteredEvents<{
'roomState': (room: RoomStateMessage) => void;
'kickedFromRoom': (roomId: string) => void;
}, UData['token']>;
const t = initTRPC.context<UData>().create();
const publicProcedure = t.procedure;
const router = t.router;
const clientInfo = z.object({
id: z.string(),
role: z.union([z.literal('admin'), z.literal('user'), z.literal('guest')]),
position: z.optional(z.tuple([z.number(), z.number(), z.number()])),
currentRoom: z.optional(z.string()),
clientEmitter: z.custom<TypedEmitter<ClientEvents>>(d => d instanceof TypedEmitter)
})
type ClientInfo = z.infer<typeof clientInfo>
const clientInfoMessage = clientInfo.pick({
id: true,
role: true,
position: true,
currentRoom: true
})
type ClientInfoMessage = z.infer<typeof clientInfoMessage>
function getClientInfoMessage(clientInfo: ClientInfo): ClientInfoMessage {
return clientInfoMessage.parse(clientInfo);
}
const roomState = z.object({
roomId: z.string(),
clients: z.object({}).catchall(clientInfo)
})
type RoomState = z.infer<typeof roomState>
const roomStateMessage = roomState.extend({
clients: z.object({}).catchall(clientInfoMessage)
});
type RoomStateMessage = z.infer<typeof roomStateMessage>
function getRoomStateMessage(roomState: RoomState): RoomStateMessage {
return roomStateMessage.parse(roomState);
}
const connectedClients: Map<string, ClientInfo> = new Map();
const rooms: Map<string, RoomState> = new Map();
function addUserToRoom(userId: string, roomId: string){
const room = rooms.get(roomId)
if(!room)
throw Error('no room with that id found')
const client = connectedClients.get(userId);
if(!client)
throw Error('no client with that id found')
room.clients[userId] = client;
client.currentRoom = room.roomId;
return room;
}
function broadcastRoomState(room: RoomState, triggeringClient: string){
const roomMessage = getRoomStateMessage(room);
// console.log('broadcasting room:', roomMessage);
for(const client of Object.values(room.clients)){
if(!client.clientEmitter)
continue;
client.clientEmitter.emit('roomState', roomMessage, client.id);
// client.clientEmitter.emit('testEvent', client, client.id, client.id);
}
}
function getMe(userId: string){
const me = connectedClients.get(userId);
if(!me)
throw new TRPCError({code: 'NOT_FOUND', message: 'didnt self among backedn clients'});
return me;
}
const roomRouter = router({
getMyToken: publicProcedure
.query(({ ctx }) => {
return ctx.token
}),
updateMyPosition: publicProcedure
.input(z.object({
x: z.number(),
y: z.number(),
z: z.number(),
}))
.mutation(({input, ctx}) => {
const me = getMe(ctx.token);
me.position = [input.x, input.y, input.z];
}),
getMyRoom: publicProcedure
.query(({ ctx }) => {
const me = getMe(ctx.token);
const err = new TRPCError({code: 'NOT_FOUND', message: 'you are not in a room'});
if(!me.currentRoom)
throw err;
const room = rooms.get(me.currentRoom);
if(!room)
throw err;
return getRoomStateMessage(room);
}),
createAndJoinRoom: publicProcedure
.input(z.string())
.mutation(({input, ctx})=> {
rooms.set(input, {
roomId: input,
clients: {}
});
const room = addUserToRoom(ctx.token, input);
broadcastRoomState(room, ctx.token)
return getRoomStateMessage(room);
}),
joinRoom: publicProcedure
.input(z.string())
.mutation(({input: roomName, ctx})=> {
const room = addUserToRoom(ctx.token, roomName);
broadcastRoomState(room, ctx.token);
return getRoomStateMessage(room);
}),
onRoomUpdate: publicProcedure.input(z.object({excludeSelf: z.boolean()})).subscription(({input: {excludeSelf}, ctx}) => {
// console.log('subscription request received:', ctx);
const me = getMe(ctx.token);
const filter = excludeSelf? me.id : undefined;
return attachFilteredEmitter(me.clientEmitter, 'roomState', filter);
}),
});
// Merge routers together
const appRouter = router({
room: roomRouter,
});
export type AppRouter = typeof appRouter;
// ws server
const { onSocketMessage, onSocketOpen, onSocketClose} = applyWSHandler<AppRouter, UData>({
router: appRouter,
});
const app = uWebSockets.App().ws<UData>('/*', {
upgrade: (res, req, ctx) => {
// console.log('upgrade request received:', req);
// console.log('ws ctx:', ctx);
const token = req.getQuery();
res.upgrade<UData>(
{
token
},
/* Spell these correctly */
req.getHeader('sec-websocket-key'),
req.getHeader('sec-websocket-protocol'),
req.getHeader('sec-websocket-extensions'),
ctx
);
},
open: (ws) => {
const uData = Object.assign({}, ws.getUserData());
onSocketOpen(ws, uData);
const e: ClientInfo['clientEmitter'] = new TypedEmitter();
const newClient: ClientInfo = {
id: uData.token,
role: 'user',
clientEmitter: e,
}
// newClient.clientEmitter.on('roomState', (roomState, filter) => console.log(newClient.id, ': my emitter was triggered:', filter));
connectedClients.set(uData.token, newClient);
},
message: (ws, msg) => {
// console.log('message received: ', msg);
const msgStr = Buffer.from(msg).toString();
// console.log('stringified msg:', msgStr);
onSocketMessage(ws, msgStr);
},
close: (ws, code, msg) => {
const msgStr = Buffer.from(msg).toString();
// console.log('ws was closed:', code, msgStr);
onSocketClose(ws, msgStr);
connectedClients.delete(ws.getUserData().token);
}
})
app.listen(2022, (ls) => {
console.log('listening on port 2022', ls);
})
import {observable} from '@trpc/server/observable';
import { ListenerSignature, TypedEmitter } from 'tiny-typed-emitter';
//Internal utility types
type EmitterCallback<E extends ListenerSignature<E>, K extends keyof E> = E[K];
type EventArgument<E extends ListenerSignature<E>, K extends keyof E> = Parameters<EmitterCallback<E,K>>[0]
type AddFilterParam<FuncType extends (...args: any) => any, FilterType> = (...args: [...parameters: Parameters<FuncType>, filter: FilterType]) => ReturnType<FuncType>;
type AddFilterToEvents<IEvents extends ListenerSignature<IEvents>, FilterType> = {
[K in keyof IEvents]: AddFilterParam<IEvents[K], FilterType>
}
export type FilteredEvents<E extends {[K in keyof E]: (p: any) => void}, FilterType> = AddFilterToEvents<E, FilterType>
// export function attachEmitter<E extends ListenerSignature<E>, K extends keyof E>(emitter: TypedEmitter<E>, event: K){
// return observable<EventArgument<E, typeof event>>(emit => {
// const onEvent = (data: EventArgument<E,typeof event>): void => {
// console.log('emitter triggered');
// emit.next(data);
// }
// emitter.on(event, onEvent as E[typeof event]);
// return () => {
// emitter.off(event, onEvent as E[typeof event]);
// }
// })
// }
export function attachFilteredEmitter<E extends ListenerSignature<E>, K extends keyof E, FilterType>(emitter: TypedEmitter<E>, event: K, filter: FilterType){
return observable<EventArgument<E, typeof event>>(emit => {
const onEvent = (data: EventArgument<E,typeof event>, triggerId: FilterType): void => {
if(triggerId === filter){
// console.log('skipping because emitter is filtered');
return
}
// console.log('emitter triggered');
emit.next(data);
}
emitter.on(event, onEvent as E[typeof event]);
return () => {
emitter.off(event, onEvent as E[typeof event]);
}
})
}
import {
AnyRouter,
ProcedureType,
callProcedure,
TRPCError,
} from '@trpc/server';
import { OnErrorFunction } from '@trpc/server/dist/internals/types';
import { Unsubscribable, isObservable } from '@trpc/server/observable';
import {
JSONRPC2,
TRPCClientOutgoingMessage,
TRPCReconnectNotification,
TRPCResponse,
TRPCResponseMessage,
} from '@trpc/server/rpc';
import { CombinedDataTransformer } from '@trpc/server/dist/transformer';
function parseMessage(
obj: unknown,
transformer: CombinedDataTransformer,
): TRPCClientOutgoingMessage {
assertIsObject(obj);
const { method, params, id, jsonrpc } = obj;
assertIsRequestId(id);
assertIsJSONRPC2OrUndefined(jsonrpc);
if (method === 'subscription.stop') {
return {
id,
jsonrpc,
method,
};
}
assertIsProcedureType(method);
assertIsObject(params);
const { input: rawInput, path } = params;
assertIsString(path);
const input = transformer.input.deserialize(rawInput);
return {
id,
jsonrpc,
method,
params: {
input,
path,
},
};
}
type BasicSendFunction = (message: string) => void;
interface MinimalWSInterface {
send: BasicSendFunction
}
type OnErrorWithoutRequest = (opts: Omit<Parameters<OnErrorFunction<AnyRouter, undefined>>[0], 'req'>) => void
export interface WSHandlerOptions<TRouter extends AnyRouter> {
onError?: OnErrorWithoutRequest
router: TRouter
}
export function applyWSHandler<TRouter extends AnyRouter, Ctx>(opts: WSHandlerOptions<TRouter>) {
const { router, } = opts;
const { transformer } = router._def._config;
const websockets: Map<MinimalWSInterface, { subscriptions: Map<number | string, Unsubscribable>, ctx: Ctx }> = new Map();
const onSocketOpen = (ws: MinimalWSInterface, ctx: Ctx) => {
// console.log('ws-adapter: ws opened');
websockets.set(ws, { subscriptions: new Map(), ctx });
}
const onSocketMessage = async (ws: MinimalWSInterface, stringifiedMessage: string) => {
// console.log('ws-adapter: msg received');
// console.dir(stringifiedMessage);
try {
const msgJSON: unknown = JSON.parse(stringifiedMessage);
const msgs: unknown[] = Array.isArray(msgJSON) ? msgJSON : [msgJSON];
const promises = msgs
.map((raw) => parseMessage(raw, transformer))
.map(msg => handleRequest(ws, msg, ws.send));
await Promise.all(promises);
} catch (cause) {
const error = new TRPCError({
code: 'PARSE_ERROR',
cause: getCauseFromUnknown(cause),
});
respond(ws, {
id: null,
error: router.getErrorShape({
error,
type: 'unknown',
path: undefined,
input: undefined,
ctx: undefined,
}),
});
}
}
const onSocketClose = (ws: MinimalWSInterface, msg: string) => {
// console.log('ws-adapter: ws closed');
const wsData = websockets.get(ws);
if(!wsData)
return
const { subscriptions } = wsData
for(const sub of subscriptions.values()){
sub.unsubscribe();
}
subscriptions.clear();
websockets.delete(ws);
}
function respond(ws: MinimalWSInterface, untransformedJSON: TRPCResponseMessage) {
const response = JSON.stringify(transformTRPCResponse(router, untransformedJSON))
// console.log('created response:', response);
// return response
ws.send(response);
};
function stopSubscription(
ws: MinimalWSInterface,
subscription: Unsubscribable,
{ id, jsonrpc }: { id: JSONRPC2.RequestId } & JSONRPC2.BaseEnvelope,
) {
subscription.unsubscribe();
respond(ws, {
id,
jsonrpc,
result: {
type: 'stopped',
},
});
}
const handleRequest = async (ws: MinimalWSInterface, msg: TRPCClientOutgoingMessage, send: BasicSendFunction) => {
if(!ws){
throw new Error('handler was called with undefined websocket instance')
}
const { id, jsonrpc } = msg;
if (id === null) {
throw new TRPCError({
code: 'BAD_REQUEST',
message: '`id` is required',
});
}
const wsData = websockets.get(ws);
if (!wsData){
throw new TRPCError({
code: 'INTERNAL_SERVER_ERROR',
message: 'websocket instance not found in the adapter/handler'
})
}
const { ctx, subscriptions } = wsData
if (msg.method === 'subscription.stop') {
const sub = subscriptions?.get(id);
if (sub) {
stopSubscription(ws, sub, { id, jsonrpc });
}
subscriptions.delete(id);
return;
}
const { path, input } = msg.params;
const type = msg.method;
try {
const result = await callProcedure({
procedures: router._def.procedures,
path,
rawInput: input,
ctx,
type,
});
if (type === 'subscription') {
if (!isObservable(result)) {
throw new TRPCError({
message: `Subscription ${path} did not return an observable`,
code: 'INTERNAL_SERVER_ERROR',
});
}
} else {
// send the value as data if the method is not a subscription
respond(ws, {
id,
jsonrpc,
result: {
type: 'data',
data: result,
},
});
return;
}
const observable = result;
const sub = observable.subscribe({
next(data) {
respond(ws, {
id,
jsonrpc,
result: {
type: 'data',
data,
},
});
},
error(err) {
const error = getTRPCErrorFromUnknown(err);
// if there was an error callback provided we call it here
opts.onError?.({ error, path, type, ctx, input });
respond(ws, {
id,
jsonrpc,
error: router.getErrorShape({
error,
type,
path,
input,
ctx,
}),
});
},
complete() {
respond(ws, {
id,
jsonrpc,
result: {
type: 'stopped',
},
});
},
});
if (subscriptions.has(id)) {
// duplicate request ids for client
stopSubscription(ws, sub, { id, jsonrpc });
throw new TRPCError({
message: `Duplicate id ${id}`,
code: 'BAD_REQUEST',
});
}
subscriptions.set(id, sub);
respond(ws, {
id,
jsonrpc,
result: {
type: 'started',
},
});
} catch (cause) {
// procedure threw an error
const error = getTRPCErrorFromUnknown(cause);
opts.onError?.({ error, path, type, ctx, input });
respond(ws, {
id,
jsonrpc,
error: router.getErrorShape({
error,
type,
path,
input,
ctx,
}),
});
}
}
return {
onSocketOpen,
onSocketMessage,
onSocketClose,
broadcastReconnectNotification: () => {
const response: TRPCReconnectNotification = {
id: null,
method: 'reconnect',
}
const data = JSON.stringify(response);
for(const client of websockets.keys()){
client.send(data);
}
}
}
}
function assertIsObject(obj: unknown): asserts obj is Record<string, unknown> {
if (typeof obj !== 'object' || Array.isArray(obj) || !obj) {
throw new Error('Not an object');
}
}
function assertIsProcedureType(obj: unknown): asserts obj is ProcedureType {
if (obj !== 'query' && obj !== 'subscription' && obj !== 'mutation') {
throw new Error('Invalid procedure type');
}
}
function assertIsRequestId(
obj: unknown,
): asserts obj is number | string | null {
if (
obj !== null &&
typeof obj === 'number' &&
isNaN(obj) &&
typeof obj !== 'string'
) {
throw new Error('Invalid request id');
}
}
function assertIsString(obj: unknown): asserts obj is string {
if (typeof obj !== 'string') {
throw new Error('Invalid string');
}
}
function assertIsJSONRPC2OrUndefined(
obj: unknown,
): asserts obj is '2.0' | undefined {
if (typeof obj !== 'undefined' && obj !== '2.0') {
throw new Error('Must be JSONRPC 2.0');
}
}
function getMessageFromUnkownError(
err: unknown,
fallback: string,
): string {
if (typeof err === 'string') {
return err;
}
if (err instanceof Error && typeof err.message === 'string') {
return err.message;
}
return fallback;
}
function getErrorFromUnknown(cause: unknown): Error {
if (cause instanceof Error) {
return cause;
}
const message = getMessageFromUnkownError(cause, 'Unknown error');
return new Error(message);
}
function getTRPCErrorFromUnknown(cause: unknown): TRPCError {
const error = getErrorFromUnknown(cause);
// this should ideally be an `instanceof TRPCError` but for some reason that isn't working
// ref https://github.com/trpc/trpc/issues/331
if (error.name === 'TRPCError') {
return cause as TRPCError;
}
const trpcError = new TRPCError({
code: 'INTERNAL_SERVER_ERROR',
cause: error,
message: error.message,
});
// Inherit stack from error
trpcError.stack = error.stack;
return trpcError;
}
function getCauseFromUnknown(cause: unknown) {
if (cause instanceof Error) {
return cause;
}
return undefined;
}
function transformTRPCResponseItem<
TResponseItem extends TRPCResponse | TRPCResponseMessage,
>(router: AnyRouter, item: TResponseItem): TResponseItem {
if ('error' in item) {
return {
...item,
error: router._def._config.transformer.output.serialize(item.error),
};
}
if ('data' in item.result) {
return {
...item,
result: {
...item.result,
data: router._def._config.transformer.output.serialize(
item.result.data,
),
},
};
}
return item;
}
/**
* Takes a unserialized `TRPCResponse` and serializes it with the router's transformers
**/
function transformTRPCResponse<
TResponse extends
| TRPCResponse
| TRPCResponse[]
| TRPCResponseMessage
| TRPCResponseMessage[],
>(router: AnyRouter, itemOrItems: TResponse) {
return Array.isArray(itemOrItems)
? itemOrItems.map((item) => transformTRPCResponseItem(router, item))
: transformTRPCResponseItem(router, itemOrItems);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment