Skip to content

Instantly share code, notes, and snippets.

@idleuncle
Created October 21, 2019 12:59
Show Gist options
  • Save idleuncle/7ecd41511e3761ad2e96815e01cc0ddf to your computer and use it in GitHub Desktop.
Save idleuncle/7ecd41511e3761ad2e96815e01cc0ddf to your computer and use it in GitHub Desktop.
[Keras FastText]
# coding=utf-8
from keras import Input, Model
from keras.layers import Embedding, GlobalAveragePooling1D, Dense
class FastText(object):
def __init__(self, maxlen, max_features, embedding_dims,
class_num=1,
last_activation='sigmoid'):
self.maxlen = maxlen
self.max_features = max_features
self.embedding_dims = embedding_dims
self.class_num = class_num
self.last_activation = last_activation
def get_model(self):
input = Input((self.maxlen,))
embedding = Embedding(self.max_features, self.embedding_dims, input_length=self.maxlen)(input)
x = GlobalAveragePooling1D()(embedding)
output = Dense(self.class_num, activation=self.last_activation)(x)
model = Model(inputs=input, outputs=output)
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment