Skip to content

Instantly share code, notes, and snippets.

@MikeOfZen
Created November 5, 2019 19:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save MikeOfZen/a26bc18850dc3203922c9e26b71c9b18 to your computer and use it in GitHub Desktop.
Save MikeOfZen/a26bc18850dc3203922c9e26b71c9b18 to your computer and use it in GitHub Desktop.
[Cut TF model] use this to slice a keras model in half. useful for transfer learning in large models #tf #python
def get_next_level(layer,model):
def wrap_list(val):
if type(val) is list:
return val
return [val]
r=[]
for output_t in wrap_list(layer.output):
r+=[x for x in model.layers if output_t.name in [y.name for y in wrap_list(x.input)]]
return r
def get_layers_above(cutoff_layer,model):
visited=set()
to_visit=set([cutoff_layer])
while to_visit:
layer=to_visit.pop()
to_visit.update(get_next_level(layer,model))
visited.add(layer)
return list(visited)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment