Skip to content

Instantly share code, notes, and snippets.

@idleuncle
Created October 21, 2019 12:54
Show Gist options
  • Save idleuncle/6d9d5592e695cd884d62bd4a9e464ff1 to your computer and use it in GitHub Desktop.
Save idleuncle/6d9d5592e695cd884d62bd4a9e464ff1 to your computer and use it in GitHub Desktop.
[Keras BiLSTM Attention]
# coding=utf-8
from keras import Input, Model
from keras.layers import Embedding, Dense, Dropout, Bidirectional, CuDNNLSTM
from attention import Attention
class TextAttBiRNN(object):
def __init__(self, maxlen, max_features, embedding_dims,
class_num=1,
last_activation='sigmoid'):
self.maxlen = maxlen
self.max_features = max_features
self.embedding_dims = embedding_dims
self.class_num = class_num
self.last_activation = last_activation
def get_model(self):
input = Input((self.maxlen,))
embedding = Embedding(self.max_features, self.embedding_dims, input_length=self.maxlen)(input)
x = Bidirectional(CuDNNLSTM(128, return_sequences=True))(embedding) # LSTM or GRU
x = Attention(self.maxlen)(x)
output = Dense(self.class_num, activation=self.last_activation)(x)
model = Model(inputs=input, outputs=output)
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment