-
-
Save funktor/3a65905ff421f487dccd6385c3af6b4c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.models import Model, Input | |
from keras.layers import LSTM, Embedding, Dense, TimeDistributed, Dropout, Bidirectional | |
from keras_contrib.layers import CRF | |
from keras.models import load_model | |
import utils | |
import numpy as np | |
def pred2label(pred, tag_inverse_transformer): | |
out = [] | |
for pred_i in pred: | |
out_i = [] | |
for p in pred_i: | |
p_i = np.argmax(p) | |
out_i.append(tag_inverse_transformer[p_i].replace("PAD", "O")) | |
out.append(out_i) | |
return out | |
class BiLSTM(object): | |
def __init__(self, vocab_size, max_words, num_tags, embedding_size, model_file_path): | |
self.vocab_size = vocab_size | |
self.max_words = max_words | |
self.num_tags = num_tags | |
self.embedding_size = embedding_size | |
self.model = None | |
self.model_file_path = model_file_path | |
def build_model(self): | |
print "Building model..." | |
input = Input(shape=(self.max_words,)) | |
embed = Embedding(input_dim=self.vocab_size + 1, output_dim=self.embedding_size, input_length=self.max_words, mask_zero=True)(input) | |
bilstm = Bidirectional(LSTM(units=50, return_sequences=True, recurrent_dropout=0.1))(embed) | |
dense = TimeDistributed(Dense(50, activation="relu"))(bilstm) | |
out = TimeDistributed(Dense(self.num_tags, activation="softmax"))(dense) | |
self.model = Model(input, out) | |
self.model.compile(optimizer="adam", loss='categorical_crossentropy', metrics=['accuracy']) | |
def fit(self, X, y): | |
print "Fitting model..." | |
self.model.fit(X, np.array(y), batch_size=32, epochs=15, validation_split=0.1, verbose=1) | |
def save(self): | |
self.model.save(self.model_file_path) | |
def load(self): | |
self.model.load_weights(self.model_file_path) | |
def predict(self, X): | |
print "Predicting..." | |
return self.model.predict(X, verbose=1) | |
def score(self, X, y, tag_inverse_transformer): | |
test_pred = self.predict(X) | |
pred_labels, test_labels = pred2label(test_pred, tag_inverse_transformer), pred2label(y, tag_inverse_transformer) | |
print utils.get_accuracy(test_labels, pred_labels) | |
return utils.get_classification_score(test_labels, pred_labels) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment