Skip to content

Instantly share code, notes, and snippets.

@yoshihikoueno
Created July 17, 2020 08:43
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yoshihikoueno/4ff0694339f88d579bb3d9b07e609122 to your computer and use it in GitHub Desktop.
Save yoshihikoueno/4ff0694339f88d579bb3d9b07e609122 to your computer and use it in GitHub Desktop.
code snippet to fix ``WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1`` in tensorflow2.2
# workaround to fix opimizer bug in tensorflow
optimizer = tf.keras.optimizers.Adam(
learning_rate=tf.Variable(0.001),
beta_1=tf.Variable(0.9),
beta_2=tf.Variable(0.999),
epsilon=tf.Variable(1e-7),
)
optimizer.iterations # this access will invoke optimizer._iterations method and create optimizer.iter attribute
optimizer.decay = tf.Variable(0.0) # Adam.__init__ assumes ``decay`` is a float object, so this needs to be converted to tf.Variable **after** __init__ method.
///
The root problem is that Adam.__init__ will initialize variables with python float objects which will not be tracked by tensorflow.
We need to let them tracked and make them appear in Adam._checkpoint_dependencies in order to load weights without actually calling the optimizer itself.
By converting Python float to tf.Variable, they will be tracked because tf.Variable is a subclass of ``trackable.Trackable``.
///
@Deepthi-Jain
Copy link

can you kindly let us know where to do the above changes??

@yoshihikoueno
Copy link
Author

yoshihikoueno commented Jul 29, 2020

@Deepthi-Jain

can you kindly let us know where to do the above changes??

If you're using tf.keras API, I guess you're doing something like:
model.compile(loss='crossentropy', optimizer='adam')

Here, you can first manually create an optimizer instance and give it to compile like this:

adam = tf.keras.optimizers.Adam(
    learning_rate=tf.Variable(0.001),
    beta_1=tf.Variable(0.9),
    beta_2=tf.Variable(0.999),
    epsilon=tf.Variable(1e-7),
)
adam.iterations
adam.decay = tf.Variable(0.0)
model.compile(loss='crossentropy', optimizer=adam)

@lixuanhng
Copy link

lixuanhng commented Dec 16, 2020

Hello @yoshihikoueno, thanks for your solution. I am using Keras.Model to build my network, and get stuck with the same problem. The way I add this part is as followed:

model = REModel(batch_size=batch_num, vocab_size=re_args.vocab_size, embedding_size=re_args.embedding_size, num_classes=re_args.num_classes, pos_num=re_args.pos_num, pos_size=re_args.pos_size, gru_units=re_args.gru_units, embedding_matrix=wordembedding)

optimizer = tf.keras.optimizers.Adam(learning_rate=tf.Variable(0.01), beta_1=tf.Variable(0.9), beta_2=tf.Variable(0.999), epsilon=tf.Variable(1e-7),)
optimizer.iterations
optimizer.decay = tf.Variable(0.0)

# loading checkpoint
ckpt = tf.train.Checkpoint(optimizer=optimizer, model=model)
ckpt.restore(tf.train.latest_checkpoint(save_path))

inputs_x = [sin_word_tensor, sin_pos1_tensor, sin_pos2_tensor]
predictions = model(inputs_x)

but unfortunately, the sort of [Unresolved object in checkpoint] thing happens again. Could you please help figure out how to improve it? Cheers.

@yoshihikoueno
Copy link
Author

@lixuanhng Can you post the error message? It should contain a list of variable names that are not resolved.

@lixuanhng
Copy link

lixuanhng commented Dec 18, 2020

@lixuanhng Can you post the error message? It should contain a list of variable names that are not resolved.

The warning massage is like:

WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.embedding.embeddings
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.time_distributed.layer.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.time_distributed.layer.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.biLSTM_1.forward_layer.cell.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.biLSTM_1.forward_layer.cell.recurrent_kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.biLSTM_1.forward_layer.cell.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.biLSTM_1.backward_layer.cell.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.biLSTM_1.backward_layer.cell.recurrent_kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.biLSTM_1.backward_layer.cell.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.biLSTM_2.forward_layer.cell.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.biLSTM_2.forward_layer.cell.recurrent_kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.biLSTM_2.forward_layer.cell.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.biLSTM_2.backward_layer.cell.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.biLSTM_2.backward_layer.cell.recurrent_kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.biLSTM_2.backward_layer.cell.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.embedding.embeddings
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.time_distributed.layer.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.time_distributed.layer.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.biLSTM_1.forward_layer.cell.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.biLSTM_1.forward_layer.cell.recurrent_kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.biLSTM_1.forward_layer.cell.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.biLSTM_1.backward_layer.cell.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.biLSTM_1.backward_layer.cell.recurrent_kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.biLSTM_1.backward_layer.cell.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.biLSTM_2.forward_layer.cell.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.biLSTM_2.forward_layer.cell.recurrent_kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.biLSTM_2.forward_layer.cell.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.biLSTM_2.backward_layer.cell.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.biLSTM_2.backward_layer.cell.recurrent_kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.biLSTM_2.backward_layer.cell.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.

Seems like optimizer object are not resolved?

@yoshihikoueno
Copy link
Author

@lixuanhng Hmm, seems like optimizer's internal variables (like optimizer.beta_1) and slot variables (e.g. model.biLSTM_2.backward_layer.cell.bias) are all not being resolved.
Is there any particular reason not to use tf.keras API for saving and restoring the weights? If no, I would recommend using them.

model = REModel(batch_size=batch_num, vocab_size=re_args.vocab_size, embedding_size=re_args.embedding_size, num_classes=re_args.num_classes, pos_num=re_args.pos_num, pos_size=re_args.pos_size, gru_units=re_args.gru_units, embedding_matrix=wordembedding)

optimizer = tf.keras.optimizers.Adam(learning_rate=tf.Variable(0.01), beta_1=tf.Variable(0.9), beta_2=tf.Variable(0.999), epsilon=tf.Variable(1e-7),)
optimizer.iterations
optimizer.decay = tf.Variable(0.0)

model.compile(optimizer=optimizer)
# Or, if you also want to specify a loss,
# model.compile(loss=loss, optimizer=optimizer)

# Run a model
inputs_x = [sin_word_tensor, sin_pos1_tensor, sin_pos2_tensor]
predictions = model(inputs_x)
# You at least need to run a model once to make tensorflow prepare all the variables.
# Or, you can also manually do that by calling `model.build` method, then you don't have to run a model.

# save
ckpt_path = 'ckpt'
model.save_weights(path)
model.load_weights(path)
# If you want to be sure that everything is working as expected here,
# you may insert `assert_consumed` method call here.

# inference
inputs_x = [sin_word_tensor, sin_pos1_tensor, sin_pos2_tensor]
predictions = model(inputs_x)

The problem with not using tf.keras API for saving/restoring is that core tensorflow API and tf.keras API have
different way of handing variable objects in saved data, if I remember correctly.
So you have to manually treat those gaps between the two (native tensorflow and tf.keras), which is very tedious.

@hjzhang1018
Copy link

Hi, just want to know if you have resolved these problems now? I just encountered the same issue. Thank you very much!

@adv010
Copy link

adv010 commented Jul 17, 2021

Hi @lixuanhng , were you able to solve this problem? I'm encountering the same issue in another project. Yours is the most verbose account of this error that I'm facing.

Do reply if you ever able to solve, would be really helpful! Thanks!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment