Skip to content

Instantly share code, notes, and snippets.

@EndingCredits
Created May 21, 2018 09:58
Show Gist options
  • Save EndingCredits/e5b29a62104bd31da705363e04848c78 to your computer and use it in GitHub Desktop.
Save EndingCredits/e5b29a62104bd31da705363e04848c78 to your computer and use it in GitHub Desktop.
def combine_weights(in_list):
"""
Returns a 1D tensor of the input list of (nested lists of) tensors, useful
for doing things like comparing current weights with old weights for EWC.
1.) For all elements in input list, (ln 3)
if a list combine it recursively
else leave it alone
2.) From resulting list, get all non-none elements and flatten them (ln 2)
3.) If resulting list is empty return None (ln 1)
else return concatenation of list
( All on one line :) )
"""
return (lambda x: None if not x else tf.concat(x, axis=0)) (
[ tf.reshape(x, [-1]) for x in
[ combine_weights(x) if isinstance(x, list) else x for x in in_list ]
if x is not None])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment