Skip to content

Instantly share code, notes, and snippets.

@eladshabi
Last active February 25, 2019 12:33
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save eladshabi/ea3ffda2df47b9e37f8d5b3ee14d66b9 to your computer and use it in GitHub Desktop.
Use custom getter
# source: https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html
def float32_variable_storage_getter(getter, name, shape=None, dtype=None,
initializer=None, regularizer=None,
trainable=True,
*args, **kwargs):
storage_dtype = tf.float32 if trainable else dtype
variable = getter(name, shape, dtype=storage_dtype,
initializer=initializer, regularizer=regularizer,
trainable=trainable,
*args, **kwargs)
if trainable and dtype != tf.float32:
variable = tf.cast(variable, dtype)
return variable
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment