Skip to content

Instantly share code, notes, and snippets.

@jelmervdl
Created March 25, 2022 17:52
Show Gist options
  • Save jelmervdl/7dc651fc53889d016261eaa0b3f30db8 to your computer and use it in GitHub Desktop.
Save jelmervdl/7dc651fc53889d016261eaa0b3f30db8 to your computer and use it in GitHub Desktop.
const {Blob} = require('buffer');
const fs = require('fs');
const https = require('https');
const wasmRoot = '../build-wasm/';
const wasmBinary = fs.readFileSync(`${wasmRoot}/bergamot-translator-worker.wasm`);
const Module = {wasmBinary, onRuntimeInitialized};
global.Module = Module;
/**
* Dirty bit of code that hooks into bergamot-translator-worker to monkey-patch
* fallbackGemm just before createWasm() is called in there.
*/
const log = [];
Object.defineProperty(Module, 'setDelayFunction', {
configurable: true,
set: function(value) {
Object.defineProperty(Module, 'setDelayFunction', {value});
patch();
}
});
function patch() {
const FALLBACK_GEMM = "asm";
const GEMM_TO_FALLBACK_FUNCTIONS_MAP = {
"int8_prepare_a": "int8PrepareAFallback",
"int8_prepare_b": "int8PrepareBFallback",
"int8_prepare_b_from_transposed": "int8PrepareBFromTransposedFallback",
"int8_prepare_b_from_quantized_transposed": "int8PrepareBFromQuantizedTransposedFallback",
"int8_prepare_bias": "int8PrepareBiasFallback",
"int8_multiply_and_add_bias": "int8MultiplyAndAddBiasFallback",
"int8_select_columns_of_b": "int8SelectColumnsOfBFallback"
};
global.createWasmGemm = () => {
console.log('Using logging fallback GEMM');
return Object.fromEntries(Array.from(Object.entries(GEMM_TO_FALLBACK_FUNCTIONS_MAP)).map(([key, fallback]) => {
return [key, (...args) => {
log.push({key, args});
return Module[FALLBACK_GEMM][fallback](...args);
}];
}));
};
}
/**
* End monkey patch
*/
// Execute bergamot-translation-worker.js in this scope
const js = fs.readFileSync(`${wasmRoot}/bergamot-translator-worker.js`, {encoding:'utf8'});
eval.call(global, js);
/**
* Helper to download file into ArrayBuffer.
*/
function download(url) {
return new Promise((accept, reject) => {
https.get(url, (res) => {
const chunks = [];
res.on('error', reject);
res.on('data', chunk => chunks.push(chunk));
res.on('end', async () => {
const data = new Blob(chunks);
data.arrayBuffer().then(accept, reject);
});
});
});
}
/**
* Loads ArrayBuffer into AlignedMemory.
*/
function load(buffer, alignment) {
const bytes = new Int8Array(buffer);
const memory = new Module.AlignedMemory(bytes.byteLength, alignment);
memory.getByteArrayView().set(bytes);
return memory;
}
/**
* Called from inside the worker.js script once the wasm module is loaded
* and all the emscripten magic and linking has been done.
*/
async function onRuntimeInitialized() {
// Root url for our models for now.
const root = 'https://storage.googleapis.com/bergamot-models-sandbox/0.2.14';
// In order of TranslationMemory's arguments
const files = [
{url: `${root}/ende/model.ende.intgemm.alphas.bin`, alignment: 256},
{url: `${root}/ende/lex.50.50.ende.s2t.bin`, alignment: 64},
{url: `${root}/ende/vocab.deen.spm`, alignment: 64},
];
// Download model data and load it into aligned memory
const [modelMem, shortlistMem, vocabMem] = await Promise.all(files.map(async (file) => {
return load(await download(file.url), file.alignment);
}));
// Config yaml (split as array to allow for indentation without adding tabs
// or spaces to the strings themselves.)
const config = [
"beam-size: 1",
"normalize: 1.0",
"word-penalty: 0",
"max-length-break: 128",
"mini-batch-words: 1024",
"workspace: 128",
"max-length-factor: 2.0",
"skip-cost: true",
"cpu-threads: 0",
"quiet: true",
"quiet-translation: true",
"gemm-precision: int8shiftAll",
].join("\n");
// Set up translation service
const service = new Module.BlockingService({cacheSize: 0});
// Put vocab into its own std::vector<AlignedMemory>
const vocabs = new Module.AlignedMemoryList();
vocabs.push_back(vocabMem);
// Setup up model with config yaml and AlignedMemory objects
const model = new Module.TranslationModel(config, modelMem, shortlistMem, vocabs, null);
// Construct std::vector<std::string> inputs;
const input = new Module.VectorString();
input.push_back("Hello world!");
// Construct std::vector<ResponseOptions>
const options = new Module.VectorResponseOptions();
options.push_back({qualityScores: false, alignment: false, html: false});
// Translate our batch (of 1)
const output = service.translate(model, input, options);
// Get output from std::vector<Response>
console.log(output.get(0).getTranslatedText());
// Clean-up
input.delete();
options.delete();
output.delete();
fs.writeFileSync('intgemm-calls.json', JSON.stringify(log, null, 2));
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment