Last active
December 12, 2020 00:57
-
-
Save capnmidnight/0c9b708325530f0996864cafebfcd750 to your computer and use it in GitHub Desktop.
Workerizing Async code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { WorkerClient } from "./WorkerClient"; | |
import { fibonacci as _fibonacci } from "./ExampleFunction"; | |
// We'll have a pool of 10 workers from which to run. | |
const worker = new WorkerClient("ExampleServer.bundle.js", "ExampleServer.bundle.min.js", 10); | |
// We can provide a UI-thread fallback, in case we're running on a system that doesn't support workers | |
export const fibonnaci = worker.createExecutor<number>( | |
(n: number, onProgress?: progressCallback) => worker.execute("fibonacci", [n], onProgress), | |
(n: number, onProgress?: progressCallback) => _fibonacci(n, onProgress)); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { progressCallback } from "./WorkerServer"; | |
export function fibonacci(n: number, onProgress?: progressCallback): number { | |
if (n < 0) { | |
throw new Error("Fibonacci sequence is not defined for negative numbers"); | |
} | |
if (n < 2) { | |
return n; | |
} | |
let a = 0; | |
let b = 1; | |
for (let i = 2; i <= n; ++i) { | |
const c = a + b; | |
a = b; | |
b = c; | |
if (onProgress) { | |
onProgress(i, n); | |
} | |
} | |
return b; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// If we're careful, the workerized function can have the same, exact interface | |
// as the local version. We could swap these imports and have the same results, | |
// just not workerized. | |
import { fibonacci } from "./ExampleClient"; | |
// import { fibonacci } from "./ExampleFunction"; | |
async function test(n, expected) { | |
const result = await fibonacci(n); | |
const isGood = result === expected; | |
console.log(`fibonacci(${n}) = ${result}. Expected: ${expected}. ${isGood ? "Yay!" : "Oh no!"}`); | |
} | |
Promise.all([ | |
test(0, 0), | |
test(1, 1), | |
test(2, 1), | |
test(3, 2), | |
test(4, 3), | |
test(5, 5), | |
test(6, 8), | |
test(7, 13), | |
test(8, 21), | |
test(9, 34), | |
test(10, 55) | |
]); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { WorkerServer } from "./WorkerServer"; | |
import { fibonacci } from "./ExampleFunction"; | |
const server = new WorkerServer((globalThis as any) as DedicatedWorkerGlobalScope); | |
server.add("fibonacci", fibonacci); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { progressCallback, WorkerMethodMessages, WorkerMethodMessageType } from "./WorkerServer"; | |
function isFunction(obj: any): obj is Function { | |
return typeof obj === "function" || obj instanceof Function; | |
} | |
function isNumber(obj: any): obj is number { | |
return typeof obj === "number" || obj instanceof Number; | |
} | |
export type workerClientCallback<T> = (...params: any[]) => Promise<T>; | |
export class WorkerClient { | |
static isSupported = "Worker" in globalThis; | |
private taskCounter: number = 0; | |
private workers: Worker[]; | |
private script: string; | |
private _enabled: boolean = true; | |
get enabled(): boolean { | |
return this._enabled; | |
} | |
set enabled(v: boolean) { | |
this._enabled = v && WorkerClient.isSupported; | |
} | |
/** | |
* Creates a new pooled worker method executor. | |
* @param scriptPath - the path to the unminified script to use for the worker | |
* @param minScriptPath - the path to the minified script to use for the worker (optional) | |
* @param workerPoolSize - the number of worker threads to create for the pool (defaults to 1) | |
* @param disabled - optionally disable a worker to test UI-thread fallbacks (defaults to false) | |
*/ | |
constructor(scriptPath: string, minScriptPath: string = null, workerPoolSize: number = 1, disabled: boolean = false) { | |
if (!WorkerClient.isSupported) { | |
console.warn("Workers are not supported on this system."); | |
} | |
// Normalize constructor parameters. | |
if (isNumber(minScriptPath) | |
&& !workerPoolSize) { | |
workerPoolSize = minScriptPath; | |
workerPoolSize = null; | |
} | |
if (!workerPoolSize) { | |
workerPoolSize = 1; | |
} | |
if (workerPoolSize < 1) { | |
throw new Error("Worker pool size must be a postive integer greater than 0"); | |
} | |
// We can optionally | |
this._enabled = WorkerClient.isSupported && !disabled; | |
// Choose which version of the script we're going to load. | |
if (process.env.NODE_ENV === "development") { | |
this.script = scriptPath; | |
} | |
else { | |
this.script = minScriptPath || scriptPath; | |
} | |
this.workers = new Array(workerPoolSize); | |
} | |
/** | |
* Execute a method on the worker thread. | |
* @param methodName - the name of the method to execute. | |
* @param params - the parameters to pass to the method. | |
*/ | |
execute<T>(methodName: string, params: any[]): Promise<T>; | |
/** | |
* Execute a method on the worker thread. | |
* @param methodName - the name of the method to execute. | |
* @param params - the parameters to pass to the method. | |
* @param transferables - any values in any of the parameters that should be transfered instead of copied to the worker thread. | |
*/ | |
execute<T>(methodName: string, params: any[], transferables: Transferable[]): Promise<T>; | |
/** | |
* Execute a method on the worker thread. | |
* @param methodName - the name of the method to execute. | |
* @param params - the parameters to pass to the method. | |
* @param onProgress - a callback for receiving progress reports on long-running invocations. | |
*/ | |
execute<T>(methodName: string, params: any[], onProgress: progressCallback): Promise<T>; | |
/** | |
* Execute a method on the worker thread. | |
* @param methodName - the name of the method to execute. | |
* @param params - the parameters to pass to the method. | |
* @param transferables - any values in any of the parameters that should be transfered instead of copied to the worker thread. | |
* @param onProgress - a callback for receiving progress reports on long-running invocations. | |
*/ | |
execute<T>(methodName: string, params: any[], transferables: any = null, onProgress: any = null): Promise<T> { | |
if (!WorkerClient.isSupported) { | |
return Promise.reject(new Error("Workers are not supported on this system.")); | |
} | |
if (!this.enabled) { | |
console.warn("Workers invocations have been disabled."); | |
return Promise.resolve(undefined); | |
} | |
// Normalize method parameters. | |
if (isFunction(transferables) | |
&& !onProgress) { | |
onProgress = transferables; | |
transferables = null; | |
} | |
// taskIDs help us keep track of return values. | |
const taskID = this.taskCounter++; | |
// Workers are pooled, so the modulus selects them in a round-robin fashion. | |
const workerID = taskID % this.workers.length; | |
// Workers are lazily created | |
if (!this.workers[workerID]) { | |
this.workers[workerID] = new Worker(this.script); | |
} | |
const worker = this.workers[workerID]; | |
return new Promise((resolve, reject) => { | |
// When the invocation is complete, we want to stop listening to the worker | |
// message channel so we don't eat up processing messages that have no chance | |
// over pertaining to the invocation. | |
const cleanup = () => { | |
worker.removeEventListener("message", dispatchMessageResponse); | |
}; | |
const dispatchMessageResponse = (evt: MessageEvent<WorkerMethodMessages>) => { | |
const data = evt.data; | |
// Did this response message match the current invocation? | |
if (data.taskID === taskID) { | |
switch (data.methodName) { | |
case WorkerMethodMessageType.Progress: | |
if (isFunction(onProgress)) { | |
onProgress(data.soFar, data.total, data.msg); | |
} | |
break; | |
case WorkerMethodMessageType.Return: | |
cleanup(); | |
resolve(undefined); | |
break; | |
case WorkerMethodMessageType.ReturnValue: | |
cleanup(); | |
resolve(data.returnValue); | |
break; | |
case WorkerMethodMessageType.Error: | |
cleanup(); | |
reject(new Error(`${methodName} failed. Reason: ${data.errorMessage}`)); | |
break; | |
default: | |
cleanup(); | |
reject(new Error(`${methodName} failed. Reason: unknown response message type.`)); | |
} | |
} | |
}; | |
worker.addEventListener("message", dispatchMessageResponse); | |
if (transferables) { | |
worker.postMessage({ | |
taskID, | |
methodName, | |
params | |
}, transferables); | |
} | |
else { | |
worker.postMessage({ | |
taskID, | |
methodName, | |
params | |
}); | |
} | |
}); | |
} | |
/** | |
* Creates a function that can optionally choose to invoke either the provided | |
* worker method, or a UI-thread fallback, if this worker dispatcher is not enabled. | |
* @param workerCall | |
* @param localCall | |
*/ | |
createExecutor<T>(workerCall: workerClientCallback<T>, localCall: workerClientCallback<T>): workerClientCallback<T> { | |
return async (...params: any[]) => { | |
if (this.enabled) { | |
return await workerCall(...params); | |
} | |
else { | |
return await localCall(...params); | |
} | |
}; | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// A callback for reporting progress through long procedures | |
export type progressCallback = (soFar: number, total: number, message?: string, est?: number) => void; | |
export type workerServerMethod = (taskID: number, ...params: any[]) => Promise<void>; | |
export type workerServerCreateTransferableCallback = (returnValue: any) => Transferable[]; | |
export enum WorkerMethodMessageType { | |
Error = "error", | |
Progress = "progress", | |
Return = "return", | |
ReturnValue = "returnValue" | |
} | |
interface WorkerMethodMessage<T extends WorkerMethodMessageType> { | |
taskID: number; | |
methodName: T; | |
} | |
export interface WorkerMethodErrorMessage | |
extends WorkerMethodMessage<WorkerMethodMessageType.Error> { | |
errorMessage: string; | |
} | |
export interface WorkerMethodProgressMessage | |
extends WorkerMethodMessage<WorkerMethodMessageType.Progress> { | |
soFar: number; | |
total: number; | |
msg: number; | |
} | |
export interface WorkerMethodReturnMessage | |
extends WorkerMethodMessage<WorkerMethodMessageType.Return> { | |
} | |
export interface WorkerMethodReturnValueMessage | |
extends WorkerMethodMessage<WorkerMethodMessageType.ReturnValue> { | |
returnValue: any | |
} | |
export type WorkerMethodMessages = WorkerMethodErrorMessage | |
| WorkerMethodProgressMessage | |
| WorkerMethodReturnMessage | |
| WorkerMethodReturnValueMessage; | |
export interface WorkerMethodCallMessage { | |
taskID: number; | |
methodName: string; | |
params: any[]; | |
} | |
export class WorkerServer { | |
private methods = new Map<string, workerServerMethod>(); | |
/** | |
* Creates a new worker thread method call listener. | |
* @param self - the worker scope in which to listen. | |
*/ | |
constructor(private self: DedicatedWorkerGlobalScope) { | |
this.self.onmessage = (evt: MessageEvent<WorkerMethodCallMessage>): void => { | |
const data = evt.data; | |
const method = this.methods.get(data.methodName); | |
if (method) { | |
method(data.taskID, ...data.params); | |
} | |
else { | |
this.onError(data.taskID, "method not found: " + data.methodName); | |
} | |
}; | |
} | |
/** | |
* Report an error back to the calling thread. | |
* @param taskID - the invocation ID of the method that errored. | |
* @param errorMessage - what happened? | |
*/ | |
private onError(taskID: number, errorMessage: string): void { | |
this.self.postMessage({ | |
taskID, | |
methodName: WorkerMethodMessageType.Error, | |
errorMessage | |
}); | |
} | |
/** | |
* Report progress through long-running invocations. | |
* @param taskID - the invocation ID of the method that is updating. | |
* @param soFar - how much of the process we've gone through. | |
* @param total - the total amount we need to go through. | |
* @param msg - an optional message to include as part of the progress update. | |
*/ | |
private onProgress(taskID: number, soFar: number, total: number, msg?: string): void { | |
this.self.postMessage({ | |
taskID, | |
methodName: WorkerMethodMessageType.Progress, | |
soFar, | |
total, | |
msg | |
}); | |
} | |
/** | |
* Return the results back to the invoker. | |
* @param taskID - the invocation ID of the method that has completed. | |
* @param returnValue - the (optional) value that is being returned. | |
* @param transferables - an (optional) array of values that appear in the return value that should be transfered back to the calling thread, rather than copied. | |
*/ | |
private onReturn(taskID: number, returnValue?: any, transferables?: Transferable[]): void { | |
if (returnValue === undefined) { | |
this.self.postMessage({ | |
taskID, | |
methodName: WorkerMethodMessageType.Return | |
}); | |
} | |
else if (transferables === undefined) { | |
this.self.postMessage({ | |
taskID, | |
methodName: WorkerMethodMessageType.ReturnValue, | |
returnValue | |
}); | |
} | |
else { | |
this.self.postMessage({ | |
taskID, | |
methodName: WorkerMethodMessageType.ReturnValue, | |
returnValue | |
}, transferables); | |
} | |
} | |
/** | |
* Registers a function call for cross-thread invocation. | |
* @param methodName - the name of the method to use during invocations. | |
* @param asyncFunc - the function to execute when the method is invoked. | |
* @param transferReturnValue - an (optional) function that reports on which values in the `returnValue` should be transfered instead of copied. | |
*/ | |
add(methodName: string, asyncFunc: Function, transferReturnValue: workerServerCreateTransferableCallback = null) { | |
this.methods.set(methodName, async (taskID: number, ...params: any[]) => { | |
// If your invocable functions don't report progress, this can be safely ignored. | |
const onProgress: progressCallback = (soFar: number, total: number, msg?: string) => { | |
this.onProgress( | |
taskID, | |
soFar, | |
total, | |
msg | |
); | |
}; | |
try { | |
// Even functions returning void and functions returning bare, unPromised values, can be awaited. | |
// This creates a convenient fallback where we don't have to consider the exact return type of the function. | |
const returnValue = await asyncFunc(...params, onProgress); | |
if (returnValue === undefined) { | |
this.onReturn(taskID); | |
} | |
else { | |
if (transferReturnValue) { | |
const transferables = transferReturnValue(returnValue); | |
this.onReturn(taskID, returnValue, transferables); | |
} | |
else { | |
this.onReturn(taskID, returnValue); | |
} | |
} | |
} | |
catch (exp) { | |
this.onError(taskID, exp.message); | |
} | |
}); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment