Skip to content

Instantly share code, notes, and snippets.

@albertz
Created December 14, 2018 13:37
Show Gist options
  • Save albertz/2ebec1fc0243fc6443e78c3bacc15e1e to your computer and use it in GitHub Desktop.
Save albertz/2ebec1fc0243fc6443e78c3bacc15e1e to your computer and use it in GitHub Desktop.
get_switch_op_cond_ctx, get control_flow_ops.CondContext from a switch tf.Operation (if possible)
def get_switch_op_cond_ctx(op):
"""
See control_flow_util.IsCondSwitch.
:param tf.Operation op: switch op
:rtype: tensorflow.python.ops.control_flow_ops.CondContext|None
"""
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
assert op.type in {"Switch", "RefSwitch"} # control_flow_util.IsSwitch
assert op.outputs
# Switch nodes are not part of the cond control flow context that they
# represent, so consider the consumers of its outputs to determine if it is
# cond switch or not. A switch is a cond switch iff all its consumers are in
# cond contexts.
is_cond_switch = True
ctxt = None
for o in op.outputs:
for c in o.consumers():
ctxt = c._get_control_flow_context() # pylint: disable=protected-access
if control_flow_util.IsLoopEnter(c):
ctxt = ctxt.outer_context
is_cond_switch = is_cond_switch and (ctxt is not None and ctxt.IsCondContext())
assert is_cond_switch
if not ctxt:
# This can happen, if we just have constructed the switch, or this is via tf.gradients.
return None
assert isinstance(ctxt, control_flow_ops.CondContext)
return ctxt
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment