Skip to content

Instantly share code, notes, and snippets.

@Mistobaan
Last active May 5, 2020 04:26
Show Gist options
  • Save Mistobaan/4047238780358ab36322b351640bb0ec to your computer and use it in GitHub Desktop.
Save Mistobaan/4047238780358ab36322b351640bb0ec to your computer and use it in GitHub Desktop.
Gists for tf.keras
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
reduce_lr = ReduceLROnPlateau(
monitor='val_loss',
factor=0.5,
patience=2,
verbose=1,
mode='auto',
min_lr=0.000001)
early_stopping = EarlyStopping(
monitor='val_loss',
patience=10,
verbose=1,
mode='auto')
model_checkpoint = ModelCheckpoint(
filepath='weights.h5',
monitor='val_loss',
verbose=1,
save_best_only=True,
save_weights_only=True,
mode='auto')
# from https://github.com/tensorflow/tensorboard/issues/2471#issuecomment-580423961
# Some initial code which is the same for all the variants
import os
import numpy as np
import tensorflow as tf
from tensorboard.plugins import projector
def register_embedding(embedding_tensor_name, meta_data_fname, log_dir):
config = projector.ProjectorConfig()
embedding = config.embeddings.add()
embedding.tensor_name = embedding_tensor_name
embedding.metadata_path = meta_data_fname
projector.visualize_embeddings(log_dir, config)
def get_random_data(shape=(100,100)):
x = np.random.rand(*shape)
y = np.random.randint(low=0, high=2, size=shape[0])
return x, y
def save_labels_tsv(labels, filepath, log_dir):
with open(os.path.join(log_dir, filepath), 'w') as f:
for label in labels:
f.write('{}\n'.format(label))
LOG_DIR = 'tmp' # Tensorboard log dir
META_DATA_FNAME = 'meta.tsv' # Labels will be stored here
EMBEDDINGS_TENSOR_NAME = 'embeddings'
EMBEDDINGS_FPATH = os.path.join(LOG_DIR, EMBEDDINGS_TENSOR_NAME + '.ckpt')
STEP = 0
x, y = get_random_data((100,100))
register_embedding(EMBEDDINGS_TENSOR_NAME, META_DATA_FNAME, LOG_DIR)
save_labels_tsv(y, META_DATA_FNAME, LOG_DIR)
# Size of files created on disk: 80.5kB
tensor_embeddings = tf.Variable(x, name=EMBEDDINGS_TENSOR_NAME)
saver = tf.compat.v1.train.Saver([tensor_embeddings]) # Must pass list or dict
saver.save(sess=None, global_step=STEP, save_path=EMBEDDINGS_FPATH)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment