Skip to content

Instantly share code, notes, and snippets.

@aleksas
Created November 29, 2018 19:03
Show Gist options
  • Save aleksas/114bb1ca3f1788687b89d3a623b7bf87 to your computer and use it in GitHub Desktop.
Save aleksas/114bb1ca3f1788687b89d3a623b7bf87 to your computer and use it in GitHub Desktop.
Use encoder part from Tensor2Tensor transformer
from tensor2tensor.models import transformer
import tensorflow as tf
hparams = transformer.transformer_base()
encoder = transformer.TransformerEncoder(hparams, mode=tf.estimator.ModeKeys.TRAIN)
#x = <your inputs, which should be of shape [batch_size, timesteps, 1, hparams.hidden_dim]>
#y = encoder({"inputs": x})
# model_fn_body(features)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment