Skip to content

Instantly share code, notes, and snippets.

@JustASquid
Created October 30, 2022 21:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save JustASquid/aece279ac63cf399744cc0f3f28a21e1 to your computer and use it in GitHub Desktop.
Save JustASquid/aece279ac63cf399744cc0f3f28a21e1 to your computer and use it in GitHub Desktop.
import random
import tensorflow as tf
import numpy as np
NUM_OPTIONS = 10
OPTIONS_DEPTH = 1
NUM_SLOTS = 1
class Selector(tf.keras.layers.Layer):
def __init__(self, temperature=0.6, hard=False):
super().__init__()
self._temperature = temperature
self._hard = hard
def call(self, options, selections_logits):
softmax = tf.nn.softmax(selections_logits / self._temperature, axis=2)
if self._hard:
# Forwards pass is regular argmax while backwards pass is softmax
hardmax = tf.one_hot(tf.math.argmax(selections_logits, axis=2), depth=NUM_OPTIONS, axis=2)
selections = tf.stop_gradient(hardmax - softmax) + softmax
else:
selections = softmax
result = tf.linalg.matmul(selections, options)
return result
def get_model():
options_input = tf.keras.layers.Input(shape=(NUM_OPTIONS, OPTIONS_DEPTH))
net = tf.keras.layers.Flatten()(options_input)
for i in range(2):
net = tf.keras.layers.Dense(64, activation="relu")(net)
net = tf.keras.layers.Dense(NUM_OPTIONS * NUM_SLOTS)(net)
logits = tf.keras.layers.Reshape((NUM_SLOTS, NUM_OPTIONS))(net)
choices = Selector(temperature=1.0, hard=True)(options_input, logits)
choices = tf.keras.layers.Flatten()(choices)
model = tf.keras.Model([options_input], choices)
return model
def batch_gen():
while True:
# Generate datasets:
# X is the list of numbers from 0 to 9 ordered randomly
# Y is the number 5 - i.e. we want the model to select 5
options = list(range(NUM_OPTIONS))
random.shuffle(options)
x = np.array(options)
y = [5]
yield x, y
dataset = tf.data.Dataset.from_generator(
batch_gen,
output_types=(
tf.float32,
tf.float32
)
)
dataset = dataset.batch(32)
dataset = dataset.repeat()
dataset = dataset.prefetch(1)
model = get_model()
model.compile(
optimizer=tf.keras.optimizers.Adam(0.001),
loss=tf.keras.losses.MeanSquaredError(),
run_eagerly=True
)
stopping_callback = tf.keras.callbacks.EarlyStopping(
monitor="loss",
min_delta=0,
patience=15,
verbose=1,
baseline=None,
restore_best_weights=True
)
reduce_lr_callback = tf.keras.callbacks.ReduceLROnPlateau(
monitor="loss",
factor=0.75,
patience=5,
verbose=1,
min_delta=0.001,
)
model.fit(
dataset,
steps_per_epoch=250,
epochs=100,
callbacks=[
reduce_lr_callback,
stopping_callback
])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment