Skip to content

Instantly share code, notes, and snippets.

@prrao87
Created January 12, 2019 23:40
Show Gist options
  • Save prrao87/f9cd8bc55d462eab8be4a7e68c68d5c7 to your computer and use it in GitHub Desktop.
Save prrao87/f9cd8bc55d462eab8be4a7e68c68d5c7 to your computer and use it in GitHub Desktop.
classification task-head for transformer
def transform_stance(X1):
# Input transform for classification task-head
n_batch = len(X1)
xmb = np.zeros((n_batch, 1, n_ctx, 2), dtype=np.int32)
mmb = np.zeros((n_batch, 1, n_ctx), dtype=np.float32)
start = encoder['_start_']
for i, x1 in enumerate(X1):
x12 = [start] + x1[:max_len] + [clf_token]
l12 = len(x12)
xmb[i, 0, :l12, 0] = x12
mmb[i, 0, :l12] = 1
# Position information that is added to the input embeddings in the TransformerModel
xmb[:, :, :, 1] = np.arange(n_vocab + n_special, n_vocab + n_special + n_ctx)
return xmb, mmb
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment