Skip to content

Instantly share code, notes, and snippets.

@mpekalski
Last active August 25, 2018 22:00
Show Gist options
  • Save mpekalski/eb663f54b420e40ae6f17007e31a11b1 to your computer and use it in GitHub Desktop.
Save mpekalski/eb663f54b420e40ae6f17007e31a11b1 to your computer and use it in GitHub Desktop.
with tf.variable_scope("dataset"):
train_values = tf.constant([[1,2,3],[2,3,4],
[3,4,2],[2,1,5],
[1,7,3],[2,2,7],
[0,1,0],[0,1,0],
[0,1,0],[0,1,0]])
train_labels = tf.constant([1, 0, 0, 1, 1, 1,0,1,0,0])
# 1) no shuffling, so the outcome would be deterministic
# 2) make one dataset object from train_values and train_labels
ds_t = (Dataset.from_tensor_slices((train_values, tshould gist be a function or whole filerain_labels))
.batch(batch_size)
.repeat(num_epochs)
)
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.data import Dataset
# Reset the graph, so when you rerun notebook tf does not complain
# about variables already existing in the graph.
tf.reset_default_graph()
g = tf.Graph()
with tf.variable_scope("iterator"):
# Define iterator from_string_handle. In general it is useful to have
# this kind of iterator if one wants to switch between train and validation
# within the training loop.
iterator_handle = tf.placeholder(tf.string, shape=[], name="iterator_handle")
iterator_t = ds_t.make_initializable_iterator()
iterator = tf.data.Iterator.from_string_handle(iterator_handle,
iterator_t.output_types,
iterator_t.output_shapes)
def get_next_item():
next_elem = iterator.get_next(name="next_element")
x, y = tf.cast(next_elem[0], tf.float32), tf.cast(next_elem[1], tf.int32)
return x, y
def create_model_no_resampling(inputs, batch_size, is_training):
# let's make model deterministic by initializing weights
# in deterministic manner
def init_f(shape, dtype=None, partition_info=None):
ker = np.zeros(shape, dtype=np.float)
ker[tuple(map(lambda x: int(np.floor(x/2)), ker.shape))]=1
#print(ker)
return ker
with tf.variable_scope("model", reuse=True):
x = tf.keras.layers.Dense(2, activation=None, input_shape=[1], kernel_initializer=init_f, name="layer1")(inputs)
x = tf.keras.layers.Dense(2, activation=None, kernel_initializer=init_f, name="layer2")(x)
return x
==========================================
1 | inputs: [[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]] | target: [1, 0] | is_resample: False | is_keep_previous: False | is_optimize: False | avg_loss: 0.6931471824645996 | loss: [0.6931471824645996, 0.6931471824645996]
2 | inputs: [[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]] | target: [1, 0] | is_resample: True | is_keep_previous: False | is_optimize: False | avg_loss: 1.5877577066421509 | loss: [0.12692804634571075, 3.0485873222351074]
3 | inputs: [[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]] | target: [1, 0] | is_resample: True | is_keep_previous: False | is_optimize: False | avg_loss: 1.5877577066421509 | loss: [0.12692804634571075, 3.0485873222351074]
4 | inputs: [[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]] | target: [1, 0] | is_resample: True | is_keep_previous: True | is_optimize: True | avg_loss: 1.5877577066421509 | loss: [0.12692804634571075, 3.0485873222351074]
5 | inputs: [[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]] | target: [1, 0] | is_resample: True | is_keep_previous: True | is_optimize: False | avg_loss: 1.57833993434906 | loss: [0.12874676287174225, 3.027933120727539]
6 | inputs: [[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]] | target: [1, 0] | is_resample: True | is_keep_previous: True | is_optimize: False | avg_loss: 1.57833993434906 | loss: [0.12874676287174225, 3.027933120727539]
==========================================
1 | inputs: [[3.0, 4.0, 2.0], [2.0, 1.0, 5.0]] | target: [0, 1] | is_resample: False | is_keep_previous: False | is_optimize: False | avg_loss: 1.0816088914871216 | loss: [2.113590717315674, 0.04962695762515068]
2 | inputs: [[3.0, 4.0, 2.0], [3.0, 4.0, 2.0]] | target: [0, 0] | is_resample: True | is_keep_previous: False | is_optimize: False | avg_loss: 2.648705005645752 | loss: [3.9955029487609863, 1.3019068241119385]
3 | inputs: [[3.0, 4.0, 2.0], [3.0, 4.0, 2.0]] | target: [0, 0] | is_resample: True | is_keep_previous: False | is_optimize: False | avg_loss: 3.9955029487609863 | loss: [3.9955029487609863, 3.9955029487609863]
4 | inputs: [[3.0, 4.0, 2.0], [3.0, 4.0, 2.0]] | target: [0, 0] | is_resample: True | is_keep_previous: True | is_optimize: True | avg_loss: 3.9955029487609863 | loss: [3.9955029487609863, 3.9955029487609863]
5 | inputs: [[3.0, 4.0, 2.0], [3.0, 4.0, 2.0]] | target: [0, 0] | is_resample: True | is_keep_previous: True | is_optimize: False | avg_loss: 3.9254682064056396 | loss: [3.9254682064056396, 3.9254682064056396]
6 | inputs: [[3.0, 4.0, 2.0], [3.0, 4.0, 2.0]] | target: [0, 0] | is_resample: True | is_keep_previous: True | is_optimize: False | avg_loss: 3.9254682064056396 | loss: [3.9254682064056396, 3.9254682064056396]
==========================================
1 | inputs: [[1.0, 7.0, 3.0], [2.0, 2.0, 7.0]] | target: [1, 1] | is_resample: False | is_keep_previous: False | is_optimize: False | avg_loss: 0.019930224865674973 | loss: [0.019930224865674973, 0.019930224865674973]
2 | inputs: [[2.0, 2.0, 7.0], [2.0, 2.0, 7.0]] | target: [1, 1] | is_resample: True | is_keep_previous: False | is_optimize: False | avg_loss: 0.06877464056015015 | loss: [0.001050516264513135, 0.13649876415729523]
3 | inputs: [[2.0, 2.0, 7.0], [2.0, 2.0, 7.0]] | target: [1, 1] | is_resample: True | is_keep_previous: False | is_optimize: False | avg_loss: 0.13649876415729523 | loss: [0.13649876415729523, 0.13649876415729523]
4 | inputs: [[2.0, 2.0, 7.0], [2.0, 2.0, 7.0]] | target: [1, 1] | is_resample: True | is_keep_previous: True | is_optimize: True | avg_loss: 0.13649876415729523 | loss: [0.13649876415729523, 0.13649876415729523]
5 | inputs: [[2.0, 2.0, 7.0], [2.0, 2.0, 7.0]] | target: [1, 1] | is_resample: True | is_keep_previous: True | is_optimize: False | avg_loss: 0.13820050656795502 | loss: [0.13820050656795502, 0.13820050656795502]
6 | inputs: [[2.0, 2.0, 7.0], [2.0, 2.0, 7.0]] | target: [1, 1] | is_resample: True | is_keep_previous: True | is_optimize: False | avg_loss: 0.13820050656795502 | loss: [0.13820050656795502, 0.13820050656795502]
==========================================
1 | inputs: [[0.0, 1.0, 0.0], [0.0, 1.0, 0.0]] | target: [0, 1] | is_resample: False | is_keep_previous: False | is_optimize: False | avg_loss: 1.0927773714065552 | loss: [2.047354221343994, 0.13820050656795502]
2 | inputs: [[0.0, 1.0, 0.0], [0.0, 1.0, 0.0]] | target: [0, 0] | is_resample: True | is_keep_previous: False | is_optimize: False | avg_loss: 1.2933204174041748 | loss: [1.2933204174041748, 1.2933204174041748]
3 | inputs: [[0.0, 1.0, 0.0], [0.0, 1.0, 0.0]] | target: [0, 0] | is_resample: True | is_keep_previous: False | is_optimize: False | avg_loss: 1.2933204174041748 | loss: [1.2933204174041748, 1.2933204174041748]
4 | inputs: [[0.0, 1.0, 0.0], [0.0, 1.0, 0.0]] | target: [0, 0] | is_resample: True | is_keep_previous: True | is_optimize: True | avg_loss: 1.2933204174041748 | loss: [1.2933204174041748, 1.2933204174041748]
5 | inputs: [[0.0, 1.0, 0.0], [0.0, 1.0, 0.0]] | target: [0, 0] | is_resample: True | is_keep_previous: True | is_optimize: False | avg_loss: 1.288617491722107 | loss: [1.288617491722107, 1.288617491722107]
6 | inputs: [[0.0, 1.0, 0.0], [0.0, 1.0, 0.0]] | target: [0, 0] | is_resample: True | is_keep_previous: True | is_optimize: False | avg_loss: 1.288617491722107 | loss: [1.288617491722107, 1.288617491722107]
==========================================
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
handle_t = sess.run(iterator_t.string_handle())
sess.run(iterator_t.initializer)
print("Trainable Variables:")
[print(x) for x in tf.trainable_variables()]
print("==========================================")
while True:
try:
i = 0
# 1
inputs_, target_, loss_, is_resample_, is_keep_previous_, is_optimize_ = sess.run([inputs, target, loss_op, is_resample, is_keep_previous, is_optimize],
feed_dict={iterator_handle: handle_t, is_resample:False, is_training:True, is_keep_previous:False, is_optimize:False})
print_log(inputs=inputs_, target=target_, is_resample=is_resample_, is_keep_previous=is_keep_previous_, is_optimize=is_optimize_, avg_loss=np.sum(loss_)/2, loss=loss_)
# 2
inputs_, target_, loss_, is_resample_, is_keep_previous_, is_optimize_, _ = sess.run([inputs, target, loss_op, is_resample, is_keep_previous, is_optimize, apply_gradients],
feed_dict={iterator_handle: handle_t, is_resample:True, is_training:True, is_keep_previous:False, is_optimize:False})
print_log(inputs=inputs_, target=target_, is_resample=is_resample_, is_keep_previous=is_keep_previous_, is_optimize=is_optimize_, avg_loss=np.sum(loss_)/2, loss=loss_)
# 3
inputs_, target_, loss_, is_resample_, is_keep_previous_, is_optimize_, _ = sess.run([inputs, target, loss_op, is_resample, is_keep_previous, is_optimize, apply_gradients],
feed_dict={iterator_handle: handle_t, is_resample:True, is_training:True, is_keep_previous:False, is_optimize:False})
print_log(inputs=inputs_, target=target_, is_resample=is_resample_, is_keep_previous=is_keep_previous_, is_optimize=is_optimize_, avg_loss=np.sum(loss_)/2, loss=loss_)
# 4
inputs_, target_, is_resample_, loss_, is_keep_previous_, is_optimize_, _ = sess.run([inputs, target, is_resample, loss_op, is_keep_previous, is_optimize, apply_gradients],
feed_dict={iterator_handle: handle_t, is_resample:True, is_training:True, is_keep_previous:True, is_optimize:True})
print_log(inputs=inputs_, target=target_, is_resample=is_resample_, is_keep_previous=is_keep_previous_, is_optimize=is_optimize_, avg_loss=np.sum(loss_)/2, loss=loss_)
# 5
inputs_, target_, is_resample_, loss_, is_keep_previous_, is_optimize_ = sess.run([inputs, target, is_resample, loss_op, is_keep_previous, is_optimize],
feed_dict={iterator_handle: handle_t, is_resample:True, is_training:True, is_keep_previous:True, is_optimize:False})
print_log(inputs=inputs_, target=target_, is_resample=is_resample_, is_keep_previous=is_keep_previous_, is_optimize=is_optimize_, avg_loss=np.sum(loss_)/2, loss=loss_)
# 6
inputs_, target_, is_resample_,loss_, is_keep_previous_, optimize_ = sess.run([inputs, target, is_resample, loss_op, is_keep_previous, is_optimize],
feed_dict={iterator_handle: handle_t, is_resample:True, is_training:True, is_keep_previous:True, is_optimize:False})
print_log(inputs=inputs_, target=target_, is_resample=is_resample_, is_keep_previous=is_keep_previous_, is_optimize=is_optimize_, avg_loss=np.sum(loss_)/2, loss=loss_)
print("==========================================")
except tf.errors.OutOfRangeError:
break
batch_size = 2
num_epochs = 3
n_replace = 1
np.random.seed(123)
step_counter = tf.train.get_or_create_global_step()
# initial value for batch_rangte
batch_range_init = np.arange(batch_size, dtype=np.int32)
# flags
is_training = tf.placeholder(tf.bool, shape=(), name="training_flag")
is_keep_previous = tf.placeholder_with_default(tf.constant(False),shape=[], name="keep_previous_flag")
is_resample = tf.placeholder_with_default(tf.constant(False),shape=[], name="resample_flag")
is_optimize = tf.placeholder_with_default(tf.constant(False),shape=[], name="optimize_flag")
# variables/placeholders
initial_loss = tf.Variable(initial_value=tf.zeros([batch_size])
, dtype=tf.float32, name="initial_loss", trainable=False)
# mind here we use use_resource=True
inputs = tf.Variable(tf.zeros(shape=[batch_size, 3])
, dtype=tf.float32, name="inputs", trainable=False, use_resource=True)
target = tf.Variable(tf.zeros(shape=[batch_size], dtype=tf.int32)
, dtype=tf.int32, name="target", trainable=False, use_resource=True)
def batch_loss_fn(logits, labels):
return tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=labels, name="batch_loss_fn")
def loss_fn(logits, labels):
return tf.reduce_mean(batch_loss_fn(logits, labels), name="loss_fn")
def new_data():
# run the data layer to generate a new batch
next_inputs, next_target = get_next_item()
with tf.control_dependencies([tf.assign(inputs, next_inputs), tf.assign(target, next_target)]):
return tf.identity(inputs), tf.identity(target)
def old_data():
# just forward the existing batch
return inputs, target
def old_resampled():
initial_loss_batch = batch_loss_fn(logits, target)
initial_loss_op = tf.assign(initial_loss, initial_loss_batch)
tf.stop_gradient(initial_loss_batch)
batch_range = tf.Variable(batch_range_init, dtype=tf.int32, trainable=False, name="batch_range")
with tf.control_dependencies([initial_loss_op]):
high_loss = tf.nn.top_k(initial_loss, k=n_replace, sorted=False, name="high_loss")
low_loss = tf.nn.top_k(tf.multiply(-1.0, initial_loss), k=n_replace, sorted=False, name="bottom_loss")
batch_range_op = batch_range.assign(batch_range_init)
with tf.control_dependencies([batch_range_op]):
new_data_ind = tf.scatter_nd_update(batch_range, tf.reshape(low_loss.indices, [n_replace,1])
, tf.reshape(high_loss.indices, [n_replace]))
index = tf.expand_dims(new_data_ind, 1)
resampled_inputs = tf.gather_nd(inputs, index, name="resampled_z")
resampled_target = tf.gather_nd(target, index, name="resampled_target")
with tf.control_dependencies([tf.assign(inputs, resampled_inputs), tf.assign(target, resampled_target)]):
return tf.identity(inputs, name="inputs"), tf.identity(target, name="target")
logits = create_model_no_resampling(inputs, batch_size, is_training)
inputs, target = tf.cond(is_keep_previous,
old_data,
lambda:tf.cond(is_resample,
lambda: tf.cond(tf.greater(step_counter, 0),
old_resampled,
old_data),
new_data))
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
def optimize_fn():
optimizer = tf.train.MomentumOptimizer(learning_rate=0.001,
momentum=0.4,
use_nesterov=False
)
grads = optimizer.compute_gradients(loss_fn(logits, target), tf.trainable_variables())
apply_grads = optimizer.apply_gradients(grads, global_step = step_counter, name="apply_gradients")
return apply_grads
with tf.control_dependencies([inputs,target]):
with tf.control_dependencies(update_ops):
loss_op = batch_loss_fn(logits, target)
tf.stop_gradient(loss_op)
barrier = tf.no_op(name="gradient_barrier")
with tf.control_dependencies([barrier]):
apply_gradients = tf.cond(is_optimize, optimize_fn, lambda: tf.no_op("no_op_cond"))
i = 0
def print_log(**kwargs):
global i
i = i + 1
print(i, ' | ', ' | '.join(['{}: {}'.format(x[0],x[1].tolist()) for x in zip(kwargs.keys(),kwargs.values())]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment