Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save caisq/33ed021e0c7b9d0e728cb1dce399527d to your computer and use it in GitHub Desktop.
Save caisq/33ed021e0c7b9d0e728cb1dce399527d to your computer and use it in GitHub Desktop.
Custom Layers in TensorFlow.js (Stateful, Configurable, and Serializable)
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
/**
* Define a custom layer.
*
* This layer performs the following simple operation:
* output = input * (x ^ alpha);
* - x is a trainable scalar weight.
* - alpha is a configurable constant.
*
* This custom layer is written in a way that can be saved and loaded.
*/
class TimesXToThePowerOfAlphaLayer extends tf.layers.Layer {
constructor(config) {
super(config);
this.alpha = config.alpha;
}
/**
* build() is called when the custom layer object is connected to an
* upstream layer for the first time.
* This is where the weights (if any) are created.
*/
build(inputShape) {
this.x = this.addWeight('x', [], 'float32', tf.initializers.ones());
}
/**
* call() contains the actual numerical computation of the layer.
*
* It is "tensor-in-tensor-out". I.e., it receives one or more
* tensors as the input and should produce one or more tensors as
* the return value.
*
* Be sure to use tidy() to avoid WebGL memory leak.
*/
call(input) {
return tf.tidy(() => {
const k = tf.pow(this.x.read(), this.alpha);
return tf.mul(input[0], k);
});
}
/**
* getConfig() generates the JSON object that is used
* when saving and loading the custom layer object.
*/
getConfig() {
const config = super.getConfig();
Object.assign(config, {alpha: this.alpha});
return config;
}
/**
* The static className getter is required by the
* registration step (see below).
*/
static get className() {
return 'TimesXToThePowerOfAlphaLayer';
}
}
/**
* Regsiter the custom layer, so TensorFlow.js knows what class constructor
* to call when deserializing an saved instance of the custom layer.
*/
tf.serialization.registerClass(TimesXToThePowerOfAlphaLayer);
(async function main() {
const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [4]}));
// Here comes an instance of the custom layer.
model.add(new TimesXToThePowerOfAlphaLayer({alpha: 1.5}));
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
model.summary();
// Train the model using some random data.
const xs = tf.randomNormal([2, 4]);
const ys = tf.randomNormal([2, 1]);
await model.fit(xs, ys, {
epochs: 5,
callbacks: {
onEpochEnd: async (epoch, logs) => {
console.log(`Epoch {epoch}: loss = ${logs.loss}`);
}
}
});
// Save the model and load it back.
await model.save('indexeddb://codepen-tfjs-model-example-jdBgwB-v1');
console.log('Model saved.');
const model2 = await tf.loadModel('indexeddb://codepen-tfjs-model-example-jdBgwB-v1');
console.log('Model2 loaded.')
console.log('The two predict() outputs should be identical:');
model.predict(xs).print();
model2.predict(xs).print();
})();
@BenjaminWegener
Copy link

https://jsfiddle.net/Lks5wq6a/ updated for tfjs 1.5

@ierezell
Copy link

Hi,
Thanks a lot for this exemple !

Do you know if it's possible to use tfjs.layers inside a custom layer ?
I tried however, the weights of a tfjs.dense are not taken as trainable (0 parameters in the model.summary())

The fix would be to have a addWeights for all dense layers in the build method but it's a bit heavy...

Anyone has an idea about that ? If yes, it would allow me to have a transformer layer to share... (I will still share the addWeight version)

Thanks in advance for any help,
Have a great day

@BenjaminWegener
Copy link

Should be possible, but i haven't tried so far. But why don't you just use a model inside a model? I used a model as attention layer in a transformerlike architecture. Also reduces memory load if used more than once.

@ierezell
Copy link

Hi @BenjaminWegener,

Because all the model I found aren't trainable. So I thought to create my own.
I was only able to load bert models as a tf.graphModel which is not trainable (only inference). If you found a way to do it please help me.

I just need to finetune a bert or any transformer like attention model and then put a classifier head on top.

However, using tf.layers inside a custom layer should be a nice feature to have.

Thanks in advance for any help.

Have a great day

@BenjaminWegener
Copy link

BenjaminWegener commented Apr 14, 2021

I have no running example at hand, I am trying to build gpt2 in tfjs. so far i'm stuck with the beamsearch/topk sampling - the results are not so promising (repeating words etc.)
but i can give you this, hope it helps.
https://gist.github.com/BenjaminWegener/311292080a71becbe5a8c0cc7657657d

@ierezell
Copy link

I've done a PR to tfjs for an AttentionLayer if it's helpful to you.

@BenjaminWegener
Copy link

I've done a PR to tfjs for an AttentionLayer if it's helpful to you.

Nice, thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment