Skip to content

Instantly share code, notes, and snippets.

@netsatsawat
Created August 17, 2020 13:53
Show Gist options
  • Save netsatsawat/c660b6e987a69a503ef7d19cf1871caf to your computer and use it in GitHub Desktop.
Save netsatsawat/c660b6e987a69a503ef7d19cf1871caf to your computer and use it in GitHub Desktop.
Snippet of RNN - GRU using Tensorflow
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
# require for installation: !pip install -q git+https://github.com/tensorflow/docs
import tensorflow_docs as tfdocs
import tensorflow_docs.plots
import tensorflow_docs.modeling
gru_model = keras.Sequential([
layers.GRU(32, return_sequences=True, input_shape=(n_steps, 1), activation='tanh'),
layers.GRU(32, return_sequences=True, activation='tanh'),
layers.Dropout(0.2),
layers.GRU(32, return_sequences=True, activation='tanh'),
layers.GRU(32, return_sequences=False, activation='tanh'),
layers.Dropout(0.2),
layers.Dense(1)
])
gru_model.compile(optimizer=tf.keras.optimizers.Adam(),
loss=tf.losses.MeanSquaredError(), metrics=['mae', 'mse'])
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss',
patience=50,
mode='min')
gru_hist = gru_model.fit(
X_train, y_train, epochs=500, validation_split=0.2,
batch_size=32, verbose=0,
callbacks=[tfdocs.modeling.EpochDots(), early_stopping]
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment