Skip to content

Instantly share code, notes, and snippets.

@shawwn
Created November 15, 2019 12:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shawwn/3edcae6d4b5ba02d41be9bffda653fec to your computer and use it in GitHub Desktop.
Save shawwn/3edcae6d4b5ba02d41be9bffda653fec to your computer and use it in GitHub Desktop.
diff --git a/requirements.txt b/requirements.txt
index 0d2556c..a573910 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -5,3 +5,4 @@ tqdm==4.31.1
toposort==1.5
tensor2tensor>=1.14.1
h5py
+tensorflow_addons>=0.6.0
diff --git a/src/generate_unconditional_samples.py b/src/generate_unconditional_samples.py
index 87e5132..176cccd 100755
--- a/src/generate_unconditional_samples.py
+++ b/src/generate_unconditional_samples.py
@@ -54,9 +54,9 @@ def sample_model(
elif length > hparams.n_ctx:
raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
- with tf.Session(graph=tf.Graph()) as sess:
+ with tf.compat.v1.Session(graph=tf.Graph()) as sess:
np.random.seed(seed)
- tf.set_random_seed(seed)
+ tf.compat.v1.set_random_seed(seed)
output = sample.sample_sequence(
hparams=hparams, length=length,
@@ -71,6 +71,7 @@ def sample_model(
generated = 0
while nsamples == 0 or generated < nsamples:
+ print('Generating...')
out = sess.run(output)
for i in range(batch_size):
generated += 1
diff --git a/src/hparam.py b/src/hparam.py
new file mode 100644
index 0000000..6495838
--- /dev/null
+++ b/src/hparam.py
@@ -0,0 +1,747 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Hyperparameter values."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import numbers
+import re
+
+import six
+
+import hparam_pb2
+from tensorflow.python.framework import ops
+from tensorflow.python.util import compat
+from tensorflow.python.util import deprecation
+
+# Define the regular expression for parsing a single clause of the input
+# (delimited by commas). A legal clause looks like:
+# <variable name>[<index>]? = <rhs>
+# where <rhs> is either a single token or [] enclosed list of tokens.
+# For example: "var[1] = a" or "x = [1,2,3]"
+PARAM_RE = re.compile(r"""
+ (?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x"
+ (\[\s*(?P<index>\d+)\s*\])? # (optional) index: "1" or None
+ \s*=\s*
+ ((?P<val>[^,\[]*) # single value: "a" or None
+ |
+ \[(?P<vals>[^\]]*)\]) # list of values: None or "1,2,3"
+ ($|,\s*)""", re.VERBOSE)
+
+
+def _parse_fail(name, var_type, value, values):
+ """Helper function for raising a value error for bad assignment."""
+ raise ValueError(
+ 'Could not parse hparam \'%s\' of type \'%s\' with value \'%s\' in %s' %
+ (name, var_type.__name__, value, values))
+
+
+def _reuse_fail(name, values):
+ """Helper function for raising a value error for reuse of name."""
+ raise ValueError('Multiple assignments to variable \'%s\' in %s' % (name,
+ values))
+
+
+def _process_scalar_value(name, parse_fn, var_type, m_dict, values,
+ results_dictionary):
+ """Update results_dictionary with a scalar value.
+
+ Used to update the results_dictionary to be returned by parse_values when
+ encountering a clause with a scalar RHS (e.g. "s=5" or "arr[0]=5".)
+
+ Mutates results_dictionary.
+
+ Args:
+ name: Name of variable in assignment ("s" or "arr").
+ parse_fn: Function for parsing the actual value.
+ var_type: Type of named variable.
+ m_dict: Dictionary constructed from regex parsing.
+ m_dict['val']: RHS value (scalar)
+ m_dict['index']: List index value (or None)
+ values: Full expression being parsed
+ results_dictionary: The dictionary being updated for return by the parsing
+ function.
+
+ Raises:
+ ValueError: If the name has already been used.
+ """
+ try:
+ parsed_value = parse_fn(m_dict['val'])
+ except ValueError:
+ _parse_fail(name, var_type, m_dict['val'], values)
+
+ # If no index is provided
+ if not m_dict['index']:
+ if name in results_dictionary:
+ _reuse_fail(name, values)
+ results_dictionary[name] = parsed_value
+ else:
+ if name in results_dictionary:
+ # The name has already been used as a scalar, then it
+ # will be in this dictionary and map to a non-dictionary.
+ if not isinstance(results_dictionary.get(name), dict):
+ _reuse_fail(name, values)
+ else:
+ results_dictionary[name] = {}
+
+ index = int(m_dict['index'])
+ # Make sure the index position hasn't already been assigned a value.
+ if index in results_dictionary[name]:
+ _reuse_fail('{}[{}]'.format(name, index), values)
+ results_dictionary[name][index] = parsed_value
+
+
+def _process_list_value(name, parse_fn, var_type, m_dict, values,
+ results_dictionary):
+ """Update results_dictionary from a list of values.
+
+ Used to update results_dictionary to be returned by parse_values when
+ encountering a clause with a list RHS (e.g. "arr=[1,2,3]".)
+
+ Mutates results_dictionary.
+
+ Args:
+ name: Name of variable in assignment ("arr").
+ parse_fn: Function for parsing individual values.
+ var_type: Type of named variable.
+ m_dict: Dictionary constructed from regex parsing.
+ m_dict['val']: RHS value (scalar)
+ values: Full expression being parsed
+ results_dictionary: The dictionary being updated for return by the parsing
+ function.
+
+ Raises:
+ ValueError: If the name has an index or the values cannot be parsed.
+ """
+ if m_dict['index'] is not None:
+ raise ValueError('Assignment of a list to a list index.')
+ elements = filter(None, re.split('[ ,]', m_dict['vals']))
+ # Make sure the name hasn't already been assigned a value
+ if name in results_dictionary:
+ raise _reuse_fail(name, values)
+ try:
+ results_dictionary[name] = [parse_fn(e) for e in elements]
+ except ValueError:
+ _parse_fail(name, var_type, m_dict['vals'], values)
+
+
+def _cast_to_type_if_compatible(name, param_type, value):
+ """Cast hparam to the provided type, if compatible.
+
+ Args:
+ name: Name of the hparam to be cast.
+ param_type: The type of the hparam.
+ value: The value to be cast, if compatible.
+
+ Returns:
+ The result of casting `value` to `param_type`.
+
+ Raises:
+ ValueError: If the type of `value` is not compatible with param_type.
+ * If `param_type` is a string type, but `value` is not.
+ * If `param_type` is a boolean, but `value` is not, or vice versa.
+ * If `param_type` is an integer type, but `value` is not.
+ * If `param_type` is a float type, but `value` is not a numeric type.
+ """
+ fail_msg = (
+ "Could not cast hparam '%s' of type '%s' from value %r" %
+ (name, param_type, value))
+
+ # If `value` is already of type `param_type`, return it directly.
+ # `isinstance` is too weak (e.g. isinstance(True, int) == True).
+ if type(value) == param_type: # pylint: disable=unidiomatic-typecheck
+ return value
+
+ # Some callers use None, for which we can't do any casting/checking. :(
+ if issubclass(param_type, type(None)):
+ return value
+
+ # Avoid converting a non-string type to a string.
+ if (issubclass(param_type, (six.string_types, six.binary_type)) and
+ not isinstance(value, (six.string_types, six.binary_type))):
+ raise ValueError(fail_msg)
+
+ # Avoid converting a number or string type to a boolean or vice versa.
+ if issubclass(param_type, bool) != isinstance(value, bool):
+ raise ValueError(fail_msg)
+
+ # Avoid converting float to an integer (the reverse is fine).
+ if (issubclass(param_type, numbers.Integral) and
+ not isinstance(value, numbers.Integral)):
+ raise ValueError(fail_msg)
+
+ # Avoid converting a non-numeric type to a numeric type.
+ if (issubclass(param_type, numbers.Number) and
+ not isinstance(value, numbers.Number)):
+ raise ValueError(fail_msg)
+
+ return param_type(value)
+
+
+def parse_values(values, type_map, ignore_unknown=False):
+ """Parses hyperparameter values from a string into a python map.
+
+ `values` is a string containing comma-separated `name=value` pairs.
+ For each pair, the value of the hyperparameter named `name` is set to
+ `value`.
+
+ If a hyperparameter name appears multiple times in `values`, a ValueError
+ is raised (e.g. 'a=1,a=2', 'a[1]=1,a[1]=2').
+
+ If a hyperparameter name in both an index assignment and scalar assignment,
+ a ValueError is raised. (e.g. 'a=[1,2,3],a[0] = 1').
+
+ The hyperparameter name may contain '.' symbols, which will result in an
+ attribute name that is only accessible through the getattr and setattr
+ functions. (And must be first explicit added through add_hparam.)
+
+ WARNING: Use of '.' in your variable names is allowed, but is not well
+ supported and not recommended.
+
+ The `value` in `name=value` must follows the syntax according to the
+ type of the parameter:
+
+ * Scalar integer: A Python-parsable integer point value. E.g.: 1,
+ 100, -12.
+ * Scalar float: A Python-parsable floating point value. E.g.: 1.0,
+ -.54e89.
+ * Boolean: Either true or false.
+ * Scalar string: A non-empty sequence of characters, excluding comma,
+ spaces, and square brackets. E.g.: foo, bar_1.
+ * List: A comma separated list of scalar values of the parameter type
+ enclosed in square brackets. E.g.: [1,2,3], [1.0,1e-12], [high,low].
+
+ When index assignment is used, the corresponding type_map key should be the
+ list name. E.g. for "arr[1]=0" the type_map must have the key "arr" (not
+ "arr[1]").
+
+ Args:
+ values: String. Comma separated list of `name=value` pairs where
+ 'value' must follow the syntax described above.
+ type_map: A dictionary mapping hyperparameter names to types. Note every
+ parameter name in values must be a key in type_map. The values must
+ conform to the types indicated, where a value V is said to conform to a
+ type T if either V has type T, or V is a list of elements of type T.
+ Hence, for a multidimensional parameter 'x' taking float values,
+ 'x=[0.1,0.2]' will parse successfully if type_map['x'] = float.
+ ignore_unknown: Bool. Whether values that are missing a type in type_map
+ should be ignored. If set to True, a ValueError will not be raised for
+ unknown hyperparameter type.
+
+ Returns:
+ A python map mapping each name to either:
+ * A scalar value.
+ * A list of scalar values.
+ * A dictionary mapping index numbers to scalar values.
+ (e.g. "x=5,L=[1,2],arr[1]=3" results in {'x':5,'L':[1,2],'arr':{1:3}}")
+
+ Raises:
+ ValueError: If there is a problem with input.
+ * If `values` cannot be parsed.
+ * If a list is assigned to a list index (e.g. 'a[1] = [1,2,3]').
+ * If the same rvalue is assigned two different values (e.g. 'a=1,a=2',
+ 'a[1]=1,a[1]=2', or 'a=1,a=[1]')
+ """
+ results_dictionary = {}
+ pos = 0
+ while pos < len(values):
+ m = PARAM_RE.match(values, pos)
+ if not m:
+ raise ValueError('Malformed hyperparameter value: %s' % values[pos:])
+ # Check that there is a comma between parameters and move past it.
+ pos = m.end()
+ # Parse the values.
+ m_dict = m.groupdict()
+ name = m_dict['name']
+ if name not in type_map:
+ if ignore_unknown:
+ continue
+ raise ValueError('Unknown hyperparameter type for %s' % name)
+ type_ = type_map[name]
+
+ # Set up correct parsing function (depending on whether type_ is a bool)
+ if type_ == bool:
+
+ def parse_bool(value):
+ if value in ['true', 'True']:
+ return True
+ elif value in ['false', 'False']:
+ return False
+ else:
+ try:
+ return bool(int(value))
+ except ValueError:
+ _parse_fail(name, type_, value, values)
+
+ parse = parse_bool
+ else:
+ parse = type_
+
+ # If a singe value is provided
+ if m_dict['val'] is not None:
+ _process_scalar_value(name, parse, type_, m_dict, values,
+ results_dictionary)
+
+ # If the assigned value is a list:
+ elif m_dict['vals'] is not None:
+ _process_list_value(name, parse, type_, m_dict, values,
+ results_dictionary)
+
+ else: # Not assigned a list or value
+ _parse_fail(name, type_, '', values)
+
+ return results_dictionary
+
+
+class HParams(object):
+ """Class to hold a set of hyperparameters as name-value pairs.
+
+ A `HParams` object holds hyperparameters used to build and train a model,
+ such as the number of hidden units in a neural net layer or the learning rate
+ to use when training.
+
+ You first create a `HParams` object by specifying the names and values of the
+ hyperparameters.
+
+ To make them easily accessible the parameter names are added as direct
+ attributes of the class. A typical usage is as follows:
+
+ ```python
+ # Create a HParams object specifying names and values of the model
+ # hyperparameters:
+ hparams = HParams(learning_rate=0.1, num_hidden_units=100)
+
+ # The hyperparameter are available as attributes of the HParams object:
+ hparams.learning_rate ==> 0.1
+ hparams.num_hidden_units ==> 100
+ ```
+
+ Hyperparameters have type, which is inferred from the type of their value
+ passed at construction type. The currently supported types are: integer,
+ float, boolean, string, and list of integer, float, boolean, or string.
+
+ You can override hyperparameter values by calling the
+ [`parse()`](#HParams.parse) method, passing a string of comma separated
+ `name=value` pairs. This is intended to make it possible to override
+ any hyperparameter values from a single command-line flag to which
+ the user passes 'hyper-param=value' pairs. It avoids having to define
+ one flag for each hyperparameter.
+
+ The syntax expected for each value depends on the type of the parameter.
+ See `parse()` for a description of the syntax.
+
+ Example:
+
+ ```python
+ # Define a command line flag to pass name=value pairs.
+ # For example using argparse:
+ import argparse
+ parser = argparse.ArgumentParser(description='Train my model.')
+ parser.add_argument('--hparams', type=str,
+ help='Comma separated list of "name=value" pairs.')
+ args = parser.parse_args()
+ ...
+ def my_program():
+ # Create a HParams object specifying the names and values of the
+ # model hyperparameters:
+ hparams = tf.contrib.training.HParams(
+ learning_rate=0.1,
+ num_hidden_units=100,
+ activations=['relu', 'tanh'])
+
+ # Override hyperparameters values by parsing the command line
+ hparams.parse(args.hparams)
+
+ # If the user passed `--hparams=learning_rate=0.3` on the command line
+ # then 'hparams' has the following attributes:
+ hparams.learning_rate ==> 0.3
+ hparams.num_hidden_units ==> 100
+ hparams.activations ==> ['relu', 'tanh']
+
+ # If the hyperparameters are in json format use parse_json:
+ hparams.parse_json('{"learning_rate": 0.3, "activations": "relu"}')
+ ```
+ """
+
+ _HAS_DYNAMIC_ATTRIBUTES = True # Required for pytype checks.
+
+ def __init__(self, hparam_def=None, model_structure=None, **kwargs):
+ """Create an instance of `HParams` from keyword arguments.
+
+ The keyword arguments specify name-values pairs for the hyperparameters.
+ The parameter types are inferred from the type of the values passed.
+
+ The parameter names are added as attributes of `HParams` object, so they
+ can be accessed directly with the dot notation `hparams._name_`.
+
+ Example:
+
+ ```python
+ # Define 3 hyperparameters: 'learning_rate' is a float parameter,
+ # 'num_hidden_units' an integer parameter, and 'activation' a string
+ # parameter.
+ hparams = tf.contrib.training.HParams(
+ learning_rate=0.1, num_hidden_units=100, activation='relu')
+
+ hparams.activation ==> 'relu'
+ ```
+
+ Note that a few names are reserved and cannot be used as hyperparameter
+ names. If you use one of the reserved name the constructor raises a
+ `ValueError`.
+
+ Args:
+ hparam_def: Serialized hyperparameters, encoded as a hparam_pb2.HParamDef
+ protocol buffer. If provided, this object is initialized by
+ deserializing hparam_def. Otherwise **kwargs is used.
+ model_structure: An instance of ModelStructure, defining the feature
+ crosses to be used in the Trial.
+ **kwargs: Key-value pairs where the key is the hyperparameter name and
+ the value is the value for the parameter.
+
+ Raises:
+ ValueError: If both `hparam_def` and initialization values are provided,
+ or if one of the arguments is invalid.
+
+ """
+ # Register the hyperparameters and their type in _hparam_types.
+ # This simplifies the implementation of parse().
+ # _hparam_types maps the parameter name to a tuple (type, bool).
+ # The type value is the type of the parameter for scalar hyperparameters,
+ # or the type of the list elements for multidimensional hyperparameters.
+ # The bool value is True if the value is a list, False otherwise.
+ self._hparam_types = {}
+ self._model_structure = model_structure
+ if hparam_def:
+ self._init_from_proto(hparam_def)
+ if kwargs:
+ raise ValueError('hparam_def and initialization values are '
+ 'mutually exclusive')
+ else:
+ for name, value in six.iteritems(kwargs):
+ self.add_hparam(name, value)
+
+ def _init_from_proto(self, hparam_def):
+ """Creates a new HParams from `HParamDef` protocol buffer.
+
+ Args:
+ hparam_def: `HParamDef` protocol buffer.
+ """
+ assert isinstance(hparam_def, hparam_pb2.HParamDef)
+ for name, value in hparam_def.hparam.items():
+ kind = value.WhichOneof('kind')
+ if kind.endswith('_value'):
+ # Single value.
+ if kind.startswith('int64'):
+ # Setting attribute value to be 'int' to ensure the type is compatible
+ # with both Python2 and Python3.
+ self.add_hparam(name, int(getattr(value, kind)))
+ elif kind.startswith('bytes'):
+ # Setting attribute value to be 'str' to ensure the type is compatible
+ # with both Python2 and Python3. UTF-8 encoding is assumed.
+ self.add_hparam(name, compat.as_str(getattr(value, kind)))
+ else:
+ self.add_hparam(name, getattr(value, kind))
+ else:
+ # List of values.
+ if kind.startswith('int64'):
+ # Setting attribute value to be 'int' to ensure the type is compatible
+ # with both Python2 and Python3.
+ self.add_hparam(name, [int(v) for v in getattr(value, kind).value])
+ elif kind.startswith('bytes'):
+ # Setting attribute value to be 'str' to ensure the type is compatible
+ # with both Python2 and Python3. UTF-8 encoding is assumed.
+ self.add_hparam(
+ name, [compat.as_str(v) for v in getattr(value, kind).value])
+ else:
+ self.add_hparam(name, [v for v in getattr(value, kind).value])
+
+ def add_hparam(self, name, value):
+ """Adds {name, value} pair to hyperparameters.
+
+ Args:
+ name: Name of the hyperparameter.
+ value: Value of the hyperparameter. Can be one of the following types:
+ int, float, string, int list, float list, or string list.
+
+ Raises:
+ ValueError: if one of the arguments is invalid.
+ """
+ # Keys in kwargs are unique, but 'name' could the name of a pre-existing
+ # attribute of this object. In that case we refuse to use it as a
+ # hyperparameter name.
+ if getattr(self, name, None) is not None:
+ raise ValueError('Hyperparameter name is reserved: %s' % name)
+ if isinstance(value, (list, tuple)):
+ if not value:
+ raise ValueError(
+ 'Multi-valued hyperparameters cannot be empty: %s' % name)
+ self._hparam_types[name] = (type(value[0]), True)
+ else:
+ self._hparam_types[name] = (type(value), False)
+ setattr(self, name, value)
+
+ def set_hparam(self, name, value):
+ """Set the value of an existing hyperparameter.
+
+ This function verifies that the type of the value matches the type of the
+ existing hyperparameter.
+
+ Args:
+ name: Name of the hyperparameter.
+ value: New value of the hyperparameter.
+
+ Raises:
+ KeyError: If the hyperparameter doesn't exist.
+ ValueError: If there is a type mismatch.
+ """
+ param_type, is_list = self._hparam_types[name]
+ if isinstance(value, list):
+ if not is_list:
+ raise ValueError(
+ 'Must not pass a list for single-valued parameter: %s' % name)
+ setattr(self, name, [
+ _cast_to_type_if_compatible(name, param_type, v) for v in value])
+ else:
+ if is_list:
+ raise ValueError(
+ 'Must pass a list for multi-valued parameter: %s.' % name)
+ setattr(self, name, _cast_to_type_if_compatible(name, param_type, value))
+
+ def del_hparam(self, name):
+ """Removes the hyperparameter with key 'name'.
+
+ Does nothing if it isn't present.
+
+ Args:
+ name: Name of the hyperparameter.
+ """
+ if hasattr(self, name):
+ delattr(self, name)
+ del self._hparam_types[name]
+
+ def parse(self, values):
+ """Override existing hyperparameter values, parsing new values from a string.
+
+ See parse_values for more detail on the allowed format for values.
+
+ Args:
+ values: String. Comma separated list of `name=value` pairs where 'value'
+ must follow the syntax described above.
+
+ Returns:
+ The `HParams` instance.
+
+ Raises:
+ ValueError: If `values` cannot be parsed or a hyperparameter in `values`
+ doesn't exist.
+ """
+ type_map = {}
+ for name, t in self._hparam_types.items():
+ param_type, _ = t
+ type_map[name] = param_type
+
+ values_map = parse_values(values, type_map)
+ return self.override_from_dict(values_map)
+
+ def override_from_dict(self, values_dict):
+ """Override existing hyperparameter values, parsing new values from a dictionary.
+
+ Args:
+ values_dict: Dictionary of name:value pairs.
+
+ Returns:
+ The `HParams` instance.
+
+ Raises:
+ KeyError: If a hyperparameter in `values_dict` doesn't exist.
+ ValueError: If `values_dict` cannot be parsed.
+ """
+ for name, value in values_dict.items():
+ self.set_hparam(name, value)
+ return self
+
+ @deprecation.deprecated(None, 'Use `override_from_dict`.')
+ def set_from_map(self, values_map):
+ """DEPRECATED. Use override_from_dict."""
+ return self.override_from_dict(values_dict=values_map)
+
+ def set_model_structure(self, model_structure):
+ self._model_structure = model_structure
+
+ def get_model_structure(self):
+ return self._model_structure
+
+ def to_json(self, indent=None, separators=None, sort_keys=False):
+ """Serializes the hyperparameters into JSON.
+
+ Args:
+ indent: If a non-negative integer, JSON array elements and object members
+ will be pretty-printed with that indent level. An indent level of 0, or
+ negative, will only insert newlines. `None` (the default) selects the
+ most compact representation.
+ separators: Optional `(item_separator, key_separator)` tuple. Default is
+ `(', ', ': ')`.
+ sort_keys: If `True`, the output dictionaries will be sorted by key.
+
+ Returns:
+ A JSON string.
+ """
+ return json.dumps(
+ self.values(),
+ indent=indent,
+ separators=separators,
+ sort_keys=sort_keys)
+
+ def parse_json(self, values_json):
+ """Override existing hyperparameter values, parsing new values from a json object.
+
+ Args:
+ values_json: String containing a json object of name:value pairs.
+
+ Returns:
+ The `HParams` instance.
+
+ Raises:
+ KeyError: If a hyperparameter in `values_json` doesn't exist.
+ ValueError: If `values_json` cannot be parsed.
+ """
+ values_map = json.loads(values_json)
+ return self.override_from_dict(values_map)
+
+ def values(self):
+ """Return the hyperparameter values as a Python dictionary.
+
+ Returns:
+ A dictionary with hyperparameter names as keys. The values are the
+ hyperparameter values.
+ """
+ return {n: getattr(self, n) for n in self._hparam_types.keys()}
+
+ def get(self, key, default=None):
+ """Returns the value of `key` if it exists, else `default`."""
+ if key in self._hparam_types:
+ # Ensure that default is compatible with the parameter type.
+ if default is not None:
+ param_type, is_param_list = self._hparam_types[key]
+ type_str = 'list<%s>' % param_type if is_param_list else str(param_type)
+ fail_msg = ("Hparam '%s' of type '%s' is incompatible with "
+ 'default=%s' % (key, type_str, default))
+
+ is_default_list = isinstance(default, list)
+ if is_param_list != is_default_list:
+ raise ValueError(fail_msg)
+
+ try:
+ if is_default_list:
+ for value in default:
+ _cast_to_type_if_compatible(key, param_type, value)
+ else:
+ _cast_to_type_if_compatible(key, param_type, default)
+ except ValueError as e:
+ raise ValueError('%s. %s' % (fail_msg, e))
+
+ return getattr(self, key)
+
+ return default
+
+ def __contains__(self, key):
+ return key in self._hparam_types
+
+ def __str__(self):
+ hpdict = self.values()
+ output_list = ['{}={}'.format(key, hpdict[key]) for key in hpdict]
+ return ','.join(output_list)
+
+ def __repr__(self):
+ strval = str(sorted(self.values().items()))
+ return '%s(%s)' % (type(self).__name__, strval)
+
+ @staticmethod
+ def _get_kind_name(param_type, is_list):
+ """Returns the field name given parameter type and is_list.
+
+ Args:
+ param_type: Data type of the hparam.
+ is_list: Whether this is a list.
+
+ Returns:
+ A string representation of the field name.
+
+ Raises:
+ ValueError: If parameter type is not recognized.
+ """
+ if issubclass(param_type, bool):
+ # This check must happen before issubclass(param_type, six.integer_types),
+ # since Python considers bool to be a subclass of int.
+ typename = 'bool'
+ elif issubclass(param_type, six.integer_types):
+ # Setting 'int' and 'long' types to be 'int64' to ensure the type is
+ # compatible with both Python2 and Python3.
+ typename = 'int64'
+ elif issubclass(param_type, (six.string_types, six.binary_type)):
+ # Setting 'string' and 'bytes' types to be 'bytes' to ensure the type is
+ # compatible with both Python2 and Python3.
+ typename = 'bytes'
+ elif issubclass(param_type, float):
+ typename = 'float'
+ else:
+ raise ValueError('Unsupported parameter type: %s' % str(param_type))
+
+ suffix = 'list' if is_list else 'value'
+ return '_'.join([typename, suffix])
+
+ def to_proto(self, export_scope=None): # pylint: disable=unused-argument
+ """Converts a `HParams` object to a `HParamDef` protocol buffer.
+
+ Args:
+ export_scope: Optional `string`. Name scope to remove.
+
+ Returns:
+ A `HParamDef` protocol buffer.
+ """
+ hparam_proto = hparam_pb2.HParamDef()
+ for name in self._hparam_types:
+ # Parse the values.
+ param_type, is_list = self._hparam_types.get(name, (None, None))
+ kind = HParams._get_kind_name(param_type, is_list)
+
+ if is_list:
+ if kind.startswith('bytes'):
+ v_list = [compat.as_bytes(v) for v in getattr(self, name)]
+ else:
+ v_list = [v for v in getattr(self, name)]
+ getattr(hparam_proto.hparam[name], kind).value.extend(v_list)
+ else:
+ v = getattr(self, name)
+ if kind.startswith('bytes'):
+ v = compat.as_bytes(getattr(self, name))
+ setattr(hparam_proto.hparam[name], kind, v)
+
+ return hparam_proto
+
+ @staticmethod
+ def from_proto(hparam_def, import_scope=None): # pylint: disable=unused-argument
+ return HParams(hparam_def=hparam_def)
+
+
+ops.register_proto_function(
+ 'hparams',
+ proto_type=hparam_pb2.HParamDef,
+ to_proto=HParams.to_proto,
+ from_proto=HParams.from_proto)
+
diff --git a/src/hparam_pb2.py b/src/hparam_pb2.py
new file mode 100644
index 0000000..ba10c30
--- /dev/null
+++ b/src/hparam_pb2.py
@@ -0,0 +1,399 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: tensorflow/contrib/training/python/training/hparam.proto
+
+import sys
+_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import message as _message
+from google.protobuf import reflection as _reflection
+from google.protobuf import symbol_database as _symbol_database
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+
+
+DESCRIPTOR = _descriptor.FileDescriptor(
+ name='tensorflow/contrib/training/python/training/hparam.proto',
+ package='tensorflow',
+ syntax='proto3',
+ serialized_options=_b('\370\001\001'),
+ serialized_pb=_b('\n8tensorflow/contrib/training/python/training/hparam.proto\x12\ntensorflow\"\xd6\x04\n\tHParamDef\x12\x31\n\x06hparam\x18\x01 \x03(\x0b\x32!.tensorflow.HParamDef.HparamEntry\x1a\x1a\n\tBytesList\x12\r\n\x05value\x18\x01 \x03(\x0c\x1a\x1e\n\tFloatList\x12\x11\n\x05value\x18\x01 \x03(\x02\x42\x02\x10\x01\x1a\x1e\n\tInt64List\x12\x11\n\x05value\x18\x01 \x03(\x03\x42\x02\x10\x01\x1a\x1d\n\x08\x42oolList\x12\x11\n\x05value\x18\x01 \x03(\x08\x42\x02\x10\x01\x1a\xc9\x02\n\nHParamType\x12\x15\n\x0bint64_value\x18\x01 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\x02 \x01(\x02H\x00\x12\x15\n\x0b\x62ytes_value\x18\x03 \x01(\x0cH\x00\x12\x14\n\nbool_value\x18\x07 \x01(\x08H\x00\x12\x35\n\nint64_list\x18\x04 \x01(\x0b\x32\x1f.tensorflow.HParamDef.Int64ListH\x00\x12\x35\n\nfloat_list\x18\x05 \x01(\x0b\x32\x1f.tensorflow.HParamDef.FloatListH\x00\x12\x35\n\nbytes_list\x18\x06 \x01(\x0b\x32\x1f.tensorflow.HParamDef.BytesListH\x00\x12\x33\n\tbool_list\x18\x08 \x01(\x0b\x32\x1e.tensorflow.HParamDef.BoolListH\x00\x42\x06\n\x04kind\x1aO\n\x0bHparamEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .tensorflow.HParamDef.HParamType:\x02\x38\x01\x42\x03\xf8\x01\x01\x62\x06proto3')
+)
+
+
+
+
+_HPARAMDEF_BYTESLIST = _descriptor.Descriptor(
+ name='BytesList',
+ full_name='tensorflow.HParamDef.BytesList',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='value', full_name='tensorflow.HParamDef.BytesList.value', index=0,
+ number=1, type=12, cpp_type=9, label=3,
+ has_default_value=False, default_value=[],
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[],
+ enum_types=[
+ ],
+ serialized_options=None,
+ is_extendable=False,
+ syntax='proto3',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=137,
+ serialized_end=163,
+)
+
+_HPARAMDEF_FLOATLIST = _descriptor.Descriptor(
+ name='FloatList',
+ full_name='tensorflow.HParamDef.FloatList',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='value', full_name='tensorflow.HParamDef.FloatList.value', index=0,
+ number=1, type=2, cpp_type=6, label=3,
+ has_default_value=False, default_value=[],
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=_b('\020\001'), file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[],
+ enum_types=[
+ ],
+ serialized_options=None,
+ is_extendable=False,
+ syntax='proto3',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=165,
+ serialized_end=195,
+)
+
+_HPARAMDEF_INT64LIST = _descriptor.Descriptor(
+ name='Int64List',
+ full_name='tensorflow.HParamDef.Int64List',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='value', full_name='tensorflow.HParamDef.Int64List.value', index=0,
+ number=1, type=3, cpp_type=2, label=3,
+ has_default_value=False, default_value=[],
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=_b('\020\001'), file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[],
+ enum_types=[
+ ],
+ serialized_options=None,
+ is_extendable=False,
+ syntax='proto3',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=197,
+ serialized_end=227,
+)
+
+_HPARAMDEF_BOOLLIST = _descriptor.Descriptor(
+ name='BoolList',
+ full_name='tensorflow.HParamDef.BoolList',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='value', full_name='tensorflow.HParamDef.BoolList.value', index=0,
+ number=1, type=8, cpp_type=7, label=3,
+ has_default_value=False, default_value=[],
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=_b('\020\001'), file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[],
+ enum_types=[
+ ],
+ serialized_options=None,
+ is_extendable=False,
+ syntax='proto3',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=229,
+ serialized_end=258,
+)
+
+_HPARAMDEF_HPARAMTYPE = _descriptor.Descriptor(
+ name='HParamType',
+ full_name='tensorflow.HParamDef.HParamType',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='int64_value', full_name='tensorflow.HParamDef.HParamType.int64_value', index=0,
+ number=1, type=3, cpp_type=2, label=1,
+ has_default_value=False, default_value=0,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='float_value', full_name='tensorflow.HParamDef.HParamType.float_value', index=1,
+ number=2, type=2, cpp_type=6, label=1,
+ has_default_value=False, default_value=float(0),
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='bytes_value', full_name='tensorflow.HParamDef.HParamType.bytes_value', index=2,
+ number=3, type=12, cpp_type=9, label=1,
+ has_default_value=False, default_value=_b(""),
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='bool_value', full_name='tensorflow.HParamDef.HParamType.bool_value', index=3,
+ number=7, type=8, cpp_type=7, label=1,
+ has_default_value=False, default_value=False,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='int64_list', full_name='tensorflow.HParamDef.HParamType.int64_list', index=4,
+ number=4, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='float_list', full_name='tensorflow.HParamDef.HParamType.float_list', index=5,
+ number=5, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='bytes_list', full_name='tensorflow.HParamDef.HParamType.bytes_list', index=6,
+ number=6, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='bool_list', full_name='tensorflow.HParamDef.HParamType.bool_list', index=7,
+ number=8, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[],
+ enum_types=[
+ ],
+ serialized_options=None,
+ is_extendable=False,
+ syntax='proto3',
+ extension_ranges=[],
+ oneofs=[
+ _descriptor.OneofDescriptor(
+ name='kind', full_name='tensorflow.HParamDef.HParamType.kind',
+ index=0, containing_type=None, fields=[]),
+ ],
+ serialized_start=261,
+ serialized_end=590,
+)
+
+_HPARAMDEF_HPARAMENTRY = _descriptor.Descriptor(
+ name='HparamEntry',
+ full_name='tensorflow.HParamDef.HparamEntry',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='key', full_name='tensorflow.HParamDef.HparamEntry.key', index=0,
+ number=1, type=9, cpp_type=9, label=1,
+ has_default_value=False, default_value=_b("").decode('utf-8'),
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='value', full_name='tensorflow.HParamDef.HparamEntry.value', index=1,
+ number=2, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[],
+ enum_types=[
+ ],
+ serialized_options=_b('8\001'),
+ is_extendable=False,
+ syntax='proto3',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=592,
+ serialized_end=671,
+)
+
+_HPARAMDEF = _descriptor.Descriptor(
+ name='HParamDef',
+ full_name='tensorflow.HParamDef',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='hparam', full_name='tensorflow.HParamDef.hparam', index=0,
+ number=1, type=11, cpp_type=10, label=3,
+ has_default_value=False, default_value=[],
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[_HPARAMDEF_BYTESLIST, _HPARAMDEF_FLOATLIST, _HPARAMDEF_INT64LIST, _HPARAMDEF_BOOLLIST, _HPARAMDEF_HPARAMTYPE, _HPARAMDEF_HPARAMENTRY, ],
+ enum_types=[
+ ],
+ serialized_options=None,
+ is_extendable=False,
+ syntax='proto3',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=73,
+ serialized_end=671,
+)
+
+_HPARAMDEF_BYTESLIST.containing_type = _HPARAMDEF
+_HPARAMDEF_FLOATLIST.containing_type = _HPARAMDEF
+_HPARAMDEF_INT64LIST.containing_type = _HPARAMDEF
+_HPARAMDEF_BOOLLIST.containing_type = _HPARAMDEF
+_HPARAMDEF_HPARAMTYPE.fields_by_name['int64_list'].message_type = _HPARAMDEF_INT64LIST
+_HPARAMDEF_HPARAMTYPE.fields_by_name['float_list'].message_type = _HPARAMDEF_FLOATLIST
+_HPARAMDEF_HPARAMTYPE.fields_by_name['bytes_list'].message_type = _HPARAMDEF_BYTESLIST
+_HPARAMDEF_HPARAMTYPE.fields_by_name['bool_list'].message_type = _HPARAMDEF_BOOLLIST
+_HPARAMDEF_HPARAMTYPE.containing_type = _HPARAMDEF
+_HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind'].fields.append(
+ _HPARAMDEF_HPARAMTYPE.fields_by_name['int64_value'])
+_HPARAMDEF_HPARAMTYPE.fields_by_name['int64_value'].containing_oneof = _HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind']
+_HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind'].fields.append(
+ _HPARAMDEF_HPARAMTYPE.fields_by_name['float_value'])
+_HPARAMDEF_HPARAMTYPE.fields_by_name['float_value'].containing_oneof = _HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind']
+_HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind'].fields.append(
+ _HPARAMDEF_HPARAMTYPE.fields_by_name['bytes_value'])
+_HPARAMDEF_HPARAMTYPE.fields_by_name['bytes_value'].containing_oneof = _HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind']
+_HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind'].fields.append(
+ _HPARAMDEF_HPARAMTYPE.fields_by_name['bool_value'])
+_HPARAMDEF_HPARAMTYPE.fields_by_name['bool_value'].containing_oneof = _HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind']
+_HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind'].fields.append(
+ _HPARAMDEF_HPARAMTYPE.fields_by_name['int64_list'])
+_HPARAMDEF_HPARAMTYPE.fields_by_name['int64_list'].containing_oneof = _HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind']
+_HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind'].fields.append(
+ _HPARAMDEF_HPARAMTYPE.fields_by_name['float_list'])
+_HPARAMDEF_HPARAMTYPE.fields_by_name['float_list'].containing_oneof = _HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind']
+_HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind'].fields.append(
+ _HPARAMDEF_HPARAMTYPE.fields_by_name['bytes_list'])
+_HPARAMDEF_HPARAMTYPE.fields_by_name['bytes_list'].containing_oneof = _HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind']
+_HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind'].fields.append(
+ _HPARAMDEF_HPARAMTYPE.fields_by_name['bool_list'])
+_HPARAMDEF_HPARAMTYPE.fields_by_name['bool_list'].containing_oneof = _HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind']
+_HPARAMDEF_HPARAMENTRY.fields_by_name['value'].message_type = _HPARAMDEF_HPARAMTYPE
+_HPARAMDEF_HPARAMENTRY.containing_type = _HPARAMDEF
+_HPARAMDEF.fields_by_name['hparam'].message_type = _HPARAMDEF_HPARAMENTRY
+DESCRIPTOR.message_types_by_name['HParamDef'] = _HPARAMDEF
+_sym_db.RegisterFileDescriptor(DESCRIPTOR)
+
+HParamDef = _reflection.GeneratedProtocolMessageType('HParamDef', (_message.Message,), dict(
+
+ BytesList = _reflection.GeneratedProtocolMessageType('BytesList', (_message.Message,), dict(
+ DESCRIPTOR = _HPARAMDEF_BYTESLIST,
+ __module__ = 'tensorflow.contrib.training.python.training.hparam_pb2'
+ # @@protoc_insertion_point(class_scope:tensorflow.HParamDef.BytesList)
+ ))
+ ,
+
+ FloatList = _reflection.GeneratedProtocolMessageType('FloatList', (_message.Message,), dict(
+ DESCRIPTOR = _HPARAMDEF_FLOATLIST,
+ __module__ = 'tensorflow.contrib.training.python.training.hparam_pb2'
+ # @@protoc_insertion_point(class_scope:tensorflow.HParamDef.FloatList)
+ ))
+ ,
+
+ Int64List = _reflection.GeneratedProtocolMessageType('Int64List', (_message.Message,), dict(
+ DESCRIPTOR = _HPARAMDEF_INT64LIST,
+ __module__ = 'tensorflow.contrib.training.python.training.hparam_pb2'
+ # @@protoc_insertion_point(class_scope:tensorflow.HParamDef.Int64List)
+ ))
+ ,
+
+ BoolList = _reflection.GeneratedProtocolMessageType('BoolList', (_message.Message,), dict(
+ DESCRIPTOR = _HPARAMDEF_BOOLLIST,
+ __module__ = 'tensorflow.contrib.training.python.training.hparam_pb2'
+ # @@protoc_insertion_point(class_scope:tensorflow.HParamDef.BoolList)
+ ))
+ ,
+
+ HParamType = _reflection.GeneratedProtocolMessageType('HParamType', (_message.Message,), dict(
+ DESCRIPTOR = _HPARAMDEF_HPARAMTYPE,
+ __module__ = 'tensorflow.contrib.training.python.training.hparam_pb2'
+ # @@protoc_insertion_point(class_scope:tensorflow.HParamDef.HParamType)
+ ))
+ ,
+
+ HparamEntry = _reflection.GeneratedProtocolMessageType('HparamEntry', (_message.Message,), dict(
+ DESCRIPTOR = _HPARAMDEF_HPARAMENTRY,
+ __module__ = 'tensorflow.contrib.training.python.training.hparam_pb2'
+ # @@protoc_insertion_point(class_scope:tensorflow.HParamDef.HparamEntry)
+ ))
+ ,
+ DESCRIPTOR = _HPARAMDEF,
+ __module__ = 'tensorflow.contrib.training.python.training.hparam_pb2'
+ # @@protoc_insertion_point(class_scope:tensorflow.HParamDef)
+ ))
+_sym_db.RegisterMessage(HParamDef)
+_sym_db.RegisterMessage(HParamDef.BytesList)
+_sym_db.RegisterMessage(HParamDef.FloatList)
+_sym_db.RegisterMessage(HParamDef.Int64List)
+_sym_db.RegisterMessage(HParamDef.BoolList)
+_sym_db.RegisterMessage(HParamDef.HParamType)
+_sym_db.RegisterMessage(HParamDef.HparamEntry)
+
+
+DESCRIPTOR._options = None
+_HPARAMDEF_FLOATLIST.fields_by_name['value']._options = None
+_HPARAMDEF_INT64LIST.fields_by_name['value']._options = None
+_HPARAMDEF_BOOLLIST.fields_by_name['value']._options = None
+_HPARAMDEF_HPARAMENTRY._options = None
+# @@protoc_insertion_point(module_scope)
+
diff --git a/src/interactive_conditional_samples.py b/src/interactive_conditional_samples.py
index 26390e6..0bde72e 100755
--- a/src/interactive_conditional_samples.py
+++ b/src/interactive_conditional_samples.py
@@ -14,6 +14,23 @@ import tflex
import model, sample, encoder
+from load_dataset import load_dataset, Sampler
+
+def generate_samples(context, enc, sampler, data_sampler, session, batch_size, count=1):
+ print('Generating samples...')
+ context_tokens = data_sampler.sample(1)
+ index = 0
+ while index < count:
+ out = session.run(
+ sampler,
+ feed_dict={context: batch_size * [context_tokens]})
+ for i in range(len(out)):
+ text = enc.decode(out[i])
+ text = '======== SAMPLE {} ========\n{}\n'.format(
+ index + 1, text)
+ print(text)
+ index += 1
+
def interact_model(
model_name='117M',
seed=None,
@@ -22,7 +39,9 @@ def interact_model(
length=None,
temperature=1,
top_k=0,
- top_p=0.0
+ top_p=0.0,
+ dataset=None,
+ combine=50000
):
"""
Interactively run the model
@@ -58,10 +77,10 @@ def interact_model(
elif length > hparams.n_ctx:
raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
- with tf.Session(graph=tf.Graph()) as sess:
- context = tf.placeholder(tf.int32, [batch_size, None])
+ with tf.compat.v1.Session(graph=tf.Graph()) as sess:
+ context = tf.compat.v1.placeholder(tf.int32, [batch_size, None])
np.random.seed(seed)
- tf.set_random_seed(seed)
+ tf.compat.v1.set_random_seed(seed)
output = sample.sample_sequence(
hparams=hparams, length=length,
context=context,
@@ -73,11 +92,21 @@ def interact_model(
ckpt = tflex.latest_checkpoint(os.path.join('models', model_name))
saver.restore(sess, ckpt)
+ print('Loading dataset...')
+ chunks = load_dataset(enc, dataset, combine=combine)
+ data_sampler = Sampler(chunks, seed=seed)
+ print('dataset has', data_sampler.total_size, 'tokens', len(chunks), 'chunks')
+
+ import pdb
+ pdb.set_trace()
+
while True:
raw_text = input("Model prompt >>> ")
while not raw_text:
print('Prompt should not be empty!')
raw_text = input("Model prompt >>> ")
+ raw_text = raw_text.rstrip()
+ print(repr(raw_text))
context_tokens = enc.encode(raw_text)
generated = 0
for _ in range(nsamples // batch_size):
diff --git a/src/memory_saving_gradients.py b/src/memory_saving_gradients.py
index 659691f..f2cd20e 100644
--- a/src/memory_saving_gradients.py
+++ b/src/memory_saving_gradients.py
@@ -2,7 +2,7 @@ from toposort import toposort
import contextlib
import numpy as np
import tensorflow as tf
-import tensorflow.contrib.graph_editor as ge
+import tflex_graph_editor as ge
import time
import sys
sys.setrecursionlimit(10000)
@@ -10,7 +10,7 @@ sys.setrecursionlimit(10000)
util = sys.modules[__name__]
# getting rid of "WARNING:tensorflow:VARIABLES collection name is deprecated"
-setattr(tf.GraphKeys, "VARIABLES", "variables")
+setattr(tf.compat.v1.GraphKeys, "VARIABLES", "variables")
# save original gradients since tf.gradient could be monkey-patched to point
# to our version
diff --git a/src/model.py b/src/model.py
index d0cde2a..dda5b58 100644
--- a/src/model.py
+++ b/src/model.py
@@ -1,9 +1,14 @@
import numpy as np
import tensorflow as tf
-from tensorflow.contrib.training import HParams
+import tensorflow_addons as tfa
+#from tensorflow.contrib.training import HParams
+import hparam
+import collections
+import tflex
+
def default_hparams():
- return HParams(
+ return hparam.HParams(
n_vocab=50257,
n_ctx=1024,
n_embd=768,
@@ -16,12 +21,22 @@ def default_hparams():
import os
-def get_variable(name):
- name = os.path.join(tf.get_variable_scope().name, name)
- vs = tf.trainable_variables()
+def get_variable(name, unset=None):
+ fqn = tflex.join_variable_scope(name)
+ vs = tf.compat.v1.trainable_variables()
for x in vs:
- if x.name.startswith(name + ':'):
+ # should this next test be this instead?
+ # if x.name == fqn + ':0'
+ #
+ if x.name.startswith(fqn + ':'):
return x
+ if callable(unset):
+ return unset()
+
+def init_variable(name, *args, **kws):
+ def unset():
+ return tf.compat.v1.get_variable(name, *args, **kws)
+ return get_variable(name, unset=unset)
def shape_list(x):
"""Deal with dynamic shape in tensorflow cleanly."""
@@ -29,24 +44,26 @@ def shape_list(x):
dynamic = tf.shape(x)
return [dynamic[i] if s is None else s for i, s in enumerate(static)]
-def softmax(x, axis=-1):
+def softmax(x, axis=-1, name=None):
x = x - tf.reduce_max(x, axis=axis, keepdims=True)
ex = tf.exp(x)
return ex / tf.reduce_sum(ex, axis=axis, keepdims=True)
+ #return tf.nn.softmax(x, axis=axis, name=name)
def gelu(x):
return 0.5*x*(1+tf.tanh(np.sqrt(2/np.pi)*(x+0.044715*tf.pow(x, 3))))
+ #return tfa.activations.gelu(x)
def norm(x, scope, *, axis=-1, epsilon=1e-5, hparams=None):
"""Normalize to mean = 0, std = 1, then do a diagonal affine transform."""
dtype = hparams.dtype if hparams else tf.float32
- with tf.variable_scope(scope, dtype=dtype):
- n_state = x.shape[-1].value
- g = get_variable('g') or tf.get_variable('g', [n_state], initializer=tf.constant_initializer(1, dtype=dtype))
- b = get_variable('b') or tf.get_variable('b', [n_state], initializer=tf.constant_initializer(0, dtype=dtype))
+ with tflex.variable_scope(scope, dtype=dtype):
+ n_state = shape_list(x)[-1]
+ g = init_variable('g', [n_state], initializer=tf.compat.v1.constant_initializer(1, dtype=dtype))
+ b = init_variable('b', [n_state], initializer=tf.compat.v1.constant_initializer(0, dtype=dtype))
u = tf.reduce_mean(x, axis=axis, keepdims=True)
s = tf.reduce_mean(tf.square(x-u), axis=axis, keepdims=True)
- x = (x - u) * tf.rsqrt(s + epsilon)
+ x = (x - u) * tf.compat.v1.rsqrt(s + epsilon)
x = x*g + b
return x
@@ -62,10 +79,10 @@ def merge_states(x):
def conv1d(x, scope, nf, *, w_init_stdev=0.02, hparams=None):
dtype = hparams.dtype if hparams else tf.float32
- with tf.variable_scope(scope, dtype=dtype):
+ with tflex.variable_scope(scope, dtype=dtype):
*start, nx = shape_list(x)
- w = get_variable('w') or tf.get_variable('w', [1, nx, nf], initializer=tf.random_normal_initializer(stddev=w_init_stdev, dtype=dtype))
- b = get_variable('b') or tf.get_variable('b', [nf], initializer=tf.constant_initializer(0, dtype=dtype))
+ w = init_variable('w', [1, nx, nf], initializer=tf.compat.v1.random_normal_initializer(stddev=w_init_stdev, dtype=dtype))
+ b = init_variable('b', [nf], initializer=tf.compat.v1.constant_initializer(0, dtype=dtype))
c = tf.reshape(tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf]))+b, start+[nf])
return c
@@ -105,7 +122,7 @@ def attn(x, scope, n_state, *, past, hparams):
def multihead_attn(q, k, v):
# q, k, v have shape [batch, heads, sequence, features]
w = tf.matmul(q, k, transpose_b=True)
- w = w * tf.rsqrt(tf.cast(v.shape[-1].value, w.dtype))
+ w = w * tf.compat.v1.rsqrt(tf.cast(v.shape[-1], w.dtype))
w = mask_attn_weights(w)
w = softmax(w)
@@ -114,7 +131,7 @@ def attn(x, scope, n_state, *, past, hparams):
return a
dtype = hparams.dtype if hparams else tf.float32
- with tf.variable_scope(scope, dtype=dtype):
+ with tflex.variable_scope(scope, dtype=dtype):
c = conv1d(x, 'c_attn', n_state*3, hparams=hparams)
q, k, v = map(split_heads, tf.split(c, 3, axis=2))
present = tf.stack([k, v], axis=1)
@@ -131,9 +148,10 @@ def attn(x, scope, n_state, *, past, hparams):
def mlp(x, scope, n_state, *, hparams):
dtype = hparams.dtype if hparams else tf.float32
- with tf.variable_scope(scope, dtype=dtype):
- nx = x.shape[-1].value
- h = gelu(conv1d(x, 'c_fc', n_state, hparams=hparams))
+ with tflex.variable_scope(scope, dtype=dtype):
+ nx = shape_list(x)[-1]
+ h0 = conv1d(x, 'c_fc', n_state, hparams=hparams)
+ h = gelu(h0)
h2 = conv1d(h, 'c_proj', nx, hparams=hparams)
h2 = dropout(h2, hparams.res_dropout)
return h2
@@ -145,8 +163,8 @@ def dropout(x, pdrop=0.1, train=True):
def block(x, scope, *, past, hparams):
dtype = hparams.dtype if hparams else tf.float32
- with tf.variable_scope(scope, dtype=dtype):
- nx = x.shape[-1].value
+ with tflex.variable_scope(scope, dtype=dtype):
+ nx = shape_list(x)[-1]
a, present = attn(norm(x, 'ln_1', hparams=hparams), 'attn', nx, past=past, hparams=hparams)
x = x + a
m = mlp(norm(x, 'ln_2', hparams=hparams), 'mlp', nx*4, hparams=hparams)
@@ -168,16 +186,16 @@ def positions_for(tokens, past_length):
return expand_tile(past_length + tf.range(nsteps), batch_size)
-def model(hparams, X, past=None, scope='model', reuse=tf.AUTO_REUSE):
+def model(hparams, X, past=None, scope='model', reuse=tf.compat.v1.AUTO_REUSE):
dtype = hparams.dtype if hparams else tf.float32
- with tf.variable_scope(scope, reuse=reuse, dtype=dtype):
+ with tflex.variable_scope(scope, reuse=reuse, dtype=dtype):
results = {}
batch, sequence = shape_list(X)
- wpe = get_variable('wpe') or tf.get_variable('wpe', [hparams.n_ctx, hparams.n_embd],
- initializer=tf.random_normal_initializer(stddev=0.01, dtype=dtype))
- wte = get_variable('wte') or tf.get_variable('wte', [hparams.n_vocab, hparams.n_embd],
- initializer=tf.random_normal_initializer(stddev=0.02, dtype=dtype))
+ wpe = init_variable('wpe', [hparams.n_ctx, hparams.n_embd],
+ initializer=tf.compat.v1.random_normal_initializer(stddev=0.01, dtype=dtype))
+ wte = init_variable('wte', [hparams.n_vocab, hparams.n_embd],
+ initializer=tf.compat.v1.random_normal_initializer(stddev=0.02, dtype=dtype))
past_length = 0 if past is None else tf.shape(past)[-2]
h = tf.gather(wte, X) + tf.gather(wpe, positions_for(X, past_length))
@@ -188,7 +206,7 @@ def model(hparams, X, past=None, scope='model', reuse=tf.AUTO_REUSE):
for layer, past in enumerate(pasts):
h, present = block(h, 'h%d' % layer, past=past, hparams=hparams)
if layer == 10:
- tf.add_to_collection('checkpoints', h)
+ tf.compat.v1.add_to_collection('checkpoints', h)
presents.append(present)
results['present'] = tf.stack(presents, axis=1)
h = norm(h, 'ln_f', hparams=hparams)
diff --git a/src/sample.py b/src/sample.py
index ff517be..fb4ad80 100644
--- a/src/sample.py
+++ b/src/sample.py
@@ -44,7 +44,7 @@ def sample_sequence(*, hparams, length, start_token=None, batch_size=None, conte
context = tf.fill([batch_size, 1], start_token)
def step(hparams, tokens, past=None):
- lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE)
+ lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.compat.v1.AUTO_REUSE)
if hparams.dtype != tf.float32:
lm_output["logits"] = tf.cast(lm_output["logits"], tf.float32)
@@ -64,12 +64,12 @@ def sample_sequence(*, hparams, length, start_token=None, batch_size=None, conte
def body(past, prev, output):
next_outputs = step(hparams, prev[:, tf.newaxis], past=past)
- logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature)
+ logits = next_outputs['logits'][:, -1, :] / tf.compat.v1.to_float(temperature)
if top_p > 0.0:
logits = top_p_logits(logits, p=top_p, epsilon=epsilon)
else:
logits = top_k_logits(logits, k=top_k, epsilon=epsilon)
- samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32)
+ samples = tf.compat.v1.multinomial(logits, num_samples=1, output_dtype=tf.int32)
return [
tf.concat([past, next_outputs['presents']], axis=-2),
tf.squeeze(samples, axis=[1]),
diff --git a/tflex.py b/tflex.py
index 7c76c44..4ec65b4 100644
--- a/tflex.py
+++ b/tflex.py
@@ -8,6 +8,7 @@ import tqdm
import h5py
import shutil
import tempfile
+import traceback
def split_by_params(vs, n=200e6, f=None):
if f is None:
@@ -25,7 +26,7 @@ def split_by_params(vs, n=200e6, f=None):
yield xs
def latest_checkpoint(checkpoint_dir, latest_filename=None):
- ctrs = np.array([[int(y) for y in re.findall(r'model-([0-9]+)(?:-[0-9]+)?[.](?:npy|hdf5)', x)] for x in glob(os.path.join(checkpoint_dir, 'model-*.*'))]).flatten()
+ ctrs = np.array([[int(y) for y in re.findall(r'model-([0-9]+)(?:-[0-9]+)?[.](?:npy|hdf5)', x)] for x in glob(os.path.join(checkpoint_dir, 'model-*.*')) if not x.endswith('.tmp')]).flatten()
if len(ctrs) <= 0:
ckpt = tf.train.latest_checkpoint(checkpoint_dir, latest_filename=latest_filename)
return ckpt
@@ -65,23 +66,55 @@ def assign_values(variables, values, session=None):
def load_snapshot(ckpt, session=None, var_list=None, reshape=False):
session = session or tf.get_default_session()
reader = pywrap_tensorflow.NewCheckpointReader(ckpt)
- vs = var_list or tf.trainable_variables()
+ vs = var_list or tf.compat.v1.trainable_variables()
for variables in tqdm.tqdm(list(split_by_params(vs))):
values = [value for variable, value in grab_values(variables, reader, reshape=reshape)]
assign_values(variables, values, session=session)
+def current_variable_scope():
+ current = tf.compat.v1.get_variable_scope().name
+ return current
+
+def fix_variable_scope(scope):
+ current = current_variable_scope()
+ print('TKTK', 'current', current, 'scope', scope)
+ if current == scope:
+ return current
+ if current.endswith('/' + scope):
+ return current
+ return scope
+
+def join_variable_scope(name, base=None):
+ if base is None:
+ base = current_variable_scope()
+ if len(base) > 0 and name.startswith(base):
+ import pdb
+ pdb.set_trace()
+ if len(name) > 0 and base.startswith(name):
+ import pdb
+ pdb.set_trace()
+ return os.path.join(base, name)
+
+def variable_scope(name=None, **kws):
+ if name is None:
+ name = current_variable_scope()
+ scope = fix_variable_scope(name)
+ #import pdb
+ #pdb.set_trace()
+ return tf.compat.v1.variable_scope(scope, **kws)
+
def get_variable(name, var_list=None):
name, num = name.split(':') if ':' in name else (name, '0')
num = int(num)
- name = os.path.join(tf.get_variable_scope().name, name)
- vs = var_list or tf.trainable_variables()
+ name = join_variable_scope(name)
+ vs = var_list or tf.compat.v1.trainable_variables()
for x in vs:
if x.name.startswith(name + ':%d' % num):
return x
def load_weights(ckpt, session=None, var_list=None, reshape=False):
session = session or tf.get_default_session()
- vs = var_list or tf.trainable_variables()
+ vs = var_list or tf.compat.v1.trainable_variables()
files = list(sorted(glob(ckpt + '-*.npy')))
for out in tqdm.tqdm(files):
for name, value in np.load(out, allow_pickle=True):
@@ -94,10 +127,10 @@ def load_weights(ckpt, session=None, var_list=None, reshape=False):
def load_variables(ckpt, session=None, var_list=None, reshape=False):
session = session or tf.get_default_session()
- vs = var_list or tf.trainable_variables()
+ vs = var_list or tf.compat.v1.trainable_variables()
with h5py.File(ckpt) as f:
for variables in tqdm.tqdm(list(split_by_params(vs))):
- values = [truncate_value(x, f[x.name], reshape=reshape) for x in variables]
+ values = [truncate_value(x, f[x.name], reshape=reshape) for x in variables if x.name in f]
assign_values(variables, values, session=session)
def maketree(path):
@@ -107,21 +140,31 @@ def maketree(path):
pass
def save_variables(ckpt, session=None, var_list=None):
- session = session or tf.get_default_session()
- vs = var_list or tf.trainable_variables()
- maketree(os.path.dirname(ckpt))
- fname = ckpt+'.tmp'
- with h5py.File(fname, "w") as f:
- for variables in tqdm.tqdm(list(split_by_params(vs))):
- values = session.run(variables)
- for value, variable in zip(values, variables):
- name = variable.name
- shape = variable.shape.as_list()
- dtype = variable.dtype
- dset = f.create_dataset(name, shape, dtype=np.float32)
- dset[:] = value
- print('Writing snapshot %s' % ckpt)
- os.rename(ckpt+'.tmp', ckpt)
+ while True:
+ try:
+ session = session or tf.get_default_session()
+ vs = var_list or tf.compat.v1.trainable_variables()
+ maketree(os.path.dirname(ckpt))
+ fname = ckpt+'.tmp'
+ with h5py.File(fname, "w") as f:
+ for variables in tqdm.tqdm(list(split_by_params(vs))):
+ values = session.run(variables)
+ for value, variable in zip(values, variables):
+ name = variable.name
+ shape = variable.shape.as_list()
+ dtype = variable.dtype
+ dset = f.create_dataset(name, shape, dtype=np.float32)
+ dset[:] = value
+ print('Writing snapshot %s' % ckpt)
+ os.rename(ckpt+'.tmp', ckpt)
+ break
+ except:
+ traceback.print_exc(file=sys.stderr)
+ print('Exception while saving checkpoint. To try again, press "c" then enter.')
+ print('(Is your disk full?)')
+ print('Attaching debugger...')
+ import pdb
+ pdb.set_trace()
class Saver(object):
def __init__(
@@ -137,7 +180,7 @@ class Saver(object):
builder=None,
defer_build=False,
allow_empty=False,
- write_version=tf.train.SaverDef.V2,
+ write_version=tf.compat.v1.train.SaverDef.V2,
pad_step_number=False,
save_relative_paths=False,
filename=None):
@@ -159,6 +202,7 @@ class Saver(object):
self.checkpoints = []
def restore(self, sess, save_path):
+ print('Restoring from {}...'.format(save_path))
if save_path.endswith('.ckpt'):
load_snapshot(save_path, session=sess, var_list=self.var_list, reshape=self.reshape)
elif save_path.endswith('.hdf5'):
diff --git a/tflex_graph_editor/__init__.py b/tflex_graph_editor/__init__.py
new file mode 100644
index 0000000..5434d4f
--- /dev/null
+++ b/tflex_graph_editor/__init__.py
@@ -0,0 +1,41 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TensorFlow Graph Editor."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=wildcard-import
+from tflex_graph_editor.edit import *
+from tflex_graph_editor.reroute import *
+from tflex_graph_editor.select import *
+from tflex_graph_editor.subgraph import *
+from tflex_graph_editor.transform import *
+from tflex_graph_editor.util import *
+# pylint: enable=wildcard-import
+
+# some useful aliases
+# pylint: disable=g-bad-import-order
+from tflex_graph_editor import subgraph as _subgraph
+from tflex_graph_editor import util as _util
+# pylint: enable=g-bad-import-order
+ph = _util.make_placeholder_from_dtype_and_shape
+sgv = _subgraph.make_view
+sgv_scope = _subgraph.make_view_from_scope
+
+del absolute_import
+del division
+del print_function
diff --git a/tflex_graph_editor/edit.py b/tflex_graph_editor/edit.py
new file mode 100644
index 0000000..411a86a
--- /dev/null
+++ b/tflex_graph_editor/edit.py
@@ -0,0 +1,221 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Various function for graph editing."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tflex_graph_editor import reroute
+from tflex_graph_editor import select
+from tflex_graph_editor import subgraph
+from tflex_graph_editor import util
+from tensorflow.python.ops import array_ops as tf_array_ops
+
+__all__ = [
+ "detach_control_inputs",
+ "detach_control_outputs",
+ "detach_inputs",
+ "detach_outputs",
+ "detach",
+ "connect",
+ "bypass",
+]
+
+
+def detach_control_inputs(sgv):
+ """Detach all the external control inputs of the subgraph sgv.
+
+ Args:
+ sgv: the subgraph view to be detached. This argument is converted to a
+ subgraph using the same rules as the function subgraph.make_view.
+ """
+ sgv = subgraph.make_view(sgv)
+ for op in sgv.ops:
+ cops = [cop for cop in op.control_inputs if cop not in sgv.ops]
+ reroute.remove_control_inputs(op, cops)
+
+
+def detach_control_outputs(sgv, control_outputs):
+ """Detach all the external control outputs of the subgraph sgv.
+
+ Args:
+ sgv: the subgraph view to be detached. This argument is converted to a
+ subgraph using the same rules as the function subgraph.make_view.
+ control_outputs: a util.ControlOutputs instance.
+ """
+ if not isinstance(control_outputs, util.ControlOutputs):
+ raise TypeError("Expected a util.ControlOutputs, got: {}",
+ type(control_outputs))
+ control_outputs.update()
+ sgv = subgraph.make_view(sgv)
+ for op in sgv.ops:
+ for cop in control_outputs.get(op):
+ if cop not in sgv.ops:
+ reroute.remove_control_inputs(cop, op)
+
+
+def detach_inputs(sgv, control_inputs=False):
+ """Detach the inputs of a subgraph view.
+
+ Args:
+ sgv: the subgraph view to be detached. This argument is converted to a
+ subgraph using the same rules as the function subgraph.make_view.
+ Note that sgv is modified in place.
+ control_inputs: if True control_inputs are also detached.
+ Returns:
+ A tuple `(sgv, input_placeholders)` where
+ `sgv` is a new subgraph view of the detached subgraph;
+ `input_placeholders` is a list of the created input placeholders.
+ Raises:
+ StandardError: if sgv cannot be converted to a SubGraphView using
+ the same rules than the function subgraph.make_view.
+ """
+ sgv = subgraph.make_view(sgv)
+
+ with sgv.graph.as_default():
+ input_placeholders = [
+ tf_array_ops.placeholder(
+ dtype=input_t.dtype, name=util.placeholder_name(input_t))
+ for input_t in sgv.inputs
+ ]
+
+ reroute.swap_inputs(sgv, input_placeholders)
+ if control_inputs:
+ detach_control_inputs(sgv)
+ return sgv, input_placeholders
+
+
+def detach_outputs(sgv, control_outputs=None):
+ """Detach the output of a subgraph view.
+
+ Args:
+ sgv: the subgraph view to be detached. This argument is converted to a
+ subgraph using the same rules as the function subgraph.make_view.
+ Note that sgv is modified in place.
+ control_outputs: a util.ControlOutputs instance or None. If not None the
+ control outputs are also detached.
+ Returns:
+ A tuple `(sgv, output_placeholders)` where
+ `sgv` is a new subgraph view of the detached subgraph;
+ `output_placeholders` is a list of the created output placeholders.
+ Raises:
+ StandardError: if sgv cannot be converted to a SubGraphView using
+ the same rules than the function subgraph.make_view.
+ """
+ sgv = subgraph.make_view(sgv)
+ # only select outputs with consumers
+ sgv_ = sgv.remap_outputs([output_id
+ for output_id, output_t in enumerate(sgv.outputs)
+ if output_t.consumers()])
+ # create consumer subgraph and remap
+ consumers_sgv = subgraph.SubGraphView(sgv_.consumers())
+ consumers_sgv = consumers_sgv.remap_inputs(
+ [input_id for input_id, input_t in enumerate(consumers_sgv.inputs)
+ if input_t in sgv_.outputs])
+
+ with sgv_.graph.as_default():
+ output_placeholders = [
+ util.make_placeholder_from_tensor(input_t)
+ for input_t in consumers_sgv.inputs
+ ]
+
+ reroute.swap_outputs(sgv_, output_placeholders)
+ if control_outputs is not None:
+ detach_control_outputs(sgv_, control_outputs)
+ return sgv_, output_placeholders
+
+
+def detach(sgv, control_inputs=False, control_outputs=None, control_ios=None):
+ """Detach both the inputs and the outputs of a subgraph view.
+
+ Args:
+ sgv: the subgraph view to be detached. This argument is converted to a
+ subgraph using the same rules as the function subgraph.make_view.
+ Note that sgv is modified in place.
+ control_inputs: A boolean indicating whether control inputs are enabled.
+ control_outputs: An instance of util.ControlOutputs or None. If not None,
+ control outputs are enabled.
+ control_ios: An instance of util.ControlOutputs or None. If not None, both
+ control inputs and control outputs are enabled. This is equivalent to set
+ control_inputs to True and control_outputs to the util.ControlOutputs
+ instance.
+ Returns:
+ A tuple `(sgv, detached_inputs, detached_outputs)` where:
+ `sgv` is a new subgraph view of the detached subgraph;
+ `detach_inputs` is a list of the created input placeholders;
+ `detach_outputs` is a list of the created output placeholders.
+ Raises:
+ StandardError: if sgv cannot be converted to a SubGraphView using
+ the same rules than the function subgraph.make_view.
+ """
+ control_inputs, control_outputs = select.check_cios(control_inputs,
+ control_outputs,
+ control_ios)
+ _, detached_inputs = detach_inputs(sgv, control_inputs)
+ _, detached_outputs = detach_outputs(sgv, control_outputs)
+ return sgv, detached_inputs, detached_outputs
+
+
+def connect(sgv0, sgv1, disconnect_first=False):
+ """Connect the outputs of sgv0 to the inputs of sgv1.
+
+ Args:
+ sgv0: the first subgraph to have its outputs swapped. This argument is
+ converted to a subgraph using the same rules as the function
+ subgraph.make_view.
+ Note that sgv0 is modified in place.
+ sgv1: the second subgraph to have its outputs swapped. This argument is
+ converted to a subgraph using the same rules as the function
+ subgraph.make_view.
+ Note that sgv1 is modified in place.
+ disconnect_first: if True the current outputs of sgv0 are disconnected.
+ Returns:
+ A tuple `(sgv0, sgv1)` of the now connected subgraphs.
+ Raises:
+ StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using
+ the same rules than the function subgraph.make_view.
+ """
+ sgv0 = subgraph.make_view(sgv0)
+ sgv1 = subgraph.make_view(sgv1)
+ util.check_graphs(sgv0, sgv1)
+ if disconnect_first:
+ detach_outputs(sgv0)
+ sgv0_outputs = subgraph.SubGraphView(passthrough_ts=sgv0.outputs)
+ reroute.reroute_inputs(sgv0_outputs, sgv1)
+ return sgv0, sgv1
+
+
+def bypass(sgv):
+ """Bypass the given subgraph by connecting its inputs to its outputs.
+
+ Args:
+ sgv: the subgraph view to be bypassed. This argument is converted to a
+ subgraph using the same rules than the function subgraph.make_view.
+ Note that sgv is modified in place.
+ Returns:
+ A tuple `(sgv, detached_inputs)` where:
+ `sgv` is a new subgraph view of the bypassed subgraph;
+ `detached_inputs` is a list of the created input placeholders.
+ Raises:
+ StandardError: if sgv cannot be converted to a SubGraphView using
+ the same rules than the function subgraph.make_view.
+ """
+ # TODO(fkp): allows to plug sgv.inputs to individual sgv.outputs consumers
+ sgv = subgraph.make_view(sgv)
+ sgv_inputs = list(sgv.inputs)
+ sgv, detached_inputs = detach_inputs(sgv)
+ reroute.reroute_ts(sgv_inputs, sgv.outputs)
+ return sgv, detached_inputs
diff --git a/tflex_graph_editor/reroute.py b/tflex_graph_editor/reroute.py
new file mode 100644
index 0000000..998991f
--- /dev/null
+++ b/tflex_graph_editor/reroute.py
@@ -0,0 +1,502 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Various function for graph rerouting."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tflex_graph_editor import subgraph as _subgraph
+from tflex_graph_editor import util as _util
+from tensorflow.python.framework import ops as _tf_ops
+
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = [
+ "swap_ts",
+ "reroute_ts",
+ "swap_inputs",
+ "reroute_inputs",
+ "swap_outputs",
+ "reroute_outputs",
+ "swap_ios",
+ "reroute_ios",
+ "remove_control_inputs",
+ "add_control_inputs",
+]
+
+
+def _check_ts_compatibility(ts0, ts1):
+ """Make sure the shape and dtype of the two tensor's lists are compatible.
+
+ Args:
+ ts0: an object convertible to a list of `tf.Tensor`.
+ ts1: an object convertible to a list of `tf.Tensor`.
+ Raises:
+ ValueError: if any pair of tensors (same index in ts0 and ts1) have
+ a dtype or a shape which is not compatible.
+ """
+ ts0 = _util.make_list_of_t(ts0)
+ ts1 = _util.make_list_of_t(ts1)
+ if len(ts0) != len(ts1):
+ raise ValueError("ts0 and ts1 have different sizes: {} != {}".format(
+ len(ts0), len(ts1)))
+ for t0, t1 in zip(ts0, ts1):
+ # check dtype
+ dtype0, dtype1 = t0.dtype, t1.dtype
+ if not dtype0.is_compatible_with(dtype1):
+ raise ValueError("Dtypes {} and {} are not compatible.".format(dtype0,
+ dtype1))
+ # check shape
+ shape0, shape1 = t0.get_shape(), t1.get_shape()
+ if not shape0.is_compatible_with(shape1):
+ raise ValueError("Shapes {} and {} are not compatible.".format(shape0,
+ shape1))
+
+
+class _RerouteMode(object):
+ """Enums for reroute's mode.
+
+ swap: the end of tensors a and b are swapped.
+ a2b: the end of the tensor a are also rerouted to the end of the tensor b
+ (the end of b is left dangling).
+ b2a: the end of the tensor b are also rerouted to the end of the tensor a
+ (the end of a is left dangling).
+ """
+ swap, a2b, b2a = range(3)
+
+ @classmethod
+ def check(cls, mode):
+ """Check swap mode.
+
+ Args:
+ mode: an integer representing one of the modes.
+ Returns:
+ A tuple `(a2b, b2a)` boolean indicating what rerouting needs doing.
+ Raises:
+ ValueError: if mode is outside the enum range.
+ """
+ if mode == cls.swap:
+ return True, True
+ elif mode == cls.b2a:
+ return False, True
+ elif mode == cls.a2b:
+ return True, False
+ else:
+ raise ValueError("Unknown _RerouteMode: {}".format(mode))
+
+
+def _reroute_t(t0, t1, consumers1, can_modify=None, cannot_modify=None):
+ """Reroute the end of the tensors (t0,t1).
+
+ Warning: this function is directly manipulating the internals of the
+ `tf.Graph`.
+
+ Args:
+ t0: a tf.Tensor.
+ t1: a tf.Tensor.
+ consumers1: The consumers of t1 which needs to be rerouted.
+ can_modify: iterable of operations which can be modified. Any operation
+ outside within_ops will be left untouched by this function.
+ cannot_modify: iterable of operations which cannot be modified.
+ Any operation within cannot_modify will be left untouched by this
+ function.
+ Returns:
+ The number of individual modifications made by the function.
+ """
+ nb_update_inputs = 0
+ if can_modify is not None:
+ consumers1 &= can_modify
+ if cannot_modify is not None:
+ consumers1 -= cannot_modify
+ consumers1_indices = {}
+ for consumer1 in consumers1:
+ consumers1_indices[consumer1] = [i for i, t in enumerate(consumer1.inputs)
+ if t is t1]
+ for consumer1 in consumers1:
+ for i in consumers1_indices[consumer1]:
+ consumer1._update_input(i, t0) # pylint: disable=protected-access
+ nb_update_inputs += 1
+ return nb_update_inputs
+
+
+def _reroute_ts(ts0, ts1, mode, can_modify=None, cannot_modify=None):
+ """Reroute the end of the tensors in each pair (t0,t1) in ts0 x ts1.
+
+ This function is the back-bone of the Graph-Editor. It is essentially a thin
+ wrapper on top of the tf.Operation._update_input.
+
+ Given a pair of tensor t0, t1 in ts0 x ts1, this function re-route the end
+ of t0 and t1 in three possible ways:
+ 1) The reroute mode is "a<->b" or "b<->a": the tensors' end are swapped. After
+ this operation, the previous consumers of t0 are now consumers of t1 and
+ vice-versa.
+ 2) The reroute mode is "a->b": the tensors' end of t0 are re-routed to the
+ tensors's end of t1 (which are left dangling). After this operation, the
+ previous consumers of t0 are still consuming t0 but the previous consumers of
+ t1 are not also consuming t0. The tensor t1 has no consumer.
+ 3) The reroute mode is "b->a": this mode is the symmetric of the "a->b" mode.
+
+ Note that this function is re-routing the end of two tensors, not the start.
+ Re-routing the start of two tensors is not supported by this library. The
+ reason for that is the following: TensorFlow, by design, creates a strong bond
+ between an op and its output tensor. This Graph editor follows this design and
+ treats an operation A and its generating tensors {t_i} as an entity which
+ cannot be broken. In other words, an op cannot be detached from any of its
+ output tensors, ever. But it is possible to detach an op from its input
+ tensors, which is what this function concerns itself with.
+
+ Warning: this function is directly manipulating the internals of the tf.Graph.
+
+ Args:
+ ts0: an object convertible to a list of `tf.Tensor`.
+ ts1: an object convertible to a list of `tf.Tensor`.
+ mode: what to do with those tensors: "a->b" or "b<->a" for swaping and
+ "a->b" or "b->a" for one direction re-routing.
+ can_modify: iterable of operations which can be modified. Any operation
+ outside within_ops will be left untouched by this function.
+ cannot_modify: iterable of operations which cannot be modified.
+ Any operation within cannot_modify will be left untouched by this
+ function.
+ Returns:
+ The number of individual modifications made by the function.
+ Raises:
+ TypeError: if `ts0` or `ts1` cannot be converted to a list of `tf.Tensor`.
+ TypeError: if `can_modify` or `cannot_modify` is not `None` and cannot be
+ converted to a list of `tf.Operation`.
+ """
+ a2b, b2a = _RerouteMode.check(mode)
+ ts0 = _util.make_list_of_t(ts0)
+ ts1 = _util.make_list_of_t(ts1)
+ _check_ts_compatibility(ts0, ts1)
+ if cannot_modify is not None:
+ cannot_modify = frozenset(_util.make_list_of_op(cannot_modify))
+ if can_modify is not None:
+ can_modify = frozenset(_util.make_list_of_op(can_modify))
+ nb_update_inputs = 0
+ precomputed_consumers = []
+ # precompute consumers to avoid issue with repeated tensors:
+ for t0, t1 in zip(ts0, ts1):
+ consumers0 = set(t0.consumers())
+ consumers1 = set(t1.consumers())
+ precomputed_consumers.append((consumers0, consumers1))
+ for t0, t1, consumers in zip(ts0, ts1, precomputed_consumers):
+ if t0 is t1:
+ continue # Silently ignore identical tensors.
+ consumers0, consumers1 = consumers
+ if a2b:
+ nb_update_inputs += _reroute_t(t0, t1, consumers1, can_modify,
+ cannot_modify)
+ if b2a:
+ nb_update_inputs += _reroute_t(t1, t0, consumers0, can_modify,
+ cannot_modify)
+ return nb_update_inputs
+
+
+def swap_ts(ts0, ts1, can_modify=None, cannot_modify=None):
+ """For each tensor's pair, swap the end of (t0,t1).
+
+ B0 B1 B0 B1
+ | | => X
+ A0 A1 A0 A1
+
+ Args:
+ ts0: an object convertible to a list of `tf.Tensor`.
+ ts1: an object convertible to a list of `tf.Tensor`.
+ can_modify: iterable of operations which can be modified. Any operation
+ outside within_ops will be left untouched by this function.
+ cannot_modify: iterable of operations which cannot be modified.
+ Any operation within cannot_modify will be left untouched by this
+ function.
+ Returns:
+ The number of individual modifications made by the function.
+ Raises:
+ TypeError: if ts0 or ts1 cannot be converted to a list of tf.Tensor.
+ TypeError: if can_modify or cannot_modify is not None and cannot be
+ converted to a list of tf.Operation.
+ """
+ return _reroute_ts(ts0, ts1, _RerouteMode.swap, can_modify, cannot_modify)
+
+
+def reroute_ts(ts0, ts1, can_modify=None, cannot_modify=None):
+ """For each tensor's pair, replace the end of t1 by the end of t0.
+
+ B0 B1 B0 B1
+ | | => |/
+ A0 A1 A0 A1
+
+ The end of the tensors in ts1 are left dangling.
+
+ Args:
+ ts0: an object convertible to a list of `tf.Tensor`.
+ ts1: an object convertible to a list of `tf.Tensor`.
+ can_modify: iterable of operations which can be modified. Any operation
+ outside within_ops will be left untouched by this function.
+ cannot_modify: iterable of operations which cannot be modified. Any
+ operation within cannot_modify will be left untouched by this function.
+ Returns:
+ The number of individual modifications made by the function.
+ Raises:
+ TypeError: if ts0 or ts1 cannot be converted to a list of tf.Tensor.
+ TypeError: if can_modify or cannot_modify is not None and cannot be
+ converted to a list of tf.Operation.
+ """
+ return _reroute_ts(ts0, ts1, _RerouteMode.a2b, can_modify, cannot_modify)
+
+
+def _reroute_sgv_remap(sgv0, sgv1, mode):
+ """Remap in place the inputs of two subgraph views to mimic the reroute.
+
+ This function is meant to used by reroute_inputs only.
+
+ Args:
+ sgv0: the first subgraph to have its inputs remapped.
+ sgv1: the second subgraph to have its inputs remapped.
+ mode: reroute mode, see _reroute_ts(...).
+ Raises:
+ TypeError: if svg0 or svg1 are not SubGraphView.
+ ValueError: if sgv0 and sgv1 do not belong to the same graph.
+ """
+ a2b, b2a = _RerouteMode.check(mode)
+ if not isinstance(sgv0, _subgraph.SubGraphView):
+ raise TypeError("Expected a SubGraphView, got {}".format(type(sgv0)))
+ if not isinstance(sgv1, _subgraph.SubGraphView):
+ raise TypeError("Expected a SubGraphView, got {}".format(type(sgv1)))
+ _util.check_graphs(sgv0, sgv1)
+ sgv0_ = sgv0.copy()
+ sgv1_ = sgv1.copy()
+ # pylint: disable=protected-access
+ if a2b and b2a:
+ (sgv0_._input_ts, sgv1_._input_ts) = (sgv1_._input_ts, sgv0_._input_ts)
+ (sgv0_._passthrough_ts, sgv1_._passthrough_ts) = (sgv1_._passthrough_ts,
+ sgv0_._passthrough_ts)
+ elif a2b:
+ sgv1_._input_ts = sgv0_._input_ts[:]
+ sgv1_._passthrough_ts = sgv0_._passthrough_ts[:]
+ elif b2a:
+ sgv0_._input_ts = sgv1_._input_ts[:]
+ sgv0_._passthrough_ts = sgv1_._passthrough_ts[:]
+ # pylint: enable=protected-access
+
+ # Update the passthrough outputs as well.
+ def update_passthrough_outputs(a, b):
+ # pylint: disable=protected-access
+ for i, t in enumerate(b._output_ts):
+ if t in a._passthrough_ts:
+ ii = a._input_ts.index(t)
+ b._output_ts[i] = b._input_ts[ii]
+ # pylint: enable=protected-access
+
+ if a2b:
+ update_passthrough_outputs(sgv0_, sgv1_)
+ if b2a:
+ update_passthrough_outputs(sgv1_, sgv0_)
+
+ # in-place
+ # pylint: disable=protected-access
+ sgv0._assign_from(sgv0_)
+ sgv1._assign_from(sgv1_)
+ # pylint: enable=protected-access
+
+
+def _reroute_sgv_inputs(sgv0, sgv1, mode):
+ """Re-route all the inputs of two subgraphs.
+
+ Args:
+ sgv0: the first subgraph to have its inputs swapped. This argument is
+ converted to a subgraph using the same rules than the function
+ subgraph.make_view.
+ sgv1: the second subgraph to have its inputs swapped. This argument is
+ converted to a subgraph using the same rules than the function
+ subgraph.make_view.
+ mode: reroute mode, see _reroute_ts(...).
+ Returns:
+ A tuple `(sgv0, sgv1)` of subgraph views with their inputs swapped.
+ Note that the function argument sgv0 and sgv1 are also modified in place.
+ Raises:
+ StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using
+ the same rules than the function subgraph.make_view.
+ """
+ sgv0 = _subgraph.make_view(sgv0)
+ sgv1 = _subgraph.make_view(sgv1)
+ _util.check_graphs(sgv0, sgv1)
+ can_modify = sgv0.ops + sgv1.ops
+ # also allow consumers of passthrough to be modified:
+ can_modify += _util.get_consuming_ops(sgv0.passthroughs)
+ can_modify += _util.get_consuming_ops(sgv1.passthroughs)
+ _reroute_ts(sgv0.inputs, sgv1.inputs, mode, can_modify=can_modify)
+ _reroute_sgv_remap(sgv0, sgv1, mode)
+ return sgv0, sgv1
+
+
+def _reroute_sgv_outputs(sgv0, sgv1, mode):
+ """Re-route all the outputs of two operations.
+
+ Args:
+ sgv0: the first subgraph to have its outputs swapped. This argument is
+ converted to a subgraph using the same rules than the function
+ subgraph.make_view.
+ sgv1: the second subgraph to have its outputs swapped. This argument is
+ converted to a subgraph using the same rules than the function
+ subgraph.make_view.
+ mode: reroute mode, see _reroute_ts(...).
+ Returns:
+ A tuple `(sgv0, sgv1)` of subgraph views with their outputs swapped.
+ Note that the function argument sgv0 and sgv1 are also modified in place.
+ Raises:
+ StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using
+ the same rules than the function subgraph.make_view.
+ """
+ sgv0 = _subgraph.make_view(sgv0)
+ sgv1 = _subgraph.make_view(sgv1)
+ _util.check_graphs(sgv0, sgv1)
+ cannot_modify = sgv0.ops + sgv1.ops
+ _reroute_ts(sgv0.outputs, sgv1.outputs, mode, cannot_modify=cannot_modify)
+ return sgv0, sgv1
+
+
+def _reroute_sgv(sgv0, sgv1, mode):
+ """Re-route both the inputs and the outputs of the two subgraph views.
+
+ This involves swapping all the inputs/outputs of the two subgraph views.
+
+ Args:
+ sgv0: the first subgraph to be swapped. This argument is converted to a
+ subgraph using the same rules than the function subgraph.make_view.
+ sgv1: the second subgraph to be swapped. This argument is converted to a
+ subgraph using the same rules than the function subgraph.make_view.
+ mode: reroute mode, see _reroute_ts(...).
+ Returns:
+ A tuple `(sgv0, sgv1)` of subgraph views with their outputs and inputs
+ swapped.
+ Note that the function argument sgv0 and sgv1 are also modified in place.
+ Raises:
+ StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using
+ the same rules than the function subgraph.make_view.
+ """
+ _reroute_sgv_outputs(sgv0, sgv1, mode)
+ _reroute_sgv_inputs(sgv0, sgv1, mode)
+ return sgv0, sgv1
+
+
+def swap_inputs(sgv0, sgv1):
+ """Swap all the inputs of sgv0 and sgv1 (see reroute_inputs)."""
+ return _reroute_sgv_inputs(sgv0, sgv1, _RerouteMode.swap)
+
+
+def reroute_inputs(sgv0, sgv1):
+ """Re-route all the inputs of two subgraphs.
+
+ Args:
+ sgv0: the first subgraph to have its inputs swapped. This argument is
+ converted to a subgraph using the same rules than the function
+ subgraph.make_view.
+ sgv1: the second subgraph to have its inputs swapped. This argument is
+ converted to a subgraph using the same rules than the function
+ subgraph.make_view.
+ Returns:
+ A tuple `(sgv0, sgv1)` of subgraph views with their inputs swapped.
+ Note that the function argument sgv0 and sgv1 are also modified in place.
+ Raises:
+ StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using
+ the same rules than the function subgraph.make_view.
+ """
+ return _reroute_sgv_inputs(sgv0, sgv1, _RerouteMode.a2b)
+
+
+def swap_outputs(sgv0, sgv1):
+ """Swap all the outputs of sgv0 and sgv1 (see reroute_outputs)."""
+ return _reroute_sgv_outputs(sgv0, sgv1, _RerouteMode.swap)
+
+
+def reroute_outputs(sgv0, sgv1):
+ """Re-route all the outputs of two operations.
+
+ Args:
+ sgv0: the first subgraph to have its outputs swapped. This argument is
+ converted to a subgraph using the same rules than the function
+ subgraph.make_view.
+ sgv1: the second subgraph to have its outputs swapped. This argument is
+ converted to a subgraph using the same rules than the function
+ subgraph.make_view.
+ Returns:
+ A tuple `(sgv0, sgv1)` of subgraph views with their outputs swapped.
+ Note that the function argument sgv0 and sgv1 are also modified in place.
+ Raises:
+ StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using
+ the same rules than the function subgraph.make_view.
+ """
+ return _reroute_sgv_outputs(sgv0, sgv1, _RerouteMode.a2b)
+
+
+def swap_ios(sgv0, sgv1):
+ """Swap the inputs and outputs of sgv1 to sgv0 (see _reroute_sgv)."""
+ return _reroute_sgv(sgv0, sgv1, _RerouteMode.swap)
+
+
+def reroute_ios(sgv0, sgv1):
+ """Re-route the inputs and outputs of sgv0 to sgv1 (see _reroute_sgv)."""
+ return _reroute_sgv(sgv0, sgv1, _RerouteMode.a2b)
+
+
+def remove_control_inputs(op, cops):
+ """Remove the control inputs cops from co.
+
+ Warning: this function is directly manipulating the internals of the
+ `tf.Graph`.
+
+ Args:
+ op: a `tf.Operation` from which to remove the control inputs.
+ cops: an object convertible to a list of `tf.Operation`.
+ Raises:
+ TypeError: if op is not a `tf.Operation`.
+ ValueError: if any cop in cops is not a control input of op.
+ """
+ if not isinstance(op, _tf_ops.Operation):
+ raise TypeError("Expected a tf.Operation, got: {}", type(op))
+ cops = _util.make_list_of_op(cops, allow_graph=False)
+ for cop in cops:
+ if cop not in op.control_inputs:
+ raise ValueError("{} is not a control_input of {}".format(op.name,
+ cop.name))
+ control_inputs = [cop for cop in op.control_inputs if cop not in cops]
+ # pylint: disable=protected-access
+ op._remove_all_control_inputs()
+ op._add_control_inputs(control_inputs)
+ # pylint: enable=protected-access
+
+
+def add_control_inputs(op, cops):
+ """Add the control inputs cops to op.
+
+ Warning: this function is directly manipulating the internals of the tf.Graph.
+
+ Args:
+ op: a tf.Operation to which the control inputs are added.
+ cops: an object convertible to a list of `tf.Operation`.
+ Raises:
+ TypeError: if op is not a tf.Operation
+ ValueError: if any cop in cops is already a control input of op.
+ """
+ if not isinstance(op, _tf_ops.Operation):
+ raise TypeError("Expected a tf.Operation, got: {}", type(op))
+ cops = _util.make_list_of_op(cops, allow_graph=False)
+ for cop in cops:
+ if cop in op.control_inputs:
+ raise ValueError("{} is already a control_input of {}".format(cop.name,
+ op.name))
+ op._add_control_inputs(cops) # pylint: disable=protected-access
+
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tflex_graph_editor/select.py b/tflex_graph_editor/select.py
new file mode 100644
index 0000000..30d83de
--- /dev/null
+++ b/tflex_graph_editor/select.py
@@ -0,0 +1,773 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Various ways of selecting operations and tensors in a graph."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import re
+
+from six import iteritems
+from six import string_types
+
+from tflex_graph_editor import util
+from tensorflow.python.ops import op_selector
+from tensorflow.python.framework import ops as tf_ops
+from tensorflow.python.util import deprecation
+
+
+
+__all__ = [
+ "can_be_regex",
+ "make_regex",
+ "filter_ts",
+ "filter_ts_from_regex",
+ "filter_ops",
+ "filter_ops_from_regex",
+ "get_name_scope_ops",
+ "check_cios",
+ "get_ops_ios",
+ "compute_boundary_ts",
+ "get_within_boundary_ops",
+ "get_forward_walk_ops",
+ "get_backward_walk_ops",
+ "get_walks_intersection_ops",
+ "get_walks_union_ops",
+ "select_ops",
+ "select_ts",
+ "select_ops_and_ts",
+]
+
+_RE_TYPE = type(re.compile(""))
+
+
+def can_be_regex(obj):
+ """Return True if obj can be turned into a regular expression."""
+ return isinstance(obj, string_types + (_RE_TYPE,))
+
+
+def make_regex(obj):
+ """Return a compiled regular expression.
+
+ Args:
+ obj: a string or a regular expression.
+ Returns:
+ A compiled regular expression.
+ Raises:
+ ValueError: if obj could not be converted to a regular expression.
+ """
+ if not can_be_regex(obj):
+ raise ValueError("Expected a string or a regex, got: {}".format(type(obj)))
+
+ if isinstance(obj, string_types):
+ return re.compile(obj)
+ else:
+ return obj
+
+
+def _get_input_ts(ops):
+ """Compute the list of unique input tensors of all the op in ops.
+
+ Args:
+ ops: an object convertible to a list of `tf.Operation`.
+ Returns:
+ The list of unique input tensors of all the op in ops.
+ Raises:
+ TypeError: if ops cannot be converted to a list of `tf.Operation`.
+ """
+ ops = util.make_list_of_op(ops)
+ ts = []
+ ts_set = set()
+ for op in ops:
+ for t in op.inputs:
+ if t not in ts_set:
+ ts.append(t)
+ ts_set.add(t)
+ return ts
+
+
+def _get_output_ts(ops):
+ """Compute the list of unique output tensors of all the op in ops.
+
+ Args:
+ ops: an object convertible to a list of tf.Operation.
+ Returns:
+ The list of unique output tensors of all the op in ops.
+ Raises:
+ TypeError: if ops cannot be converted to a list of tf.Operation.
+ """
+ ops = util.make_list_of_op(ops)
+ ts = []
+ for op in ops:
+ ts += op.outputs
+ return ts
+
+
+def filter_ts(ops, positive_filter):
+ """Get all the tensors which are input or output of an op in ops.
+
+ Args:
+ ops: an object convertible to a list of `tf.Operation`.
+ positive_filter: a function deciding whether to keep a tensor or not.
+ If `True`, all the tensors are returned.
+ Returns:
+ A list of `tf.Tensor`.
+ Raises:
+ TypeError: if ops cannot be converted to a list of `tf.Operation`.
+ """
+ ops = util.make_list_of_op(ops)
+ ts = _get_input_ts(ops)
+ util.concatenate_unique(ts, _get_output_ts(ops))
+ if positive_filter is not True:
+ ts = [t for t in ts if positive_filter(t)]
+ return ts
+
+
+def filter_ts_from_regex(ops, regex):
+ r"""Get all the tensors linked to ops that match the given regex.
+
+ Args:
+ ops: an object convertible to a list of tf.Operation.
+ regex: a regular expression matching the tensors' name.
+ For example, "^foo(/.*)?:\d+$" will match all the tensors in the "foo"
+ scope.
+ Returns:
+ A list of tf.Tensor.
+ Raises:
+ TypeError: if ops cannot be converted to a list of tf.Operation.
+ """
+ ops = util.make_list_of_op(ops)
+ regex_obj = make_regex(regex)
+ return filter_ts(ops, positive_filter=lambda op: regex_obj.search(op.name))
+
+
+def filter_ops(ops, positive_filter):
+ """Get the ops passing the given filter.
+
+ Args:
+ ops: an object convertible to a list of tf.Operation.
+ positive_filter: a function deciding where to keep an operation or not.
+ If True, all the operations are returned.
+ Returns:
+ A list of selected tf.Operation.
+ Raises:
+ TypeError: if ops cannot be converted to a list of tf.Operation.
+ """
+ ops = util.make_list_of_op(ops)
+ if positive_filter is not True: # pylint: disable=g-explicit-bool-comparison
+ ops = [op for op in ops if positive_filter(op)]
+ return ops
+
+
+def filter_ops_from_regex(ops, regex):
+ """Get all the operations that match the given regex.
+
+ Args:
+ ops: an object convertible to a list of `tf.Operation`.
+ regex: a regular expression matching the operation's name.
+ For example, `"^foo(/.*)?$"` will match all the operations in the "foo"
+ scope.
+ Returns:
+ A list of `tf.Operation`.
+ Raises:
+ TypeError: if ops cannot be converted to a list of `tf.Operation`.
+ """
+ ops = util.make_list_of_op(ops)
+ regex_obj = make_regex(regex)
+ return filter_ops(ops, lambda op: regex_obj.search(op.name))
+
+
+def get_name_scope_ops(ops, scope):
+ """Get all the operations under the given scope path.
+
+ Args:
+ ops: an object convertible to a list of tf.Operation.
+ scope: a scope path.
+ Returns:
+ A list of tf.Operation.
+ Raises:
+ TypeError: if ops cannot be converted to a list of tf.Operation.
+ """
+ if scope and scope[-1] == "/":
+ scope = scope[:-1]
+ return filter_ops_from_regex(ops, "^{}(/.*)?$".format(scope))
+
+
+def check_cios(control_inputs=False, control_outputs=None, control_ios=None):
+ """Do various check on control_inputs and control_outputs.
+
+ Args:
+ control_inputs: A boolean indicating whether control inputs are enabled.
+ control_outputs: An instance of util.ControlOutputs or None. If not None,
+ control outputs are enabled.
+ control_ios: An instance of util.ControlOutputs or None. If not None, both
+ control inputs and control outputs are enabled. This is equivalent to set
+ control_inputs to True and control_outputs to the util.ControlOutputs
+ instance.
+ Returns:
+ A tuple `(control_inputs, control_outputs)` where:
+ `control_inputs` is a boolean indicating whether to use control inputs.
+ `control_outputs` is an instance of util.ControlOutputs or None
+ Raises:
+ ValueError: if control_inputs is an instance of util.ControlOutputs but
+ control_outputs is not None
+ TypeError: if control_outputs is not None and is not a util.ControlOutputs.
+ """
+ if control_ios is not None:
+ if not isinstance(control_ios, util.ControlOutputs):
+ raise TypeError("Expected a util.ControlOutputs, got: {}".format(
+ type(control_ios)))
+ if control_outputs is not None:
+ raise ValueError("control_outputs should be None when using control_ios.")
+ control_inputs = True
+ control_outputs = control_ios
+ elif control_outputs is not None:
+ if not isinstance(control_outputs, util.ControlOutputs):
+ raise TypeError("Expected a util.ControlOutputs, got: {}".format(
+ type(control_outputs)))
+
+ if control_outputs is not None:
+ control_outputs.update()
+ return control_inputs, control_outputs
+
+
+def get_ops_ios(ops, control_inputs=False, control_outputs=None,
+ control_ios=None):
+ """Return all the `tf.Operation` which are connected to an op in ops.
+
+ Args:
+ ops: an object convertible to a list of `tf.Operation`.
+ control_inputs: A boolean indicating whether control inputs are enabled.
+ control_outputs: An instance of `util.ControlOutputs` or `None`. If not
+ `None`, control outputs are enabled.
+ control_ios: An instance of `util.ControlOutputs` or `None`. If not `None`,
+ both control inputs and control outputs are enabled. This is equivalent to
+ set `control_inputs` to `True` and `control_outputs` to the
+ `util.ControlOutputs` instance.
+ Returns:
+ All the `tf.Operation` surrounding the given ops.
+ Raises:
+ TypeError: if `ops` cannot be converted to a list of `tf.Operation`.
+ """
+ control_inputs, control_outputs = check_cios(control_inputs, control_outputs,
+ control_ios)
+ ops = util.make_list_of_op(ops)
+ res = []
+ for op in ops:
+ util.concatenate_unique(res, [t.op for t in op.inputs])
+ for t in op.outputs:
+ util.concatenate_unique(res, t.consumers())
+ if control_outputs is not None:
+ util.concatenate_unique(res, control_outputs.get(op))
+ if control_inputs:
+ util.concatenate_unique(res, op.control_inputs)
+ return res
+
+
+def compute_boundary_ts(ops):
+ """Compute the tensors at the boundary of a set of ops.
+
+ This function looks at all the tensors connected to the given ops (in/out)
+ and classify them into three categories:
+ 1) input tensors: tensors whose generating operation is not in ops.
+ 2) output tensors: tensors whose consumer operations are not in ops
+ 3) inside tensors: tensors which are neither input nor output tensors.
+
+ Note that a tensor can be both an inside tensor and an output tensor if it is
+ consumed by operations both outside and inside of `ops`.
+
+ Args:
+ ops: an object convertible to a list of tf.Operation.
+ Returns:
+ A tuple `(outside_input_ts, outside_output_ts, inside_ts)` where:
+ `outside_input_ts` is a Python list of input tensors;
+ `outside_output_ts` is a python list of output tensors;
+ `inside_ts` is a python list of inside tensors.
+ Since a tensor can be both an inside tensor and an output tensor,
+ `outside_output_ts` and `inside_ts` might intersect.
+ Raises:
+ TypeError: if ops cannot be converted to a list of tf.Operation.
+ """
+ ops = util.make_list_of_op(ops)
+ input_ts = _get_input_ts(ops)
+ output_ts = _get_output_ts(ops)
+ output_ts_set = frozenset(output_ts)
+ ops_set = frozenset(ops)
+
+ # Compute inside tensors.
+ inside_ts = []
+ only_inside_ts = []
+ for t in input_ts:
+ # Skip if the input tensor is not also an output tensor.
+ if t not in output_ts_set:
+ continue
+ # Mark as "inside".
+ inside_ts.append(t)
+ # Mark as "only inside" if the tensor is not both inside and output.
+ consumers = frozenset(t.consumers())
+ if consumers - ops_set:
+ continue
+ only_inside_ts.append(t)
+
+ inside_ts_set = frozenset(inside_ts)
+ only_inside_ts_set = frozenset(only_inside_ts)
+ outside_output_ts = [t for t in output_ts if t not in only_inside_ts_set]
+ outside_input_ts = [t for t in input_ts if t not in inside_ts_set]
+ return outside_input_ts, outside_output_ts, inside_ts
+
+
+def get_within_boundary_ops(ops,
+ seed_ops,
+ boundary_ops=(),
+ inclusive=True,
+ control_inputs=False,
+ control_outputs=None,
+ control_ios=None):
+ """Return all the `tf.Operation` within the given boundary.
+
+ Args:
+ ops: an object convertible to a list of `tf.Operation`. those ops define the
+ set in which to perform the operation (if a `tf.Graph` is given, it
+ will be converted to the list of all its operations).
+ seed_ops: the operations from which to start expanding.
+ boundary_ops: the ops forming the boundary.
+ inclusive: if `True`, the result will also include the boundary ops.
+ control_inputs: A boolean indicating whether control inputs are enabled.
+ control_outputs: An instance of `util.ControlOutputs` or `None`. If not
+ `None`, control outputs are enabled.
+ control_ios: An instance of `util.ControlOutputs` or `None`. If not
+ `None`, both control inputs and control outputs are enabled. This is
+ equivalent to set control_inputs to True and control_outputs to
+ the `util.ControlOutputs` instance.
+ Returns:
+ All the `tf.Operation` surrounding the given ops.
+ Raises:
+ TypeError: if `ops` or `seed_ops` cannot be converted to a list of
+ `tf.Operation`.
+ ValueError: if the boundary is intersecting with the seeds.
+ """
+ control_inputs, control_outputs = check_cios(control_inputs, control_outputs,
+ control_ios)
+ ops = util.make_list_of_op(ops)
+ seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)
+ boundary_ops = set(util.make_list_of_op(boundary_ops))
+ res = set(seed_ops)
+ if boundary_ops & res:
+ raise ValueError("Boundary is intersecting with the seeds.")
+ wave = set(seed_ops)
+ while wave:
+ new_wave = set()
+ ops_io = get_ops_ios(wave, control_inputs, control_outputs)
+ for op in ops_io:
+ if op in res:
+ continue
+ if op in boundary_ops:
+ if inclusive:
+ res.add(op)
+ else:
+ new_wave.add(op)
+ res.update(new_wave)
+ wave = new_wave
+ return [op for op in ops if op in res]
+
+
+def get_forward_walk_ops(seed_ops,
+ inclusive=True,
+ within_ops=None,
+ within_ops_fn=None,
+ stop_at_ts=(),
+ control_outputs=None):
+ """Do a forward graph walk and return all the visited ops.
+
+ Args:
+ seed_ops: an iterable of operations from which the forward graph
+ walk starts. If a list of tensors is given instead, the seed_ops are set
+ to be the consumers of those tensors.
+ inclusive: if True the given seed_ops are also part of the resulting set.
+ within_ops: an iterable of `tf.Operation` within which the search is
+ restricted. If `within_ops` is `None`, the search is performed within
+ the whole graph.
+ within_ops_fn: if provided, a function on ops that should return True iff
+ the op is within the graph traversal. This can be used along within_ops,
+ in which case an op is within if it is also in within_ops.
+ stop_at_ts: an iterable of tensors at which the graph walk stops.
+ control_outputs: a `util.ControlOutputs` instance or None.
+ If not `None`, it will be used while walking the graph forward.
+ Returns:
+ A Python set of all the `tf.Operation` ahead of `seed_ops`.
+ Raises:
+ TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of
+ `tf.Operation`.
+ """
+ _, control_outputs = check_cios(False, control_outputs)
+ if not util.is_iterable(seed_ops):
+ seed_ops = [seed_ops]
+ if not seed_ops:
+ return []
+ if isinstance(seed_ops[0], tf_ops.Tensor):
+ ts = util.make_list_of_t(seed_ops, allow_graph=False)
+ seed_ops = util.get_consuming_ops(ts)
+ else:
+ seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)
+
+ seed_ops = frozenset(seed_ops)
+ stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts))
+ if within_ops:
+ within_ops = util.make_list_of_op(within_ops, allow_graph=False)
+ within_ops = frozenset(within_ops)
+ seed_ops &= within_ops
+
+ def is_within(op):
+ return (within_ops is None or op in within_ops) and (
+ within_ops_fn is None or within_ops_fn(op))
+
+ result = list(seed_ops)
+ wave = set(seed_ops)
+ while wave:
+ new_wave = set()
+ for op in wave:
+ for new_t in op.outputs:
+ if new_t in stop_at_ts:
+ continue
+ for new_op in new_t.consumers():
+ if new_op not in result and is_within(new_op):
+ new_wave.add(new_op)
+ if control_outputs is not None:
+ for new_op in control_outputs.get(op):
+ if new_op not in result and is_within(new_op):
+ new_wave.add(new_op)
+ util.concatenate_unique(result, new_wave)
+ wave = new_wave
+ if not inclusive:
+ result = [op for op in result if op not in seed_ops]
+ return result
+
+
+@deprecation.deprecated(
+ "2019-06-06",
+ "Please use tensorflow.python.ops.op_selector.get_backward_walk_ops.",
+ warn_once=True)
+def get_backward_walk_ops(seed_ops,
+ inclusive=True,
+ within_ops=None,
+ within_ops_fn=None,
+ stop_at_ts=(),
+ control_inputs=False):
+ """Do a backward graph walk and return all the visited ops.
+
+ Args:
+ seed_ops: an iterable of operations from which the backward graph
+ walk starts. If a list of tensors is given instead, the seed_ops are set
+ to be the generators of those tensors.
+ inclusive: if True the given seed_ops are also part of the resulting set.
+ within_ops: an iterable of `tf.Operation` within which the search is
+ restricted. If `within_ops` is `None`, the search is performed within
+ the whole graph.
+ within_ops_fn: if provided, a function on ops that should return True iff
+ the op is within the graph traversal. This can be used along within_ops,
+ in which case an op is within if it is also in within_ops.
+ stop_at_ts: an iterable of tensors at which the graph walk stops.
+ control_inputs: if True, control inputs will be used while moving backward.
+ Returns:
+ A Python set of all the `tf.Operation` behind `seed_ops`.
+ Raises:
+ TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of
+ `tf.Operation`.
+ """
+ return op_selector.get_backward_walk_ops(
+ seed_ops,
+ inclusive=inclusive,
+ within_ops=within_ops,
+ within_ops_fn=within_ops_fn,
+ stop_at_ts=stop_at_ts,
+ control_inputs=control_inputs)
+
+
+def get_walks_intersection_ops(forward_seed_ops,
+ backward_seed_ops,
+ forward_inclusive=True,
+ backward_inclusive=True,
+ within_ops=None,
+ within_ops_fn=None,
+ control_inputs=False,
+ control_outputs=None,
+ control_ios=None):
+ """Return the intersection of a forward and a backward walk.
+
+ Args:
+ forward_seed_ops: an iterable of operations from which the forward graph
+ walk starts. If a list of tensors is given instead, the seed_ops are set
+ to be the consumers of those tensors.
+ backward_seed_ops: an iterable of operations from which the backward graph
+ walk starts. If a list of tensors is given instead, the seed_ops are set
+ to be the generators of those tensors.
+ forward_inclusive: if True the given forward_seed_ops are also part of the
+ resulting set.
+ backward_inclusive: if True the given backward_seed_ops are also part of the
+ resulting set.
+ within_ops: an iterable of tf.Operation within which the search is
+ restricted. If within_ops is None, the search is performed within
+ the whole graph.
+ within_ops_fn: if provided, a function on ops that should return True iff
+ the op is within the graph traversal. This can be used along within_ops,
+ in which case an op is within if it is also in within_ops.
+ control_inputs: A boolean indicating whether control inputs are enabled.
+ control_outputs: An instance of util.ControlOutputs or None. If not None,
+ control outputs are enabled.
+ control_ios: An instance of util.ControlOutputs or None. If not None, both
+ control inputs and control outputs are enabled. This is equivalent to set
+ control_inputs to True and control_outputs to the util.ControlOutputs
+ instance.
+ Returns:
+ A Python set of all the tf.Operation in the intersection of a forward and a
+ backward walk.
+ Raises:
+ TypeError: if `forward_seed_ops` or `backward_seed_ops` or `within_ops`
+ cannot be converted to a list of `tf.Operation`.
+ """
+ control_inputs, control_outputs = check_cios(control_inputs, control_outputs,
+ control_ios)
+ forward_ops = get_forward_walk_ops(
+ forward_seed_ops,
+ inclusive=forward_inclusive,
+ within_ops=within_ops,
+ within_ops_fn=within_ops_fn,
+ control_outputs=control_outputs)
+ backward_ops = get_backward_walk_ops(
+ backward_seed_ops,
+ inclusive=backward_inclusive,
+ within_ops=within_ops,
+ within_ops_fn=within_ops_fn,
+ control_inputs=control_inputs)
+ return [op for op in forward_ops if op in backward_ops]
+
+
+def get_walks_union_ops(forward_seed_ops,
+ backward_seed_ops,
+ forward_inclusive=True,
+ backward_inclusive=True,
+ within_ops=None,
+ within_ops_fn=None,
+ control_inputs=False,
+ control_outputs=None,
+ control_ios=None):
+ """Return the union of a forward and a backward walk.
+
+ Args:
+ forward_seed_ops: an iterable of operations from which the forward graph
+ walk starts. If a list of tensors is given instead, the seed_ops are set
+ to be the consumers of those tensors.
+ backward_seed_ops: an iterable of operations from which the backward graph
+ walk starts. If a list of tensors is given instead, the seed_ops are set
+ to be the generators of those tensors.
+ forward_inclusive: if True the given forward_seed_ops are also part of the
+ resulting set.
+ backward_inclusive: if True the given backward_seed_ops are also part of the
+ resulting set.
+ within_ops: restrict the search within those operations. If within_ops is
+ None, the search is done within the whole graph.
+ within_ops_fn: if provided, a function on ops that should return True iff
+ the op is within the graph traversal. This can be used along within_ops,
+ in which case an op is within if it is also in within_ops.
+ control_inputs: A boolean indicating whether control inputs are enabled.
+ control_outputs: An instance of util.ControlOutputs or None. If not None,
+ control outputs are enabled.
+ control_ios: An instance of util.ControlOutputs or None. If not None, both
+ control inputs and control outputs are enabled. This is equivalent to set
+ control_inputs to True and control_outputs to the util.ControlOutputs
+ instance.
+ Returns:
+ A Python set of all the tf.Operation in the union of a forward and a
+ backward walk.
+ Raises:
+ TypeError: if forward_seed_ops or backward_seed_ops or within_ops cannot be
+ converted to a list of tf.Operation.
+ """
+ control_inputs, control_outputs = check_cios(control_inputs, control_outputs,
+ control_ios)
+ forward_ops = get_forward_walk_ops(
+ forward_seed_ops,
+ inclusive=forward_inclusive,
+ within_ops=within_ops,
+ within_ops_fn=within_ops_fn,
+ control_outputs=control_outputs)
+ backward_ops = get_backward_walk_ops(
+ backward_seed_ops,
+ inclusive=backward_inclusive,
+ within_ops=within_ops,
+ within_ops_fn=within_ops_fn,
+ control_inputs=control_inputs)
+ return util.concatenate_unique(forward_ops, backward_ops)
+
+
+def select_ops(*args, **kwargs):
+ """Helper to select operations.
+
+ Args:
+ *args: list of 1) regular expressions (compiled or not) or 2) (array of)
+ `tf.Operation`. `tf.Tensor` instances are silently ignored.
+ **kwargs: 'graph': `tf.Graph` in which to perform the regex query.This is
+ required when using regex.
+ 'positive_filter': an elem if selected only if `positive_filter(elem)` is
+ `True`. This is optional.
+ 'restrict_ops_regex': a regular expression is ignored if it doesn't start
+ with the substring "(?#ops)".
+ Returns:
+ A list of `tf.Operation`.
+ Raises:
+ TypeError: if the optional keyword argument graph is not a `tf.Graph`
+ or if an argument in args is not an (array of) `tf.Operation`
+ or an (array of) `tf.Tensor` (silently ignored) or a string
+ or a regular expression.
+ ValueError: if one of the keyword arguments is unexpected or if a regular
+ expression is used without passing a graph as a keyword argument.
+ """
+ # get keywords arguments
+ graph = None
+ positive_filter = None
+ restrict_ops_regex = False
+ for k, v in iteritems(kwargs):
+ if k == "graph":
+ graph = v
+ if graph is not None and not isinstance(graph, tf_ops.Graph):
+ raise TypeError("Expected a tf.Graph, got: {}".format(type(graph)))
+ elif k == "positive_filter":
+ positive_filter = v
+ elif k == "restrict_ops_regex":
+ restrict_ops_regex = v
+ elif k == "restrict_ts_regex":
+ pass
+ else:
+ raise ValueError("Wrong keywords argument: {}.".format(k))
+
+ ops = []
+
+ for arg in args:
+ if can_be_regex(arg):
+ if graph is None:
+ raise ValueError("Use the keyword argument 'graph' to use regex.")
+ regex = make_regex(arg)
+ if regex.pattern.startswith("(?#ts)"):
+ continue
+ if restrict_ops_regex and not regex.pattern.startswith("(?#ops)"):
+ continue
+ ops_ = filter_ops_from_regex(graph, regex)
+ for op_ in ops_:
+ if op_ not in ops:
+ if positive_filter is None or positive_filter(op_):
+ ops.append(op_)
+ else:
+ ops_aux = util.make_list_of_op(arg, ignore_ts=True)
+ if positive_filter is not None:
+ ops_aux = [op for op in ops_aux if positive_filter(op)]
+ ops_aux = [op for op in ops_aux if op not in ops]
+ ops += ops_aux
+
+ return ops
+
+
+def select_ts(*args, **kwargs):
+ """Helper to select tensors.
+
+ Args:
+ *args: list of 1) regular expressions (compiled or not) or 2) (array of)
+ `tf.Tensor`. `tf.Operation` instances are silently ignored.
+ **kwargs: 'graph': `tf.Graph` in which to perform the regex query.This is
+ required when using regex.
+ 'positive_filter': an elem if selected only if `positive_filter(elem)` is
+ `True`. This is optional.
+ 'restrict_ts_regex': a regular expression is ignored if it doesn't start
+ with the substring "(?#ts)".
+ Returns:
+ A list of `tf.Tensor`.
+ Raises:
+ TypeError: if the optional keyword argument graph is not a `tf.Graph`
+ or if an argument in args is not an (array of) `tf.Tensor`
+ or an (array of) `tf.Operation` (silently ignored) or a string
+ or a regular expression.
+ ValueError: if one of the keyword arguments is unexpected or if a regular
+ expression is used without passing a graph as a keyword argument.
+ """
+ # get keywords arguments
+ graph = None
+ positive_filter = None
+ restrict_ts_regex = False
+ for k, v in iteritems(kwargs):
+ if k == "graph":
+ graph = v
+ if graph is not None and not isinstance(graph, tf_ops.Graph):
+ raise TypeError("Expected a tf.Graph, got {}".format(type(graph)))
+ elif k == "positive_filter":
+ positive_filter = v
+ elif k == "restrict_ts_regex":
+ restrict_ts_regex = v
+ elif k == "restrict_ops_regex":
+ pass
+ else:
+ raise ValueError("Wrong keywords argument: {}.".format(k))
+
+ ts = []
+
+ for arg in args:
+ if can_be_regex(arg):
+ if graph is None:
+ raise ValueError("Use the keyword argument 'graph' to use regex.")
+ regex = make_regex(arg)
+ if regex.pattern.startswith("(?#ops)"):
+ continue
+ if restrict_ts_regex and not regex.pattern.startswith("(?#ts)"):
+ continue
+ ts_ = filter_ts_from_regex(graph, regex)
+ for t_ in ts_:
+ if t_ not in ts:
+ if positive_filter is None or positive_filter(t_):
+ ts.append(t_)
+ else:
+ ts_aux = util.make_list_of_t(arg, ignore_ops=True)
+ if positive_filter is not None:
+ ts_aux = [t for t in ts_aux if positive_filter(t)]
+ ts_aux = [t for t in ts_aux if t not in ts]
+ ts += ts_aux
+
+ return ts
+
+
+def select_ops_and_ts(*args, **kwargs):
+ """Helper to select operations and tensors.
+
+ Args:
+ *args: list of 1) regular expressions (compiled or not) or 2) (array of)
+ `tf.Operation` 3) (array of) tf.Tensor. Regular expressions matching
+ tensors must start with the comment `"(?#ts)"`, for instance:
+ `"(?#ts)^foo/.*"`.
+ **kwargs: 'graph': `tf.Graph` in which to perform the regex query.This is
+ required when using regex.
+ 'positive_filter': an elem if selected only if `positive_filter(elem)` is
+ `True`. This is optional.
+ Returns:
+ A tuple `(ops, ts)` where:
+ `ops` is a list of `tf.Operation`, and
+ `ts` is a list of `tf.Tensor`
+ Raises:
+ TypeError: if the optional keyword argument graph is not a `tf.Graph`
+ or if an argument in args is not an (array of) `tf.Tensor`
+ or an (array of) `tf.Operation` or a string or a regular expression.
+ ValueError: if one of the keyword arguments is unexpected or if a regular
+ expression is used without passing a graph as a keyword argument.
+ """
+ ops = select_ops(*args, restrict_ops_regex=False, **kwargs)
+ ts = select_ts(*args, restrict_ts_regex=True, **kwargs)
+ return ops, ts
diff --git a/tflex_graph_editor/subgraph.py b/tflex_graph_editor/subgraph.py
new file mode 100644
index 0000000..4bac8c5
--- /dev/null
+++ b/tflex_graph_editor/subgraph.py
@@ -0,0 +1,668 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""SubGraphView: a subgraph view on an existing tf.Graph.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+
+import six
+from six import iteritems
+from six import StringIO
+
+from tflex_graph_editor import select
+from tflex_graph_editor import util
+from tensorflow.python.framework import ops as tf_ops
+
+__all__ = [
+ "SubGraphView",
+ "make_view",
+ "make_view_from_scope",
+]
+
+
+def _finalize_index(index_or_t, ts):
+ """Returns index as is or return index of tensor in `ts`."""
+ if isinstance(index_or_t, six.integer_types):
+ return index_or_t
+ else:
+ return ts.index(index_or_t)
+
+
+def _finalize_indices(list_of_index_or_t, ts):
+ """Returns index in `indices` as is or replace with tensor's index."""
+ return [_finalize_index(index_or_t, ts) for index_or_t in list_of_index_or_t]
+
+
+def _check_within_range(mapping, n, repetition):
+ """Check is the mapping is valid.
+
+ Args:
+ mapping: an iterable of integer.
+ n: define the input domain as [0, n-1]. Note that the mapping can be
+ under-complete, that is, it can only contain a subset of the integers on
+ [0, n-1].
+ repetition: if True repetition are allowed (the function is surjective)
+ otherwise repetition are not allowed (the function is injective).
+ Raises:
+ ValueError: if the mapping is out of range ot if repetition is False and
+ the mapping has some repetition.
+ """
+ for i in mapping:
+ if not 0 <= i < n:
+ raise ValueError("Out of [0, {}[ range: {}".format(n, i))
+ if not repetition and len(set(mapping)) != len(mapping):
+ raise ValueError("Found repetition in mapping: {}".format(mapping))
+
+
+class SubGraphView(object):
+ """A subgraph view on an existing `tf.Graph`.
+
+ An instance of this class is a subgraph view on an existing `tf.Graph`.
+ "subgraph" means that it can represent part of the whole `tf.Graph`.
+ "view" means that it only provides a passive observation and do not to act
+ on the `tf.Graph`. Note that in this documentation, the term "subgraph" is
+ often used as substitute to "subgraph view".
+
+ A subgraph contains:
+
+ * a list of input tensors, accessible via the `inputs` property.
+ * a list of output tensors, accessible via the `outputs` property.
+ * and the operations in between, accessible via the "ops" property.
+
+ An subgraph can be seen as a function F(i0, i1, ...) -> o0, o1, ... It is a
+ function which takes as input some input tensors and returns as output some
+ output tensors. The computation that the function performs is encoded in the
+ operations of the subgraph.
+
+ The tensors (input or output) can be of two kinds:
+
+ - connected: a connected tensor connects to at least one operation contained
+ in the subgraph. One example is a subgraph representing a single operation
+ and its inputs and outputs: all the input and output tensors of the op
+ are "connected".
+ - passthrough: a passthrough tensor does not connect to any operation
+ contained in the subgraph. One example is a subgraph representing a
+ single tensor: this tensor is passthrough. By default a passthrough tensor is
+ present both in the input and output tensors of the subgraph. It can however
+ be remapped to only appear as an input (or output) only.
+
+ The input and output tensors can be remapped. For instance, some input tensor
+ can be omitted. For instance, a subgraph representing an operation with two
+ inputs can be remapped to only take one input. Note that this does not change
+ at all the underlying `tf.Graph` (remember, it is a view). It means that
+ the other input is being ignored, or is being treated as "given".
+ The analogy with functions can be extended like this: F(x,y) is the original
+ function. Remapping the inputs from [x, y] to just [x] means that the subgraph
+ now represent the function F_y(x) (y is "given").
+
+ The output tensors can also be remapped. For instance, some output tensor can
+ be omitted. Other output tensor can be duplicated as well. As mentioned
+ before, this does not change at all the underlying `tf.Graph`.
+ The analogy with functions can be extended like this: F(...)->x,y is the
+ original function. Remapping the outputs from [x, y] to just [y,y] means that
+ the subgraph now represent the function M(F(...)) where M is the function
+ M(a,b)->b,b.
+
+ It is useful to describe three other kind of tensors:
+
+ * internal: an internal tensor is a tensor connecting operations contained
+ in the subgraph. One example in the subgraph representing the two
+ operations A and B connected sequentially: -> A -> B ->. The middle arrow
+ is an internal tensor.
+ * actual input: an input tensor of the subgraph, regardless of whether it is
+ listed in "inputs" or not (masked-out).
+ * actual output: an output tensor of the subgraph, regardless of whether it is
+ listed in "outputs" or not (masked-out).
+ * hidden input: an actual input which has been masked-out using an
+ input remapping. In other word, a hidden input is a non-internal tensor
+ not listed as a input tensor and one of whose consumers belongs to
+ the subgraph.
+ * hidden output: a actual output which has been masked-out using an output
+ remapping. In other word, a hidden output is a non-internal tensor
+ not listed as an output and one of whose generating operations belongs to
+ the subgraph.
+
+ Here are some useful guarantees about an instance of a SubGraphView:
+
+ * the input (or output) tensors are not internal.
+ * the input (or output) tensors are either "connected" or "passthrough".
+ * the passthrough tensors are not connected to any of the operation of
+ the subgraph.
+
+ Note that there is no guarantee that an operation in a subgraph contributes
+ at all to its inputs or outputs. For instance, remapping both the inputs and
+ outputs to empty lists will produce a subgraph which still contains all the
+ original operations. However, the remove_unused_ops function can be used to
+ make a new subgraph view whose operations are connected to at least one of
+ the input or output tensors.
+
+ An instance of this class is meant to be a lightweight object which is not
+ modified in-place by the user. Rather, the user can create new modified
+ instances of a given subgraph. In that sense, the class SubGraphView is meant
+ to be used like an immutable python object.
+
+ A common problem when using views is that they can get out-of-sync with the
+ data they observe (in this case, a `tf.Graph`). This is up to the user to
+ ensure that this doesn't happen. To keep on the safe side, it is recommended
+ that the life time of subgraph views are kept very short. One way to achieve
+ this is to use subgraphs within a "with make_sgv(...) as sgv:" Python context.
+
+ To alleviate the out-of-sync problem, some functions are granted the right to
+ modified subgraph in place. This is typically the case of graph manipulation
+ functions which, given some subgraphs as arguments, can modify the underlying
+ `tf.Graph`. Since this modification is likely to render the subgraph view
+ invalid, those functions can modify the argument in place to reflect the
+ change. For instance, calling the function swap_inputs(svg0, svg1) will modify
+ svg0 and svg1 in place to reflect the fact that their inputs have now being
+ swapped.
+ """
+
+ def __init__(self, inside_ops=(), passthrough_ts=()):
+ """Create a subgraph containing the given ops and the "passthrough" tensors.
+
+ Args:
+ inside_ops: an object convertible to a list of `tf.Operation`. This list
+ defines all the operations in the subgraph.
+ passthrough_ts: an object convertible to a list of `tf.Tensor`. This list
+ define all the "passthrough" tensors. A passthrough tensor is a tensor
+ which goes directly from the input of the subgraph to it output, without
+ any intermediate operations. All the non passthrough tensors are
+ silently ignored.
+ Raises:
+ TypeError: if inside_ops cannot be converted to a list of `tf.Operation`
+ or if `passthrough_ts` cannot be converted to a list of `tf.Tensor`.
+ """
+
+ inside_ops = util.make_list_of_op(inside_ops)
+ passthrough_ts = util.make_list_of_t(passthrough_ts)
+ ops_and_ts = inside_ops + passthrough_ts
+ if ops_and_ts:
+ self._graph = util.get_unique_graph(ops_and_ts)
+ self._ops = inside_ops
+
+ # Compute inside and outside tensor
+ inputs, outputs, insides = select.compute_boundary_ts(inside_ops)
+
+ # Compute passthrough tensors, silently ignoring the non-passthrough ones.
+ all_tensors = frozenset(inputs + outputs + list(insides))
+ self._passthrough_ts = [t for t in passthrough_ts if t not in all_tensors]
+
+ # Set inputs and outputs.
+ self._input_ts = inputs + self._passthrough_ts
+ self._output_ts = outputs + self._passthrough_ts
+ else:
+ self._graph = None
+ self._passthrough_ts = []
+ self._input_ts = []
+ self._output_ts = []
+ self._ops = []
+
+ def __copy__(self):
+ """Create a copy of this subgraph.
+
+ Note that this class is a "view", copying it only create another view and
+ does not copy the underlying part of the `tf.Graph`.
+
+ Returns:
+ A new identical instance of the original subgraph view.
+ """
+ cls = self.__class__
+ result = cls.__new__(cls)
+ for k, v in iteritems(self.__dict__):
+ if k == "_graph":
+ setattr(result, k, v)
+ else:
+ setattr(result, k, list(v)) # copy the list
+ return result
+
+ def _assign_from(self, other):
+ """Assign other to itself.
+
+ Args:
+ other: another subgraph-view.
+ Returns:
+ A new instance identical to the original one.
+ Raises:
+ TypeError: if other is not an SubGraphView.
+ """
+ if not isinstance(other, SubGraphView):
+ raise TypeError("Expected SubGraphView, got: {}".format(type(other)))
+ # pylint: disable=protected-access
+ self._graph = other._graph
+ self._ops = list(other._ops)
+ self._passthrough_ts = list(other._passthrough_ts)
+ self._input_ts = list(other._input_ts)
+ self._output_ts = list(other._output_ts)
+ # pylint: enable=protected-access
+
+ def copy(self):
+ """Return a copy of itself.
+
+ Note that this class is a "view", copying it only create another view and
+ does not copy the underlying part of the tf.Graph.
+
+ Returns:
+ A new instance identical to the original one.
+ """
+ return copy.copy(self)
+
+ def _remap_default(self, remove_input_map=True, remove_output_map=True):
+ """Remap in the place the inputs and/or outputs to the default mapping.
+
+ Args:
+ remove_input_map: if True the input map is reset to the default one.
+ remove_output_map: if True the output map is reset to the default one.
+ """
+ if not remove_input_map and not remove_output_map:
+ return
+
+ # Compute inside and outside tensor
+ inputs, outputs, _ = select.compute_boundary_ts(self._ops)
+ if remove_input_map:
+ self._input_ts = list(inputs) + self._passthrough_ts
+ if remove_output_map:
+ self._output_ts = list(outputs) + self._passthrough_ts
+
+ def remap_default(self, remove_input_map=True, remove_output_map=True):
+ """Remap the inputs and/or outputs to the default mapping.
+
+ Args:
+ remove_input_map: if True the input map is reset to the default one.
+ remove_output_map: if True the output map is reset to the default one.
+ Returns:
+ A new modified instance of the original subgraph view with its
+ input and/or output mapping reset to the default one.
+ """
+ res = self.copy()
+ res._remap_default(remove_input_map, remove_output_map) # pylint: disable=protected-access
+ return res
+
+ def _remap_inputs(self, new_input_indices):
+ """Remap the inputs of the subgraph in-place."""
+ new_input_indices = _finalize_indices(new_input_indices, self._input_ts)
+ _check_within_range(
+ new_input_indices, len(self._input_ts), repetition=False)
+ self._input_ts = [self._input_ts[i] for i in new_input_indices]
+
+ def _remap_outputs(self, new_output_indices):
+ """Remap the outputs of the subgraph in-place."""
+ new_output_indices = _finalize_indices(new_output_indices, self._output_ts)
+ _check_within_range(
+ new_output_indices, len(self._output_ts), repetition=True)
+ self._output_ts = [self._output_ts[i] for i in new_output_indices]
+
+ def _remap_outputs_make_unique(self):
+ """Remap the outputs in place so that all the tensors appears only once."""
+ output_ts = list(self._output_ts)
+ self._output_ts = []
+ util.concatenate_unique(self._output_ts, output_ts)
+
+ def _remap_outputs_to_consumers(self):
+ """Remap the outputs in place to match the number of consumers."""
+ self._remap_outputs_make_unique()
+ output_ts = list(self._output_ts)
+ self._output_ts = []
+ for t in output_ts:
+ self._output_ts += [t] * len(t.consumers())
+
+ def remap_outputs_make_unique(self):
+ """Remap the outputs so that all the tensors appears only once."""
+ res = copy.copy(self)
+ res._remap_outputs_make_unique() # pylint: disable=protected-access
+ return res
+
+ def remap_outputs_to_consumers(self):
+ """Remap the outputs to match the number of consumers."""
+ res = copy.copy(self)
+ res._remap_outputs_to_consumers() # pylint: disable=protected-access
+ return res
+
+ def _remove_unused_ops(self, control_inputs=True):
+ """Remove unused ops in place.
+
+ Args:
+ control_inputs: if True, control inputs are used to detect used ops.
+ Returns:
+ A new subgraph view which only contains used operations.
+ """
+ ops = select.get_walks_union_ops(
+ self.connected_inputs,
+ self.connected_outputs,
+ within_ops=self._ops,
+ control_inputs=control_inputs)
+ self._ops = [op for op in self._ops if op in ops]
+
+ def remove_unused_ops(self, control_inputs=True):
+ """Remove unused ops.
+
+ Args:
+ control_inputs: if True, control inputs are used to detect used ops.
+ Returns:
+ A new subgraph view which only contains used operations.
+ """
+ res = copy.copy(self)
+ res._remove_unused_ops(control_inputs) # pylint: disable=protected-access
+ return res
+
+ def remap_inputs(self, new_input_indices):
+ """Remap the inputs of the subgraph.
+
+ If the inputs of the original subgraph are [t0, t1, t2], remapping to [2,0]
+ will create a new instance whose inputs is [t2, t0].
+
+ Note that this is only modifying the view: the underlying `tf.Graph` is not
+ affected.
+
+ Args:
+ new_input_indices: an iterable of integers or tf.Tensors
+ representing a mapping between the old inputs and the new ones.
+ Integers must be positive and smaller than the number of old inputs.
+ tf.Tensors must belong to the old list of inputs.
+ This mapping can be under-complete and must be without repetitions.
+ Returns:
+ A new modified instance of the original subgraph view with remapped
+ inputs.
+ """
+ res = self.copy()
+ res._remap_inputs(new_input_indices) # pylint: disable=protected-access
+ return res
+
+ def remap_outputs(self, new_output_indices):
+ """Remap the output of the subgraph.
+
+ If the output of the original subgraph are [t0, t1, t2], remapping to
+ [1,1,0] will create a new instance whose outputs is [t1, t1, t0].
+
+ Note that this is only modifying the view: the underlying tf.Graph is not
+ affected.
+
+ Args:
+ new_output_indices: an iterable of integers or tf.Tensors
+ representing a mapping between the old outputs and the new ones.
+ Integers must be positive and smaller than the number of old outputs.
+ tf.Tensors must belong to the old list of outputs.
+ This mapping can be under-complete and can have repetitions.
+ Returns:
+ A new modified instance of the original subgraph view with remapped
+ outputs.
+ """
+ res = copy.copy(self)
+ res._remap_outputs(new_output_indices) # pylint: disable=protected-access
+ return res
+
+ def remap(self, new_input_indices=None, new_output_indices=None):
+ """Remap the inputs and outputs of the subgraph.
+
+ Note that this is only modifying the view: the underlying tf.Graph is not
+ affected.
+
+ Args:
+ new_input_indices: an iterable of integers or tf.Tensors
+ representing a mapping between the old inputs and the new ones.
+ Integers must be positive and smaller than the number of old inputs.
+ tf.Tensors must belong to the old list of inputs.
+ This mapping can be under-complete and must be without repetitions.
+ new_output_indices: an iterable of integers or tf.Tensors
+ representing a mapping between the old outputs and the new ones.
+ Integers must be positive and smaller than the number of old outputs.
+ tf.Tensors must belong to the old list of outputs.
+ This mapping can be under-complete and can have repetitions.
+ Returns:
+ A new modified instance of the original subgraph view with remapped
+ inputs and outputs.
+ """
+ res = copy.copy(self)
+ if new_input_indices is not None:
+ res._remap_inputs(new_input_indices) # pylint: disable=protected-access
+ if new_output_indices is not None:
+ res._remap_outputs(new_output_indices) # pylint: disable=protected-access
+ return res
+
+ def find_op_by_name(self, op_name):
+ """Return the op named op_name.
+
+ Args:
+ op_name: the name to search for
+ Returns:
+ The op named op_name.
+ Raises:
+ ValueError: if the op_name could not be found.
+ AssertionError: if the name was found multiple time.
+ """
+ res = [op for op in self._ops if op.name == op_name]
+ if not res:
+ raise ValueError("{} not in subgraph.".format(op_name))
+ if len(res) > 1:
+ raise AssertionError("More than 1 op named: {}!".format(op_name))
+ return res[0]
+
+ def __str__(self):
+ if not self:
+ return "SubGraphView: empty"
+
+ def op_name(op):
+ return op.name
+
+ def tensor_name(t):
+ if t in self._passthrough_ts:
+ return "{} *".format(t.name)
+ else:
+ return t.name
+
+ def print_list(name, iterable, get_name):
+ if iterable:
+ print("** {}[{}]:".format(name, len(iterable)), file=res)
+ print("\n".join([" {}".format(get_name(elem)) for elem in iterable]),
+ file=res)
+ else:
+ print("** {}: empty".format(name), file=res)
+
+ res = StringIO()
+ print("SubGraphView (graphid={}):".format(id(self.graph)), file=res)
+ print_list("ops", self._ops, op_name)
+ print_list("inputs", self._input_ts, tensor_name)
+ print_list("outputs", self._output_ts, tensor_name)
+ return res.getvalue()
+
+ @property
+ def graph(self):
+ """The underlying `tf.Graph`."""
+ return self._graph
+
+ @property
+ def ops(self):
+ """The operations in this subgraph view."""
+ return self._ops
+
+ @property
+ def inputs(self):
+ """The input tensors of this subgraph view."""
+ return util.ListView(self._input_ts)
+
+ @property
+ def connected_inputs(self):
+ """The connected input tensors of this subgraph view."""
+ return [t for t in self._input_ts if t not in self._passthrough_ts]
+
+ @property
+ def outputs(self):
+ """The output tensors of this subgraph view."""
+ return util.ListView(self._output_ts)
+
+ @property
+ def connected_outputs(self):
+ """The connected output tensors of this subgraph view."""
+ return [t for t in self._output_ts if t not in self._passthrough_ts]
+
+ @property
+ def passthroughs(self):
+ """The passthrough tensors, going straight from input to output."""
+ return util.ListView(self._passthrough_ts)
+
+ def __bool__(self):
+ """Allows for implicit boolean conversion."""
+ return self._graph is not None
+
+ # Python 3 wants __bool__, Python 2.7 wants __nonzero__
+ __nonzero__ = __bool__
+
+ def op(self, op_id):
+ """Get an op by its index."""
+ return self._ops[op_id]
+
+ def is_passthrough(self, t):
+ """Check whether a tensor is passthrough."""
+ return t in self._passthrough_ts
+
+ def __enter__(self):
+ """Allow Python context to minimize the life time of a subgraph view.
+
+ A subgraph view is meant to be a lightweight and transient object. A short
+ lifetime will alleviate the "out-of-sync" issue mentioned earlier. For that
+ reason, a SubGraphView instance can be used within a Python context. For
+ example:
+
+ from tensorflow.contrib import graph_editor as ge
+ with ge.make_sgv(...) as sgv:
+ print(sgv)
+
+ Returns:
+ Itself.
+ """
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ pass
+
+ def input_index(self, t):
+ """Find the input index corresponding to the given input tensor t.
+
+ Args:
+ t: the input tensor of this subgraph view.
+ Returns:
+ The index in the self.inputs list.
+ Raises:
+ Error: if t in not an input tensor.
+ """
+ try:
+ subgraph_id = self._input_ts.index(t)
+ except:
+ raise ValueError("Can't find {} in inputs of subgraph {}.".format(
+ t.name, self.name))
+ return subgraph_id
+
+ def output_index(self, t):
+ """Find the output index corresponding to given output tensor t.
+
+ Args:
+ t: the output tensor of this subgraph view.
+ Returns:
+ The index in the self.outputs list.
+ Raises:
+ Error: if t in not an output tensor.
+ """
+ try:
+ subgraph_id = self._output_ts.index(t)
+ except:
+ raise ValueError("Can't find {} in outputs of subgraph {}.".format(
+ t.name, self.name))
+ return subgraph_id
+
+ def consumers(self):
+ """Return a Python set of all the consumers of this subgraph view.
+
+ A consumer of a subgraph view is a tf.Operation which is a consumer
+ of one of the output tensors and is not in the subgraph.
+
+ Returns:
+ A list of `tf.Operation` which are the consumers of this subgraph view.
+ """
+ ops_set = frozenset(self._ops)
+ res = []
+ for output in self._output_ts:
+ consumers = [op for op in output.consumers() if op not in ops_set]
+ util.concatenate_unique(res, consumers)
+ return res
+
+
+def _check_graph(sgv, graph):
+ """Check if sgv belongs to the given graph.
+
+ Args:
+ sgv: a SubGraphView.
+ graph: a graph or None.
+ Returns:
+ The SubGraphView sgv.
+ Raises:
+ TypeError: if sgv is not a SubGraphView or if graph is not None and not
+ a tf.Graph.
+ ValueError: if the graph of sgv and the given graph are not None and
+ different.
+ """
+ if not isinstance(sgv, SubGraphView):
+ raise TypeError("Expected a SubGraphView, got: {}".format(type(graph)))
+ if graph is None or not sgv.graph:
+ return sgv
+ if not isinstance(graph, tf_ops.Graph):
+ raise TypeError("Expected a tf.Graph, got: {}".format(type(graph)))
+ if sgv.graph is not graph:
+ raise ValueError("Graph mismatch.")
+ return sgv
+
+
+def make_view(*args, **kwargs):
+ """Create a SubGraphView from selected operations and passthrough tensors.
+
+ Args:
+ *args: list of 1) regular expressions (compiled or not) or 2) (array of)
+ `tf.Operation` 3) (array of) `tf.Tensor`. Those objects will be converted
+ into a list of operations and a list of candidate for passthrough tensors.
+ **kwargs: keyword graph is used 1) to check that the ops and ts are from
+ the correct graph 2) for regular expression query
+ Returns:
+ A subgraph view.
+ Raises:
+ TypeError: if the optional keyword argument graph is not a `tf.Graph`
+ or if an argument in args is not an (array of) `tf.Tensor`
+ or an (array of) `tf.Operation` or a string or a regular expression.
+ ValueError: if one of the keyword arguments is unexpected.
+ """
+ # get keywords arguments
+ graph = kwargs["graph"] if "graph" in kwargs else None
+
+ # already a view?
+ if len(args) == 1 and isinstance(args[0], SubGraphView):
+ return _check_graph(args[0], graph)
+
+ ops, ts = select.select_ops_and_ts(*args, **kwargs)
+ sgv = SubGraphView(ops, ts)
+ return _check_graph(sgv, graph)
+
+
+def make_view_from_scope(scope, graph):
+ """Make a subgraph from a name scope.
+
+ Args:
+ scope: the name of the scope.
+ graph: the `tf.Graph`.
+ Returns:
+ A subgraph view representing the given scope.
+ """
+ ops = select.get_name_scope_ops(graph, scope)
+ return SubGraphView(ops)
diff --git a/tflex_graph_editor/transform.py b/tflex_graph_editor/transform.py
new file mode 100644
index 0000000..e33223a
--- /dev/null
+++ b/tflex_graph_editor/transform.py
@@ -0,0 +1,752 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Class to transform an subgraph into another.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from copy import deepcopy
+from functools import partial
+from six import iteritems
+from six import string_types
+from six import StringIO
+from tflex_graph_editor import reroute
+from tflex_graph_editor import select
+from tflex_graph_editor import subgraph
+from tflex_graph_editor import util
+from tensorflow.python.framework import ops as tf_ops
+from tensorflow.python.platform import tf_logging as logging
+
+
+__all__ = [
+ "replace_t_with_placeholder_handler",
+ "keep_t_if_possible_handler",
+ "assign_renamed_collections_handler",
+ "transform_op_if_inside_handler",
+ "copy_op_handler",
+ "Transformer",
+ "TransformerInfo",
+ "copy",
+ "copy_with_input_replacements",
+ "graph_replace",
+]
+
+
+def replace_t_with_placeholder_handler(info, t):
+ """Transform a tensor into a placeholder tensor.
+
+ This handler is typically used to transform a subgraph input tensor into a
+ placeholder.
+
+ Args:
+ info: Transform._TmpInfo instance.
+ t: tensor whose input must be transformed into a place holder.
+ Returns:
+ The tensor generated by the newly created place holder.
+ """
+ with info.graph_.as_default():
+ t_ = util.make_placeholder_from_tensor(t, scope=info.scope_)
+ return t_
+
+
+def keep_t_if_possible_handler(info, t):
+ """Transform a tensor into itself (identity) if possible.
+
+ This handler transform a tensor into itself if the source and destination
+ graph are the same. Otherwise it will create a placeholder.
+ This handler is typically used to transform a hidden input tensors.
+
+ Args:
+ info: Transform._TmpInfo instance.
+ t: tensor whose input must be transformed into a place holder.
+ Returns:
+ The tensor generated by the newly created place holder.
+ """
+ if info.graph is info.graph_:
+ return t
+ else:
+ return replace_t_with_placeholder_handler(info, t)
+
+
+def assign_renamed_collections_handler(info, elem, elem_):
+ """Add the transformed elem to the (renamed) collections of elem.
+
+ A collection is renamed only if is not a known key, as described in
+ `tf.compat.v1.GraphKeys`.
+
+ Args:
+ info: Transform._TmpInfo instance.
+ elem: the original element (`tf.Tensor` or `tf.Operation`)
+ elem_: the transformed element
+ """
+ known_collection_names = util.get_predefined_collection_names()
+ for name, collection in iteritems(info.collections):
+ if elem not in collection:
+ continue
+
+ if name in known_collection_names:
+ transformed_name = name
+ else:
+ transformed_name = info.new_name(name)
+ info.graph_.add_to_collection(transformed_name, elem_)
+
+
+def transform_op_if_inside_handler(info, op, keep_if_possible=True):
+ """Transform an optional op only if it is inside the subgraph.
+
+ This handler is typically use to handle original op: it is fine to keep them
+ if they are inside the subgraph, otherwise they are just ignored.
+
+ Args:
+ info: Transform._TmpInfo instance.
+ op: the optional op to transform (or ignore).
+ keep_if_possible: re-attach to the original op if possible, that is,
+ if the source graph and the destination graph are the same.
+ Returns:
+ The transformed op or None.
+ """
+ if op in info.sgv.ops:
+ return info.transformed_ops[op]
+ else:
+ if keep_if_possible and info.graph is info.graph_:
+ return op
+ else:
+ return None
+
+
+def copy_op_handler(info, op, new_inputs, copy_shape=False, nodedef_fn=None):
+ """Copy a `tf.Operation`.
+
+ Args:
+ info: Transform._TmpInfo instance.
+ op: the `tf.Operation` to be copied.
+ new_inputs: The new inputs for this op.
+ copy_shape: also copy the shape of the tensor
+ nodedef_fn: If provided, a function that will be run on the NodeDef
+ and should return a mutated NodeDef before a new Operation is created.
+ This is useful as certain features cannot be set on the Operation and
+ must be modified in NodeDef.
+
+ Returns:
+ A `(op, op_outputs)` tuple containing the transformed op and its outputs.
+ """
+ # The `new_inputs` was added to this function. For compatibility reason,
+ # let's raise an error if `new_inputs` is a boolean.
+ if isinstance(new_inputs, bool):
+ raise TypeError("the `new_inputs` argument must be an iterable.")
+
+ # pylint: disable=protected-access
+
+ # Clone the node def:
+ node_def_ = deepcopy(op.node_def)
+
+ # Transform name:
+ name_ = info.new_name(op.name)
+ name_ = info.graph_.unique_name(name_)
+ node_def_.name = name_
+
+ # Mutate NodeDef if requested:
+ if nodedef_fn is not None:
+ node_def_ = nodedef_fn(node_def_)
+
+ # Copy the other inputs needed for initialization
+ output_types_ = op._output_types[:]
+ input_types_ = op._input_types[:]
+
+ # Make a copy of the op_def too.
+ # Its unique to every _type_ of Operation.
+ op_def_ = deepcopy(op.op_def)
+
+ # Initialize a new Operation instance
+ op_ = tf_ops.Operation(node_def_, info.graph_, new_inputs, output_types_,
+ [], input_types_, None, op_def_)
+
+ # copy the shape over
+ if copy_shape:
+ for t, t_ in zip(op.outputs, op_.outputs):
+ t_.set_shape(t.get_shape())
+
+ # Original op cannot be finalised here yet. Because some ops require this
+ # attribute to exist, we will create a dummy original_op first and then
+ # later finalise it with the actual original_op when all the ops have
+ # been copied.
+ # TODO(fkp): Stop worrying about _original_op and remove this code?
+ if op._original_op:
+ op_._original_op = op._original_op
+
+ return op_, op_.outputs
+
+
+class TransformerInfo(object):
+ """"Contains information about the result of a transform operation."""
+
+ def __init__(self, info):
+ """Constructor.
+
+ Args:
+ info: an instance of Transformer._TmpInfo containing various internal
+ information about the transform operation.
+ """
+ self._graph = info.graph
+ self._scope = info.scope
+ self._graph_ = info.graph_
+ self._scope_ = info.scope_
+ self._transformed_ops = info.transformed_ops
+ self._transformed_ts = info.transformed_ts
+
+ def _get_transformed_map(self, top):
+ """Return the correct container depending on the type of `top`."""
+ if isinstance(top, tf_ops.Operation):
+ return self._transformed_ops
+ elif isinstance(top, tf_ops.Tensor):
+ return self._transformed_ts
+ else:
+ raise TypeError(
+ "Expected a tf.Tensor or a tf.Operation, got a {}".format(
+ type(top)))
+
+ def _transformed_elem(self, original_top, missing_fn=None):
+ """Return the transformed op/tensor corresponding to the original one.
+
+ Args:
+ original_top: the original tensor/operation.
+ missing_fn: function handling the case where the counterpart
+ cannot be found. By default, None is returned.
+ Returns:
+ the transformed tensor/operation (or None if no match is found).
+ """
+ transformed_map = self._get_transformed_map(original_top)
+ if isinstance(original_top, string_types):
+ for original, transformed in iteritems(transformed_map):
+ if original.name == original_top:
+ return transformed
+ return None if missing_fn is None else missing_fn(original_top)
+ else:
+ if original_top not in transformed_map:
+ return None if missing_fn is None else missing_fn(original_top)
+ return transformed_map[original_top]
+
+ def _original_elem(self, transformed_top, missing_fn=None):
+ """Return the original op/tensor corresponding to the transformed one.
+
+ Args:
+ transformed_top: the transformed tensor/operation.
+ missing_fn: function handling the case where the counterpart
+ cannot be found. By default, None is returned.
+ Returns:
+ the original tensor/operation (or None if no match is found).
+ """
+ transformed_map = self._get_transformed_map(transformed_top)
+ if isinstance(transformed_top, string_types):
+ finder = lambda transformed: transformed.name == transformed_top
+ else:
+ finder = lambda transformed: transformed == transformed_top
+ for original, transformed in iteritems(transformed_map):
+ if finder(transformed):
+ return original
+ return None if missing_fn is None else missing_fn(transformed_top)
+
+ def transformed(self, original, missing_fn=None):
+ """Return the transformed op/tensor corresponding to the original one.
+
+ Note that the output of this function mimics the hierarchy
+ of its input argument `original`.
+ Given an iterable, it returns a list. Given an operation or a tensor,
+ it will return an operation or a tensor.
+
+ Args:
+ original: the original tensor/operation.
+ missing_fn: function handling the case where the counterpart
+ cannot be found. By default, None is returned.
+ Returns:
+ the transformed tensor/operation (or None if no match is found).
+ """
+ transformed_elem = partial(self._transformed_elem, missing_fn=missing_fn)
+ return util.transform_tree(original, transformed_elem)
+
+ def original(self, transformed, missing_fn=None):
+ """Return the original op/tensor corresponding to the transformed one.
+
+ Note that the output of this function mimics the hierarchy
+ of its input argument `transformed`.
+ Given an iterable, it returns a list. Given an operation or a tensor,
+ it will return an operation or a tensor.
+
+ Args:
+ transformed: the transformed tensor/operation.
+ missing_fn: function handling the case where the counterpart
+ cannot be found. By default, None is returned.
+ Returns:
+ the original tensor/operation (or None if no match is found).
+ """
+ original_elem = partial(self._original_elem, missing_fn=missing_fn)
+ return util.transform_tree(transformed, original_elem)
+
+ def __str__(self):
+ res = StringIO()
+ print("Transform result info:", file=res)
+ if self._graph == self._graph_:
+ in_place_str = "" if self._scope_ else " IN-PLACE"
+ print(" Within graph[{}]{}".format(
+ id(self._graph), in_place_str), file=res)
+ else:
+ print(" graph[{}] => graph[{}]".format(
+ id(self._graph), id(self._graph_)), file=res)
+ if self._scope:
+ print(" Relative to source scope: {}".format(self._scope), file=res)
+ if self._scope_:
+ print(" Scope destination: {}".format(self._scope_), file=res)
+ print("Operations mapping:", file=res)
+ for op, op_ in iteritems(self._transformed_ops):
+ print(" {} => {}".format(op.name, op_.name), file=res)
+ return res.getvalue()
+
+
+class _TmpInfo(object):
+ """Transformer temporary data.
+
+ An instance of this class holds all the information relevant to a call
+ to a transformer instance (that is, a call to __call__). An instance
+ is created for the life-time of the __call__ function and is passed as
+ argument to the handlers.
+ """
+
+ def __init__(self, sgv, dst_graph, dst_scope, src_scope):
+ self.sgv = sgv
+ self.sgv_inputs_set = frozenset(sgv.inputs)
+ self.ops = frozenset(sgv.ops)
+ self.control_outputs = util.ControlOutputs(sgv.graph)
+ self.graph = sgv.graph
+ self.scope = src_scope
+ self.graph_ = dst_graph
+ self.scope_ = dst_scope
+ self.transformed_ops = {}
+ self.transformed_ts = {}
+ self.collections = dict((key, self.graph.get_collection(key))
+ for key in self.graph.get_all_collection_keys())
+ self.cyclic_ops = []
+ self.transform_original_op_handler = transform_op_if_inside_handler
+ # The graph is transformed op by op, in the same order the original ops
+ # were created. However, this is sometimes not possible due to cycles
+ # (i.e. while loops). So when the transformer creates a new op whose
+ # inputs do not exist yet, temporary placeholders are created and stored
+ # in this `tmp_cyclic_ts` container. During a second pass,
+ # those temporary tensors are replaced by the proper transformed tensors
+ # (see the function `_finalize_cycles`).
+ self.tmp_cyclic_ts = []
+
+ def new_name(self, name):
+ """Compute a destination name from a source name.
+
+ Args:
+ name: the name to be "transformed".
+ Returns:
+ The transformed name.
+ Raises:
+ ValueError: if the source scope is used (that is, not an empty string)
+ and the source name does not belong to the source scope.
+ """
+ scope = self.scope
+ if not name.startswith(scope):
+ raise ValueError("{} does not belong to source scope: {}.".format(
+ name, scope))
+ rel_name = name[len(scope):]
+ name_ = self.scope_ + rel_name
+ return name_
+
+
+class Transformer(object):
+ """Transform a subgraph into another one.
+
+ By default, the constructor create a transform which copy a subgraph and
+ replaces inputs with placeholders. This behavior can be modified by changing
+ the handlers.
+ """
+
+ def __init__(self):
+ """Transformer constructor.
+
+ The following members can be modified:
+ transform_op_handler: handle the transformation of a `tf.Operation`.
+ This handler defaults to a simple copy.
+ assign_collections_handler: handle the assignment of collections.
+ This handler defaults to assigning new collections created under the
+ given name-scope.
+ transform_external_input_handler: handle the transform of the inputs to
+ the given subgraph. This handler defaults to creating placeholders
+ instead of the ops just before the input tensors of the subgraph.
+ transform_external_hidden_input_handler: handle the transform of the
+ hidden inputs of the subgraph, that is, the inputs which are not listed
+ in sgv.inputs. This handler defaults to a transform which keep the same
+ input if the source and destination graphs are the same, otherwise
+ use placeholders.
+ transform_original_op_handler: handle the transform of original_op. This
+ handler defaults to transforming original_op only if they are in the
+ subgraph, otherwise they are ignored.
+ """
+
+ # handlers
+ self.transform_op_handler = copy_op_handler
+ self.transform_control_input_handler = transform_op_if_inside_handler
+ self.assign_collections_handler = assign_renamed_collections_handler
+ self.transform_external_input_handler = replace_t_with_placeholder_handler
+ self.transform_external_hidden_input_handler = keep_t_if_possible_handler
+ self.transform_original_op_handler = transform_op_if_inside_handler
+
+ def __call__(self,
+ sgv,
+ dst_graph,
+ dst_scope,
+ src_scope="",
+ reuse_dst_scope=False):
+ """Execute the transformation.
+
+ Args:
+ sgv: the source subgraph-view.
+ dst_graph: the destination graph.
+ dst_scope: the destination scope.
+ src_scope: the source scope, which specify the path from which the
+ relative path of the transformed nodes are computed. For instance, if
+ src_scope is a/ and dst_scoped is b/, then the node a/x/y will have a
+ relative path of x/y and will be transformed into b/x/y.
+ reuse_dst_scope: if True the dst_scope is re-used if it already exists.
+ Otherwise, the scope is given a unique name based on the one given
+ by appending an underscore followed by a digit (default).
+ Returns:
+ A tuple `(sgv, info)` where:
+ `sgv` is the transformed subgraph view;
+ `info` is an instance of TransformerInfo containing
+ information about the transform, including mapping between
+ original and transformed tensors and operations.
+ Raises:
+ ValueError: if the arguments are invalid.
+ """
+ sgv = subgraph.make_view(sgv)
+ if not isinstance(dst_graph, tf_ops.Graph):
+ raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph)))
+
+ src_scope = util.scope_finalize(src_scope)
+ dst_scope = util.scope_finalize(dst_scope)
+
+ # Potentially create new scope if reuse_dst_scope is False
+ if dst_scope and not reuse_dst_scope:
+ dst_scope = util.scope_finalize(dst_graph.unique_name(dst_scope[:-1]))
+
+ # Create temporary info used during this transform call
+ info = _TmpInfo(sgv, dst_graph, dst_scope, src_scope)
+
+ self._copy_ops(info)
+ self._finalize_cycles(info)
+ self._connect_control_inputs(info)
+
+ # Compute information about the transformation
+ res_info = TransformerInfo(info)
+ sgv_ = self._transform_sgv(info, sgv)
+ return sgv_, res_info
+
+ def _copy_ops(self, info):
+ """Copy ops without connecting them."""
+ sorted_ops = sorted(info.sgv.ops, key=lambda op: op._id) # pylint: disable=protected-access
+ for op in sorted_ops:
+ new_inputs = [self._transformed_t(info, t, op) for t in op.inputs]
+ op_, op_outputs_ = self.transform_op_handler(info, op, new_inputs)
+ if op is op_:
+ raise ValueError("In-place transformation not allowed.")
+
+ # Process op.
+ info.transformed_ops[op] = op_
+ self.assign_collections_handler(info, op, op_)
+
+ # Process output tensors.
+ for op_output, op_output_ in zip(op.outputs, op_outputs_):
+ info.transformed_ts[op_output] = op_output_
+ self.assign_collections_handler(info, op_output, op_output_)
+
+ def _finalize_cycles(self, info):
+ """Reconnects the cyclic tensors."""
+ for t, tmp_t_, consumer_op in info.tmp_cyclic_ts:
+ if t not in info.transformed_ts:
+ raise ValueError("The tensor {} should be transformed by now.".format(
+ t.name))
+ if consumer_op not in info.transformed_ops:
+ raise ValueError("The op {} should be transformed by now.".format(
+ consumer_op.name))
+ t_ = info.transformed_ts[t]
+ consumer_op_ = info.transformed_ops[consumer_op]
+ t_index_ = list(consumer_op_.inputs).index(tmp_t_)
+ consumer_op_._update_input(t_index_, t_) # pylint: disable=protected-access
+
+ def _connect_control_inputs(self, info):
+ """Connect the previously copied ops."""
+ for op in info.sgv.ops:
+ logging.debug("Connecting control inputs of op: %s", op.name)
+ op_ = info.transformed_ops[op]
+
+ # Finalize original op.
+ # TODO(fkp): Stop worrying about _original_op and remove this code?
+ # pylint: disable=protected-access
+ if op._original_op:
+ original_op = self.transform_original_op_handler(info, op._original_op)
+ if original_op is None:
+ logging.debug("Could not find original op for: %s", op_.name)
+ else:
+ op_._original_op = original_op
+ # pylint: enable=protected-access
+
+ # Finalize control inputs:
+ control_inputs_ = [self.transform_control_input_handler(info, ci)
+ for ci in op.control_inputs]
+ control_inputs_ = [ci for ci in control_inputs_ if ci is not None]
+ reroute.add_control_inputs(op_, control_inputs_)
+
+ def _transform_sgv(self, info, sgv):
+ """Transform a subgraph view.
+
+ For convenience, a transform operation returns a subgraph view of the
+ transformed graph.
+
+ Args:
+ info: Temporary information for this transorfm call.
+ sgv: the subgraph to be transformed.
+ Returns:
+ The transformed subgraph.
+ """
+ ops_ = [op_ for _, op_ in iteritems(info.transformed_ops)]
+ sgv_ = subgraph.SubGraphView(ops_)
+ sgv_inputs_ = sgv_.inputs
+ sgv_outputs_ = sgv_.outputs
+
+ # re-order inputs
+ input_map_ = []
+ for input_t in sgv.inputs:
+ if input_t not in info.transformed_ts:
+ continue
+ input_t_ = info.transformed_ts[input_t]
+ if input_t_ not in sgv_inputs_:
+ continue
+ input_t_index_ = sgv_.input_index(input_t_)
+ input_map_.append(input_t_index_)
+
+ # re-order outputs
+ output_map_ = []
+ for output_t in sgv.outputs:
+ if output_t not in info.transformed_ts:
+ continue
+ output_t_ = info.transformed_ts[output_t]
+ if output_t_ not in sgv_outputs_:
+ continue
+ output_t_index_ = sgv_.output_index(output_t_)
+ output_map_.append(output_t_index_)
+
+ return sgv_.remap(input_map_, output_map_)
+
+ def _transformed_t(self, info, t, consumer_op):
+ """Return tre transformed tensor of `t`."""
+ if t in info.transformed_ts:
+ # If op is in the subgraph, just return its transformed counterpart.
+ return info.transformed_ts[t]
+
+ if t in info.sgv_inputs_set:
+ # `t` is an input of the subgraph.
+ return self.transform_external_input_handler(info, t)
+ elif t.op in info.ops:
+ # `t` is an internal tensor but is not transformed yet because it
+ # belongs to a graph cycle.
+ logging.debug("Cyclic tensor: t.name = %s", t.name)
+ # Try to find an existing tensor we can use for now,
+ # otherwise create one. We'll rewire this later.
+ if consumer_op.type == "Merge":
+ first_input = consumer_op.inputs[0]
+ tmp_t_ = self._transformed_t(info, first_input, consumer_op)
+ elif t.op.type == "Enter":
+ enter_input = t.op.inputs[0]
+ tmp_t_ = self._transformed_t(info, enter_input, consumer_op)
+ else:
+ with info.graph_.as_default():
+ tmp_t_ = util.make_placeholder_from_tensor(t, scope=info.scope_,
+ prefix="geph_tmp")
+ logging.debug("Created temporary placeholder: %s.", tmp_t_.name)
+ # Register as temporary and return.
+ info.tmp_cyclic_ts.append((t, tmp_t_, consumer_op))
+ return tmp_t_
+ else:
+ # `t` is a hidden input of the subgraph.
+ return self.transform_external_hidden_input_handler(info, t)
+
+
+def copy(sgv, dst_graph=None, dst_scope="", src_scope="",
+ reuse_dst_scope=False):
+ """Copy a subgraph.
+
+ Args:
+ sgv: the source subgraph-view. This argument is converted to a subgraph
+ using the same rules than the function subgraph.make_view.
+ dst_graph: the destination graph.
+ dst_scope: the destination scope.
+ src_scope: the source scope.
+ reuse_dst_scope: if True the dst_scope is re-used if it already exists.
+ Otherwise, the scope is given a unique name based on the one given
+ by appending an underscore followed by a digit (default).
+ Returns:
+ A tuple `(sgv, info)` where:
+ `sgv` is the transformed subgraph view;
+ `info` is an instance of TransformerInfo containing
+ information about the transform, including mapping between
+ original and transformed tensors and operations.
+ Raises:
+ TypeError: if `dst_graph` is not a `tf.Graph`.
+ StandardError: if sgv cannot be converted to a SubGraphView using
+ the same rules than the function subgraph.make_view.
+ """
+ sgv = subgraph.make_view(sgv)
+ if dst_graph is None:
+ dst_graph = sgv.graph
+ if not isinstance(dst_graph, tf_ops.Graph):
+ raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph)))
+
+ copier = Transformer()
+ return copier(
+ sgv, dst_graph, dst_scope, src_scope, reuse_dst_scope=reuse_dst_scope)
+
+
+def copy_with_input_replacements(sgv, replacement_ts,
+ dst_graph=None, dst_scope="", src_scope="",
+ reuse_dst_scope=False):
+ """Copy a subgraph, replacing some of its inputs.
+
+ Note a replacement only happens if the tensor to be replaced
+ is an input of the given subgraph. The inputs of a subgraph can
+ be queried using sgv.inputs.
+
+ Args:
+ sgv: the source subgraph-view. This argument is converted to a subgraph
+ using the same rules as the function subgraph.make_view.
+ replacement_ts: dictionary mapping from original tensors to the
+ replaced one.
+ dst_graph: the destination graph.
+ dst_scope: the destination scope.
+ src_scope: the source scope.
+ reuse_dst_scope: if True the dst_scope is re-used if it already exists.
+ Otherwise, the scope is given a unique name based on the one given
+ by appending an underscore followed by a digit (default).
+ Returns:
+ A tuple `(sgv, info)` where:
+ `sgv` is the transformed subgraph view;
+ `info` is an instance of TransformerInfo containing
+ information about the transform, including mapping between
+ original and transformed tensors and operations.
+ Raises:
+ TypeError: if dst_graph is not a tf.Graph.
+ StandardError: if sgv cannot be converted to a SubGraphView using
+ the same rules as the function subgraph.make_view.
+ """
+ sgv = subgraph.make_view(sgv)
+ if dst_graph is None:
+ dst_graph = sgv.graph
+ if not isinstance(dst_graph, tf_ops.Graph):
+ raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph)))
+
+ copier = Transformer()
+ # Replace tensor if possible.
+ def replace_t_with_replacement_handler(info, t):
+ if t in replacement_ts:
+ return replacement_ts[t]
+ else:
+ return keep_t_if_possible_handler(info, t)
+ copier.transform_external_input_handler = replace_t_with_replacement_handler
+ return copier(
+ sgv, dst_graph, dst_scope, src_scope, reuse_dst_scope=reuse_dst_scope)
+
+
+def _add_control_flow_ops(ops, control_ios):
+ """Complete `ops` so that the transformed graph is valid.
+
+ Partially copying a graph can lead to a malformed graph. For instance,
+ copying half of a while construct is likely to result in an invalid graph.
+ This function attempts to add missing ops so that the transformation result
+ in a valid graph.
+
+ Args:
+ ops: list of ops (modifed in-place).
+ control_ios: object created by a call to `util.ControlOutputs`.
+ """
+ # Find while contexts.
+ control_flow_contexts = set()
+ for op in ops:
+ cfc = op._control_flow_context # pylint: disable=protected-access
+ if cfc:
+ control_flow_contexts.add(cfc)
+ # Find new ops.
+ new_ops = []
+ for cfc in control_flow_contexts:
+ if cfc.IsWhileContext():
+ new_ops += select.get_walks_intersection_ops(
+ [enter_t.op for enter_t in cfc.loop_enters],
+ [exit_t.op for exit_t in cfc.loop_exits],
+ control_ios=control_ios)
+ # Add new ops.
+ new_ops_set = set(new_ops)
+ ops_set = frozenset(ops)
+ for op in new_ops_set:
+ if op not in ops_set:
+ ops.append(op)
+
+
+def graph_replace(target_ts, replacement_ts, dst_scope="",
+ src_scope="", reuse_dst_scope=False):
+ """Create a new graph which compute the targets from the replaced Tensors.
+
+ Args:
+ target_ts: a single tf.Tensor or an iterable of tf.Tensor.
+ replacement_ts: dictionary mapping from original tensors to replaced tensors
+ dst_scope: the destination scope.
+ src_scope: the source scope.
+ reuse_dst_scope: if True the dst_scope is re-used if it already exists.
+ Otherwise, the scope is given a unique name based on the one given
+ by appending an underscore followed by a digit (default).
+ Returns:
+ A single tf.Tensor or a list of target tf.Tensor, depending on
+ the type of the input argument `target_ts`.
+ The returned tensors are recomputed using the tensors from replacement_ts.
+ Raises:
+ ValueError: if the targets are not connected to replacement_ts.
+ """
+ # Identify operations in the graph that will change.
+ # Start forward walk at Tensors that will be replaced, and
+ # backward walk at the target output Tensors.
+ flatten_target_ts = util.flatten_tree(target_ts)
+ # Construct the forward control dependencies edges so that
+ # the get_walks_intersection_ops can also traverse the
+ # control dependencies.
+ graph = util.get_unique_graph(flatten_target_ts, check_types=(tf_ops.Tensor))
+ control_ios = util.ControlOutputs(graph)
+ ops = select.get_walks_intersection_ops(
+ list(replacement_ts), flatten_target_ts, control_ios=control_ios)
+ if not ops:
+ raise ValueError("Targets and replacements are not connected!")
+
+ # Complete ops to avoid malformed control flow.
+ # TODO(fkp): Consider moving this function deeper (in the transformer?).
+ _add_control_flow_ops(ops, control_ios)
+
+ # Create a copy of the relevant subgraph
+ unused_sgv_, info = copy_with_input_replacements(
+ ops, replacement_ts, None, dst_scope, src_scope, reuse_dst_scope)
+ # Return the transformed targets but keep the original if the transformed
+ # counterpart cannot be found
+ missing_fn = lambda original_t: original_t
+ return info.transformed(target_ts, missing_fn)
diff --git a/tflex_graph_editor/util.py b/tflex_graph_editor/util.py
new file mode 100644
index 0000000..543c1da
--- /dev/null
+++ b/tflex_graph_editor/util.py
@@ -0,0 +1,566 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility functions for the graph_editor.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import re
+from six import iteritems
+from tensorflow.python.framework import ops as tf_ops
+from tensorflow.python.ops import array_ops as tf_array_ops
+from tensorflow.python.util.compat import collections_abc
+
+__all__ = [
+ "make_list_of_op",
+ "get_tensors",
+ "make_list_of_t",
+ "get_generating_ops",
+ "get_consuming_ops",
+ "ControlOutputs",
+ "placeholder_name",
+ "make_placeholder_from_tensor",
+ "make_placeholder_from_dtype_and_shape",
+]
+
+
+# The graph editor sometimes need to create placeholders, they are named
+# "geph_*". "geph" stands for Graph-Editor PlaceHolder.
+_DEFAULT_PLACEHOLDER_PREFIX = "geph"
+
+
+def concatenate_unique(la, lb):
+ """Add all the elements of `lb` to `la` if they are not there already.
+
+ The elements added to `la` maintain ordering with respect to `lb`.
+
+ Args:
+ la: List of Python objects.
+ lb: List of Python objects.
+ Returns:
+ `la`: The list `la` with missing elements from `lb`.
+ """
+ la_set = set(la)
+ for l in lb:
+ if l not in la_set:
+ la.append(l)
+ la_set.add(l)
+ return la
+
+
+# TODO(fkp): very generic code, it should be moved in a more generic place.
+class ListView(object):
+ """Immutable list wrapper.
+
+ This class is strongly inspired by the one in tf.Operation.
+ """
+
+ def __init__(self, list_):
+ if not isinstance(list_, list):
+ raise TypeError("Expected a list, got: {}.".format(type(list_)))
+ self._list = list_
+
+ def __iter__(self):
+ return iter(self._list)
+
+ def __len__(self):
+ return len(self._list)
+
+ def __bool__(self):
+ return bool(self._list)
+
+ # Python 3 wants __bool__, Python 2.7 wants __nonzero__
+ __nonzero__ = __bool__
+
+ def __getitem__(self, i):
+ return self._list[i]
+
+ def __add__(self, other):
+ if not isinstance(other, list):
+ other = list(other)
+ return list(self) + other
+
+
+# TODO(fkp): very generic code, it should be moved in a more generic place.
+def is_iterable(obj):
+ """Return true if the object is iterable."""
+ if isinstance(obj, tf_ops.Tensor):
+ return False
+ try:
+ _ = iter(obj)
+ except Exception: # pylint: disable=broad-except
+ return False
+ return True
+
+
+def flatten_tree(tree, leaves=None):
+ """Flatten a tree into a list.
+
+ Args:
+ tree: iterable or not. If iterable, its elements (child) can also be
+ iterable or not.
+ leaves: list to which the tree leaves are appended (None by default).
+ Returns:
+ A list of all the leaves in the tree.
+ """
+ if leaves is None:
+ leaves = []
+ if isinstance(tree, dict):
+ for _, child in iteritems(tree):
+ flatten_tree(child, leaves)
+ elif is_iterable(tree):
+ for child in tree:
+ flatten_tree(child, leaves)
+ else:
+ leaves.append(tree)
+ return leaves
+
+
+def transform_tree(tree, fn, iterable_type=tuple):
+ """Transform all the nodes of a tree.
+
+ Args:
+ tree: iterable or not. If iterable, its elements (child) can also be
+ iterable or not.
+ fn: function to apply to each leaves.
+ iterable_type: type use to construct the resulting tree for unknown
+ iterable, typically `list` or `tuple`.
+ Returns:
+ A tree whose leaves has been transformed by `fn`.
+ The hierarchy of the output tree mimics the one of the input tree.
+ """
+ if is_iterable(tree):
+ if isinstance(tree, dict):
+ res = tree.__new__(type(tree))
+ res.__init__(
+ (k, transform_tree(child, fn)) for k, child in iteritems(tree))
+ return res
+ elif isinstance(tree, tuple):
+ # NamedTuple?
+ if hasattr(tree, "_asdict"):
+ res = tree.__new__(type(tree), **transform_tree(tree._asdict(), fn))
+ else:
+ res = tree.__new__(type(tree),
+ (transform_tree(child, fn) for child in tree))
+ return res
+ elif isinstance(tree, collections_abc.Sequence):
+ res = tree.__new__(type(tree))
+ res.__init__(transform_tree(child, fn) for child in tree)
+ return res
+ else:
+ return iterable_type(transform_tree(child, fn) for child in tree)
+ else:
+ return fn(tree)
+
+
+def check_graphs(*args):
+ """Check that all the element in args belong to the same graph.
+
+ Args:
+ *args: a list of object with a obj.graph property.
+ Raises:
+ ValueError: if all the elements do not belong to the same graph.
+ """
+ graph = None
+ for i, sgv in enumerate(args):
+ if graph is None and sgv.graph is not None:
+ graph = sgv.graph
+ elif sgv.graph is not None and sgv.graph is not graph:
+ raise ValueError("Argument[{}]: Wrong graph!".format(i))
+
+
+def get_unique_graph(tops, check_types=None, none_if_empty=False):
+ """Return the unique graph used by the all the elements in tops.
+
+ Args:
+ tops: list of elements to check (usually a list of tf.Operation and/or
+ tf.Tensor). Or a tf.Graph.
+ check_types: check that the element in tops are of given type(s). If None,
+ the types (tf.Operation, tf.Tensor) are used.
+ none_if_empty: don't raise an error if tops is an empty list, just return
+ None.
+ Returns:
+ The unique graph used by all the tops.
+ Raises:
+ TypeError: if tops is not a iterable of tf.Operation.
+ ValueError: if the graph is not unique.
+ """
+ if isinstance(tops, tf_ops.Graph):
+ return tops
+ if not is_iterable(tops):
+ raise TypeError("{} is not iterable".format(type(tops)))
+ if check_types is None:
+ check_types = (tf_ops.Operation, tf_ops.Tensor)
+ elif not is_iterable(check_types):
+ check_types = (check_types,)
+ g = None
+ for op in tops:
+ if not isinstance(op, check_types):
+ raise TypeError("Expected a type in ({}), got: {}".format(", ".join([str(
+ t) for t in check_types]), type(op)))
+ if g is None:
+ g = op.graph
+ elif g is not op.graph:
+ raise ValueError("Operation {} does not belong to given graph".format(op))
+ if g is None and not none_if_empty:
+ raise ValueError("Can't find the unique graph of an empty list")
+ return g
+
+
+def make_list_of_op(ops, check_graph=True, allow_graph=True, ignore_ts=False):
+ """Convert ops to a list of `tf.Operation`.
+
+ Args:
+ ops: can be an iterable of `tf.Operation`, a `tf.Graph` or a single
+ operation.
+ check_graph: if `True` check if all the operations belong to the same graph.
+ allow_graph: if `False` a `tf.Graph` cannot be converted.
+ ignore_ts: if True, silently ignore `tf.Tensor`.
+ Returns:
+ A newly created list of `tf.Operation`.
+ Raises:
+ TypeError: if ops cannot be converted to a list of `tf.Operation` or,
+ if `check_graph` is `True`, if all the ops do not belong to the
+ same graph.
+ """
+ if isinstance(ops, tf_ops.Graph):
+ if allow_graph:
+ return ops.get_operations()
+ else:
+ raise TypeError("allow_graph is False: cannot convert a tf.Graph.")
+ else:
+ if not is_iterable(ops):
+ ops = [ops]
+ if not ops:
+ return []
+ if check_graph:
+ check_types = None if ignore_ts else tf_ops.Operation
+ get_unique_graph(ops, check_types=check_types)
+ return [op for op in ops if isinstance(op, tf_ops.Operation)]
+
+
+# TODO(fkp): move this function in tf.Graph?
+def get_tensors(graph):
+ """get all the tensors which are input or output of an op in the graph.
+
+ Args:
+ graph: a `tf.Graph`.
+ Returns:
+ A list of `tf.Tensor`.
+ Raises:
+ TypeError: if graph is not a `tf.Graph`.
+ """
+ if not isinstance(graph, tf_ops.Graph):
+ raise TypeError("Expected a graph, got: {}".format(type(graph)))
+ ts = []
+ for op in graph.get_operations():
+ ts += op.outputs
+ return ts
+
+
+def make_list_of_t(ts, check_graph=True, allow_graph=True, ignore_ops=False):
+ """Convert ts to a list of `tf.Tensor`.
+
+ Args:
+ ts: can be an iterable of `tf.Tensor`, a `tf.Graph` or a single tensor.
+ check_graph: if `True` check if all the tensors belong to the same graph.
+ allow_graph: if `False` a `tf.Graph` cannot be converted.
+ ignore_ops: if `True`, silently ignore `tf.Operation`.
+ Returns:
+ A newly created list of `tf.Tensor`.
+ Raises:
+ TypeError: if `ts` cannot be converted to a list of `tf.Tensor` or,
+ if `check_graph` is `True`, if all the ops do not belong to the same graph.
+ """
+ if isinstance(ts, tf_ops.Graph):
+ if allow_graph:
+ return get_tensors(ts)
+ else:
+ raise TypeError("allow_graph is False: cannot convert a tf.Graph.")
+ else:
+ if not is_iterable(ts):
+ ts = [ts]
+ if not ts:
+ return []
+ if check_graph:
+ check_types = None if ignore_ops else tf_ops.Tensor
+ get_unique_graph(ts, check_types=check_types)
+ return [t for t in ts if isinstance(t, tf_ops.Tensor)]
+
+
+def get_generating_ops(ts):
+ """Return all the generating ops of the tensors in `ts`.
+
+ Args:
+ ts: a list of `tf.Tensor`
+ Returns:
+ A list of all the generating `tf.Operation` of the tensors in `ts`.
+ Raises:
+ TypeError: if `ts` cannot be converted to a list of `tf.Tensor`.
+ """
+ ts = make_list_of_t(ts, allow_graph=False)
+ return [t.op for t in ts]
+
+
+def get_consuming_ops(ts):
+ """Return all the consuming ops of the tensors in ts.
+
+ Args:
+ ts: a list of `tf.Tensor`
+ Returns:
+ A list of all the consuming `tf.Operation` of the tensors in `ts`.
+ Raises:
+ TypeError: if ts cannot be converted to a list of `tf.Tensor`.
+ """
+ ts = make_list_of_t(ts, allow_graph=False)
+ ops = []
+ for t in ts:
+ for op in t.consumers():
+ if op not in ops:
+ ops.append(op)
+ return ops
+
+
+class ControlOutputs(object):
+ """The control outputs topology."""
+
+ def __init__(self, graph):
+ """Create a dictionary of control-output dependencies.
+
+ Args:
+ graph: a `tf.Graph`.
+ Returns:
+ A dictionary where a key is a `tf.Operation` instance and the
+ corresponding value is a list of all the ops which have the key
+ as one of their control-input dependencies.
+ Raises:
+ TypeError: graph is not a `tf.Graph`.
+ """
+ if not isinstance(graph, tf_ops.Graph):
+ raise TypeError("Expected a tf.Graph, got: {}".format(type(graph)))
+ self._control_outputs = {}
+ self._graph = graph
+ self._version = None
+ self._build()
+
+ def update(self):
+ """Update the control outputs if the graph has changed."""
+ if self._version != self._graph.version:
+ self._build()
+ return self
+
+ def _build(self):
+ """Build the control outputs dictionary."""
+ self._control_outputs.clear()
+ ops = self._graph.get_operations()
+ for op in ops:
+ for control_input in op.control_inputs:
+ if control_input not in self._control_outputs:
+ self._control_outputs[control_input] = []
+ if op not in self._control_outputs[control_input]:
+ self._control_outputs[control_input].append(op)
+ self._version = self._graph.version
+
+ def get_all(self):
+ return self._control_outputs
+
+ def get(self, op):
+ """return the control outputs of op."""
+ if op in self._control_outputs:
+ return self._control_outputs[op]
+ else:
+ return ()
+
+ @property
+ def graph(self):
+ return self._graph
+
+
+def scope_finalize(scope):
+ if scope and scope[-1] != "/":
+ scope += "/"
+ return scope
+
+
+def scope_dirname(scope):
+ slash = scope.rfind("/")
+ if slash == -1:
+ return ""
+ return scope[:slash + 1]
+
+
+def scope_basename(scope):
+ slash = scope.rfind("/")
+ if slash == -1:
+ return scope
+ return scope[slash + 1:]
+
+
+def placeholder_name(t=None, scope=None, prefix=_DEFAULT_PLACEHOLDER_PREFIX):
+ """Create placeholder name for the graph editor.
+
+ Args:
+ t: optional tensor on which the placeholder operation's name will be based
+ on
+ scope: absolute scope with which to prefix the placeholder's name. None
+ means that the scope of t is preserved. "" means the root scope.
+ prefix: placeholder name prefix.
+ Returns:
+ A new placeholder name prefixed by "geph". Note that "geph" stands for
+ Graph Editor PlaceHolder. This convention allows to quickly identify the
+ placeholder generated by the Graph Editor.
+ Raises:
+ TypeError: if t is not None or a tf.Tensor.
+ """
+ if scope is not None:
+ scope = scope_finalize(scope)
+ if t is not None:
+ if not isinstance(t, tf_ops.Tensor):
+ raise TypeError("Expected a tf.Tenfor, got: {}".format(type(t)))
+ op_dirname = scope_dirname(t.op.name)
+ op_basename = scope_basename(t.op.name)
+ if scope is None:
+ scope = op_dirname
+
+ if op_basename.startswith("{}__".format(prefix)):
+ ph_name = op_basename
+ else:
+ ph_name = "{}__{}_{}".format(prefix, op_basename, t.value_index)
+
+ return scope + ph_name
+ else:
+ if scope is None:
+ scope = ""
+ return "{}{}".format(scope, prefix)
+
+
+def make_placeholder_from_tensor(t, scope=None,
+ prefix=_DEFAULT_PLACEHOLDER_PREFIX):
+ """Create a `tf.compat.v1.placeholder` for the Graph Editor.
+
+ Note that the correct graph scope must be set by the calling function.
+
+ Args:
+ t: a `tf.Tensor` whose name will be used to create the placeholder (see
+ function placeholder_name).
+ scope: absolute scope within which to create the placeholder. None means
+ that the scope of `t` is preserved. `""` means the root scope.
+ prefix: placeholder name prefix.
+
+ Returns:
+ A newly created `tf.compat.v1.placeholder`.
+ Raises:
+ TypeError: if `t` is not `None` or a `tf.Tensor`.
+ """
+ return tf_array_ops.placeholder(
+ dtype=t.dtype, shape=t.get_shape(),
+ name=placeholder_name(t, scope=scope, prefix=prefix))
+
+
+def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None,
+ prefix=_DEFAULT_PLACEHOLDER_PREFIX):
+ """Create a tf.compat.v1.placeholder for the Graph Editor.
+
+ Note that the correct graph scope must be set by the calling function.
+ The placeholder is named using the function placeholder_name (with no
+ tensor argument).
+
+ Args:
+ dtype: the tensor type.
+ shape: the tensor shape (optional).
+ scope: absolute scope within which to create the placeholder. None means
+ that the scope of t is preserved. "" means the root scope.
+ prefix: placeholder name prefix.
+
+ Returns:
+ A newly created tf.placeholder.
+ """
+ return tf_array_ops.placeholder(
+ dtype=dtype, shape=shape,
+ name=placeholder_name(scope=scope, prefix=prefix))
+
+
+_INTERNAL_VARIABLE_RE = re.compile(r"^__\w+__$")
+
+
+def get_predefined_collection_names():
+ """Return all the predefined collection names."""
+ return [getattr(tf_ops.GraphKeys, key) for key in dir(tf_ops.GraphKeys)
+ if not _INTERNAL_VARIABLE_RE.match(key)]
+
+
+def find_corresponding_elem(target, dst_graph, dst_scope="", src_scope=""):
+ """Find corresponding op/tensor in a different graph.
+
+ Args:
+ target: A `tf.Tensor` or a `tf.Operation` belonging to the original graph.
+ dst_graph: The graph in which the corresponding graph element must be found.
+ dst_scope: A scope which is prepended to the name to look for.
+ src_scope: A scope which is removed from the original of `target` name.
+
+ Returns:
+ The corresponding tf.Tensor` or a `tf.Operation`.
+
+ Raises:
+ ValueError: if `src_name` does not start with `src_scope`.
+ TypeError: if `target` is not a `tf.Tensor` or a `tf.Operation`
+ KeyError: If the corresponding graph element cannot be found.
+ """
+ src_name = target.name
+ if src_scope:
+ src_scope = scope_finalize(src_scope)
+ if not src_name.startswidth(src_scope):
+ raise ValueError("{} does not start with {}".format(src_name, src_scope))
+ src_name = src_name[len(src_scope):]
+
+ dst_name = src_name
+ if dst_scope:
+ dst_scope = scope_finalize(dst_scope)
+ dst_name = dst_scope + dst_name
+
+ if isinstance(target, tf_ops.Tensor):
+ return dst_graph.get_tensor_by_name(dst_name)
+ if isinstance(target, tf_ops.Operation):
+ return dst_graph.get_operation_by_name(dst_name)
+ raise TypeError("Expected tf.Tensor or tf.Operation, got: {}", type(target))
+
+
+def find_corresponding(targets, dst_graph, dst_scope="", src_scope=""):
+ """Find corresponding ops/tensors in a different graph.
+
+ `targets` is a Python tree, that is, a nested structure of iterable
+ (list, tupple, dictionary) whose leaves are instances of
+ `tf.Tensor` or `tf.Operation`
+
+ Args:
+ targets: A Python tree containing `tf.Tensor` or `tf.Operation`
+ belonging to the original graph.
+ dst_graph: The graph in which the corresponding graph element must be found.
+ dst_scope: A scope which is prepended to the name to look for.
+ src_scope: A scope which is removed from the original of `top` name.
+
+ Returns:
+ A Python tree containin the corresponding tf.Tensor` or a `tf.Operation`.
+
+ Raises:
+ ValueError: if `src_name` does not start with `src_scope`.
+ TypeError: if `top` is not a `tf.Tensor` or a `tf.Operation`
+ KeyError: If the corresponding graph element cannot be found.
+ """
+ def func(top):
+ return find_corresponding_elem(top, dst_graph, dst_scope, src_scope)
+ return transform_tree(targets, func)
diff --git a/train.py b/train.py
index d497440..bfba032 100755
--- a/train.py
+++ b/train.py
@@ -12,8 +12,10 @@ import tensorflow as tf
import time
import tqdm
from tensorflow.core.protobuf import rewriter_config_pb2
-from tensorflow.contrib import tpu
-from tensorflow.contrib.cluster_resolver import TPUClusterResolver
+from tensorflow.compat.v1 import tpu
+#import tensorflow.distribute
+#TPUClusterResolver = tf.distribute.TPUClusterResolver
+from tensorflow.compat.v1.distribute.cluster_resolver import TPUClusterResolver
from tensorflow.python import pywrap_tensorflow
import model, sample, encoder
@@ -189,18 +191,18 @@ def main(tpu_cluster=None):
if args.optimizer == 'adam':
args.only_train_transformer_layers = True
- config = tf.ConfigProto()
+ config = tf.compat.v1.ConfigProto()
if args.allow_growth:
config.gpu_options.allow_growth = True
if args.disable_layout_optimizer:
config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF
- with tf.Session(tpu_cluster, config=config) as sess:
+ with tf.compat.v1.Session(tpu_cluster, config=config) as sess:
if tpu_cluster and args.init_tpu:
print("initializing TPU system...")
sess.run(tpu.initialize_system())
if tpu_cluster:
print("Using TPU %s" % tpu_cluster)
- context = tf.placeholder(tf.int32, [args.batch_size, None])
+ context = tf.compat.v1.placeholder(tf.int32, [args.batch_size, None])
context_in = randomize(context, hparams, args.noise)
output = model.model(hparams=hparams, X=context_in)
loss = tf.reduce_mean(
@@ -208,12 +210,12 @@ def main(tpu_cluster=None):
labels=context[:, 1:], logits=output['logits'][:, :-1]))
if args.val_every > 0:
- val_context = tf.placeholder(tf.int32, [args.val_batch_size, None])
+ val_context = tf.compat.v1.placeholder(tf.int32, [args.val_batch_size, None])
val_output = model.model(hparams=hparams, X=val_context)
val_loss = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=val_context[:, 1:], logits=val_output['logits'][:, :-1]))
- val_loss_summary = tf.summary.scalar('val_loss', val_loss)
+ val_loss_summary = tf.compat.v1.summary.scalar('val_loss', val_loss)
tf_sample = sample.sample_sequence(
@@ -226,14 +228,14 @@ def main(tpu_cluster=None):
top_p=args.top_p,
epsilon=epsilon)
- all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
+ all_vars = [v for v in tf.compat.v1.trainable_variables() if 'model' in v.name]
train_vars = [v for v in all_vars if '/h' in v.name] if args.only_train_transformer_layers else all_vars
parameter_count = sum([np.prod(v.shape.as_list()) for v in train_vars])
print("This model is using %d parameters (%.2fM)" % (parameter_count, parameter_count/(1024.0*1024.0)))
- with tf.variable_scope(tf.get_variable_scope().name, reuse=tf.AUTO_REUSE):
- global_step = tflex.get_variable('global_step') or tf.get_variable('global_step', shape=(), dtype=tf.int32, trainable=False)
+ with tflex.variable_scope(reuse=tf.compat.v1.AUTO_REUSE):
+ global_step = model.init_variable('global_step', shape=(), dtype=tf.int32, trainable=False)
current_step = args.learning_rate_initial_step
global_step.load(current_step, session=sess)
if args.learning_rate_cos:
@@ -243,9 +245,9 @@ def main(tpu_cluster=None):
lr = tf.constant(args.learning_rate)
if args.optimizer == 'adam':
- opt = tf.train.AdamOptimizer(learning_rate=lr)
+ opt = tf.compat.v1.train.AdamOptimizer(learning_rate=lr)
elif args.optimizer == 'sgd':
- opt = tf.train.GradientDescentOptimizer(learning_rate=lr)
+ opt = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=lr)
elif args.optimizer == 'ada':
import tensor2tensor.utils.optimize
from tensor2tensor.utils import hparam
@@ -259,8 +261,9 @@ def main(tpu_cluster=None):
exit('Bad optimizer:', args.optimizer)
if tpu_cluster:
+ pass
# https://pulsejet.github.io/blog/posts/tpu-without-estimator/
- from tensorflow.contrib.tpu.python.tpu import tpu_function
+ #from tensorflow.contrib.tpu.python.tpu import tpu_function
#tpu_function.get_tpu_context().set_number_of_shards(8)
#opt = tf.contrib.tpu.CrossShardOptimizer(opt)
@@ -273,7 +276,7 @@ def main(tpu_cluster=None):
opt_reset = opt.reset()
opt_compute = opt.compute_gradients(loss)
opt_apply = opt.apply_gradients()
- summary_loss = tf.summary.scalar('loss', opt_apply)
+ summary_loss = tf.compat.v1.summary.scalar('loss', opt_apply)
else:
if args.memory_saving_gradients:
opt_grads = memory_saving_gradients.gradients(loss, train_vars)
@@ -281,12 +284,12 @@ def main(tpu_cluster=None):
opt_grads = tf.gradients(loss, train_vars)
opt_grads = list(zip(opt_grads, train_vars))
opt_apply = opt.apply_gradients(opt_grads)
- summary_loss = tf.summary.scalar('loss', loss)
+ summary_loss = tf.compat.v1.summary.scalar('loss', loss)
- summary_lr = tf.summary.scalar('learning_rate', lr)
- summaries = tf.summary.merge([summary_lr, summary_loss])
+ summary_lr = tf.compat.v1.summary.scalar('learning_rate', lr)
+ summaries = tf.compat.v1.summary.merge([summary_lr, summary_loss])
- summary_log = tf.summary.FileWriter(
+ summary_log = tf.compat.v1.summary.FileWriter(
os.path.join(CHECKPOINT_DIR, args.run_name))
if args.save_graph:
@@ -297,7 +300,7 @@ def main(tpu_cluster=None):
max_to_keep=args.max_to_keep,
keep_checkpoint_every_n_hours=2,
reshape=args.truncate_weights)
- sess.run(tf.global_variables_initializer())
+ sess.run(tf.compat.v1.global_variables_initializer())
if args.restore_from == 'latest':
ckpt = tflex.latest_checkpoint(
@@ -528,7 +531,7 @@ def main(tpu_cluster=None):
print('trainable variables:')
print('name/shape/parameter_count')
param_count = 0
- for x in tf.trainable_variables():
+ for x in tf.compat.v1.trainable_variables():
shape = x.shape.as_list()
count = np.prod(shape)
print(x.name, shape, count)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment