Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Appsync Subscription support for subscriptions-transport-ws
import { WebSocketLink } from 'apollo-link-ws';
import { Auth } from 'aws-amplify';
import { print } from 'graphql/language/printer';
import * as url from 'url';
import { AWSSubscriptionClient } from './AWSSubscriptionClient';
const createWebsocketLink = async (
customRealtimeEndpoint: string,
defaultGraphqlEndpoint: string
): Promise<WebSocketLink> => {
// Too many issues were found when using the aws provided realtime-subscription-handshake-link and the library
// is poorly maintained. Because of this we are opting to use the community driven WebSocketLink and add in
// middleware that will allow us to authenticate properly into the appsync realtime endpoint.
// we need to generate the header and payload query string based on how aws does it
// https://github.com/awslabs/aws-mobile-appsync-sdk-js/blob/master/packages/aws-appsync-subscription-link/src/realtime-subscription-handshake-link.ts
const { host } = url.parse(defaultGraphqlEndpoint);
const middleware = {
applyMiddleware: async (options, next) => {
if (options.query) {
const header = await generateSubscriptionHeader({ host });
options.data = JSON.stringify({
query:
typeof options.query === 'string'
? options.query
: print(options.query),
variables: options.variables,
});
options.extensions = {
authorization: {
...header,
},
};
}
next();
},
};
const websocketUrl: string = await getAppsyncWebSocketUrl(
host as string,
customRealtimeEndpoint
);
let subscription: AWSSubscriptionClient;
const connectionCallback = async message => {
if (message) {
const { errors } = message;
if (errors && errors.length > 0) {
const error = errors[0];
if (error) {
if (error.errorCode === 401) {
if (subscription) {
subscription.setUrl(
await getAppsyncWebSocketUrl(
host as string,
customRealtimeEndpoint
)
);
// reapply middleware to operation options since it could have
// an invalid token embedded in the options
for (const key in Object.keys(subscription.operations)) {
if (key) {
const val = subscription.operations[key];
if (val) {
val.options = await subscription.applyMiddlewares(
val.options
);
}
}
}
// force close after a 401. this will auto-reconnect if reconnect = true
// on the client options
subscription.close(false, false);
}
}
}
}
}
};
subscription = new AWSSubscriptionClient(websocketUrl, {
reconnect: true,
timeout: 5 * 60 * 1000,
connectionCallback,
});
const wsLink = new WebSocketLink(subscription);
// @ts-ignore
wsLink.subscriptionClient.use([middleware]);
return wsLink;
};
const generateSubscriptionHeader = async ({ host }): Promise<any> => {
return {
Authorization: (await Auth.currentSession()).getAccessToken().getJwtToken(),
host,
};
};
const getAppsyncWebSocketUrl = async (
internalGraphqlHost: string,
realtimeEndpoint: string
): Promise<string> => {
const headerObj: any = {
Authorization: (await Auth.currentSession()).getAccessToken().getJwtToken(),
host: internalGraphqlHost,
};
const headerBase64 = Buffer.from(JSON.stringify(headerObj)).toString(
'base64'
);
const payloadBase64 = Buffer.from('{}').toString('base64');
const convertedRealtimeEndpoint: string = await convertRealtimeEndpoint(
realtimeEndpoint
);
const websocketUrl: string = `${convertedRealtimeEndpoint}?header=${headerBase64}&payload=${payloadBase64}`;
return websocketUrl;
};
const convertRealtimeEndpoint = async (endpoint): Promise<string> => {
return endpoint
.replace('https://', 'wss://')
.replace('appsync-api', 'appsync-realtime-api')
.replace('gogi-beta', 'grt-beta');
};
import { uniqBy } from 'lodash';
import { ClientOptions, SubscriptionClient } from 'subscriptions-transport-ws';
export class AWSSubscriptionClient extends SubscriptionClient {
constructor(
url: string,
options?: ClientOptions,
webSocketImpl?: any,
webSocketProtocols?: string | string[]
) {
super(url, options, webSocketImpl, webSocketProtocols);
// since we are in TS and these functions are private we cannot directly override
// in this child class so we use this trick (which is not safe) to override
// the parent functions
this['flushUnsentMessagesQueue'] = this.flush;
this['processReceivedData'] = this.process;
}
public setUrl(url: string): void {
super['url'] = url;
}
public getUnsentMessagesQueue(): any[] {
return this.unsentMessagesQueue || [];
}
public setUnsentMessagesQueue(queue: any[]): void {
this.unsentMessagesQueue = queue;
}
private flush() {
const messages = uniqBy(this.getUnsentMessagesQueue(), 'id');
messages.forEach(message => {
super.sendMessageRaw(message);
});
}
private process(receivedData: any) {
try {
const message = JSON.parse(receivedData);
// ignore start_ack message from appsync since this isn't
// treated as a valid gql message type
if (message.type === 'start_ack') {
const newQueue = this.getUnsentMessagesQueue().filter(
el => el.id !== message.id
);
this.setUnsentMessagesQueue(newQueue);
return;
}
super.processReceivedData(receivedData);
} catch (err) {
// do nothing
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment