Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save typhoonzero/0eb25c98f382e5f7a4260ab80d9f3cc8 to your computer and use it in GitHub Desktop.
Save typhoonzero/0eb25c98f382e5f7a4260ab80d9f3cc8 to your computer and use it in GitHub Desktop.
import tensorflow as tf
import numpy as np
class LSTMClassifier(tf.keras.Model):
def __init__(self, feature_columns, units=32, hidden_layer=32, n_classes=2):
"""LSTMClassifier
:param feature_columns: feature columns.
:type feature_columns: list[tf.feature_column].
:param hidden_units: number of hidden units.
:type hidden_units: list[int].
:param n_classes: List of hidden units per layer.
:type n_classes: int.
"""
super(LSTMClassifier, self).__init__()
# self.feature_layer = tf.keras.layers.DenseFeatures(feature_columns)
print(feature_columns)
self.feature_layer = tf.keras.experimental.SequenceFeatures(feature_columns)
self.fc = tf.keras.layers.Dense(hidden_layer * 4)
self.lstm = tf.keras.layers.LSTM(units, return_sequences=False)
self.pred = tf.keras.layers.Dense(n_classes, activation='softmax')
def call(self, inputs):
seq_in, seq_len = self.feature_layer(inputs)
# print("########\n", seq_in, seq_len)
seq_mask = tf.sequence_mask(seq_len)
x = self.lstm(seq_in, mask=seq_mask)
# print("#########\n lstm out: ", x)
out = self.pred(x)
return out
def default_optimizer(self):
"""Default optimizer name. Used in model.compile."""
return 'adam'
def default_loss(self):
"""Default loss function. Used in model.compile."""
return 'binary_crossentropy'
def default_training_epochs(self):
"""Default training epochs. Used in model.fit."""
return 5
def prepare_prediction_column(self, prediction):
"""Return the class label of highest probability."""
return prediction.argmax(axis=-1)
def train_input_fn(features, labels, batch_size=32):
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
return dataset
def eval_input_fn(features, labels, batch_size=32):
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
dataset = dataset.batch(batch_size)
return dataset
def run():
# 100 sentences of 8 words.
features = {"c1": np.array([int(x) for x in range(800)]).reshape(100, 8)}
label = [0 for _ in range(50)] + [1 for _ in range(50)]
fea = tf.feature_column.sequence_categorical_column_with_identity(
key="c1",
num_buckets=800
)
# fea = tf.feature_column.categorical_column_with_identity(
# key="c1",
# num_buckets=100)
emb = tf.feature_column.embedding_column(
fea,
dimension=32)
feature_columns = [emb]
model = LSTMClassifier(feature_columns=feature_columns)
model.compile(optimizer=model.default_optimizer(),
loss=model.default_loss(),
metrics=["accuracy"])
model.fit(train_input_fn(features, label),
epochs=model.default_training_epochs(),
steps_per_epoch=100)
loss, acc = model.evaluate(eval_input_fn(features, label))
if __name__ == '__main__':
run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment