Created
November 15, 2019 12:47
-
-
Save shawwn/3edcae6d4b5ba02d41be9bffda653fec to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/requirements.txt b/requirements.txt | |
index 0d2556c..a573910 100644 | |
--- a/requirements.txt | |
+++ b/requirements.txt | |
@@ -5,3 +5,4 @@ tqdm==4.31.1 | |
toposort==1.5 | |
tensor2tensor>=1.14.1 | |
h5py | |
+tensorflow_addons>=0.6.0 | |
diff --git a/src/generate_unconditional_samples.py b/src/generate_unconditional_samples.py | |
index 87e5132..176cccd 100755 | |
--- a/src/generate_unconditional_samples.py | |
+++ b/src/generate_unconditional_samples.py | |
@@ -54,9 +54,9 @@ def sample_model( | |
elif length > hparams.n_ctx: | |
raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) | |
- with tf.Session(graph=tf.Graph()) as sess: | |
+ with tf.compat.v1.Session(graph=tf.Graph()) as sess: | |
np.random.seed(seed) | |
- tf.set_random_seed(seed) | |
+ tf.compat.v1.set_random_seed(seed) | |
output = sample.sample_sequence( | |
hparams=hparams, length=length, | |
@@ -71,6 +71,7 @@ def sample_model( | |
generated = 0 | |
while nsamples == 0 or generated < nsamples: | |
+ print('Generating...') | |
out = sess.run(output) | |
for i in range(batch_size): | |
generated += 1 | |
diff --git a/src/hparam.py b/src/hparam.py | |
new file mode 100644 | |
index 0000000..6495838 | |
--- /dev/null | |
+++ b/src/hparam.py | |
@@ -0,0 +1,747 @@ | |
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved. | |
+# | |
+# Licensed under the Apache License, Version 2.0 (the "License"); | |
+# you may not use this file except in compliance with the License. | |
+# You may obtain a copy of the License at | |
+# | |
+# http://www.apache.org/licenses/LICENSE-2.0 | |
+# | |
+# Unless required by applicable law or agreed to in writing, software | |
+# distributed under the License is distributed on an "AS IS" BASIS, | |
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+# See the License for the specific language governing permissions and | |
+# limitations under the License. | |
+# ============================================================================== | |
+"""Hyperparameter values.""" | |
+from __future__ import absolute_import | |
+from __future__ import division | |
+from __future__ import print_function | |
+ | |
+import json | |
+import numbers | |
+import re | |
+ | |
+import six | |
+ | |
+import hparam_pb2 | |
+from tensorflow.python.framework import ops | |
+from tensorflow.python.util import compat | |
+from tensorflow.python.util import deprecation | |
+ | |
+# Define the regular expression for parsing a single clause of the input | |
+# (delimited by commas). A legal clause looks like: | |
+# <variable name>[<index>]? = <rhs> | |
+# where <rhs> is either a single token or [] enclosed list of tokens. | |
+# For example: "var[1] = a" or "x = [1,2,3]" | |
+PARAM_RE = re.compile(r""" | |
+ (?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x" | |
+ (\[\s*(?P<index>\d+)\s*\])? # (optional) index: "1" or None | |
+ \s*=\s* | |
+ ((?P<val>[^,\[]*) # single value: "a" or None | |
+ | | |
+ \[(?P<vals>[^\]]*)\]) # list of values: None or "1,2,3" | |
+ ($|,\s*)""", re.VERBOSE) | |
+ | |
+ | |
+def _parse_fail(name, var_type, value, values): | |
+ """Helper function for raising a value error for bad assignment.""" | |
+ raise ValueError( | |
+ 'Could not parse hparam \'%s\' of type \'%s\' with value \'%s\' in %s' % | |
+ (name, var_type.__name__, value, values)) | |
+ | |
+ | |
+def _reuse_fail(name, values): | |
+ """Helper function for raising a value error for reuse of name.""" | |
+ raise ValueError('Multiple assignments to variable \'%s\' in %s' % (name, | |
+ values)) | |
+ | |
+ | |
+def _process_scalar_value(name, parse_fn, var_type, m_dict, values, | |
+ results_dictionary): | |
+ """Update results_dictionary with a scalar value. | |
+ | |
+ Used to update the results_dictionary to be returned by parse_values when | |
+ encountering a clause with a scalar RHS (e.g. "s=5" or "arr[0]=5".) | |
+ | |
+ Mutates results_dictionary. | |
+ | |
+ Args: | |
+ name: Name of variable in assignment ("s" or "arr"). | |
+ parse_fn: Function for parsing the actual value. | |
+ var_type: Type of named variable. | |
+ m_dict: Dictionary constructed from regex parsing. | |
+ m_dict['val']: RHS value (scalar) | |
+ m_dict['index']: List index value (or None) | |
+ values: Full expression being parsed | |
+ results_dictionary: The dictionary being updated for return by the parsing | |
+ function. | |
+ | |
+ Raises: | |
+ ValueError: If the name has already been used. | |
+ """ | |
+ try: | |
+ parsed_value = parse_fn(m_dict['val']) | |
+ except ValueError: | |
+ _parse_fail(name, var_type, m_dict['val'], values) | |
+ | |
+ # If no index is provided | |
+ if not m_dict['index']: | |
+ if name in results_dictionary: | |
+ _reuse_fail(name, values) | |
+ results_dictionary[name] = parsed_value | |
+ else: | |
+ if name in results_dictionary: | |
+ # The name has already been used as a scalar, then it | |
+ # will be in this dictionary and map to a non-dictionary. | |
+ if not isinstance(results_dictionary.get(name), dict): | |
+ _reuse_fail(name, values) | |
+ else: | |
+ results_dictionary[name] = {} | |
+ | |
+ index = int(m_dict['index']) | |
+ # Make sure the index position hasn't already been assigned a value. | |
+ if index in results_dictionary[name]: | |
+ _reuse_fail('{}[{}]'.format(name, index), values) | |
+ results_dictionary[name][index] = parsed_value | |
+ | |
+ | |
+def _process_list_value(name, parse_fn, var_type, m_dict, values, | |
+ results_dictionary): | |
+ """Update results_dictionary from a list of values. | |
+ | |
+ Used to update results_dictionary to be returned by parse_values when | |
+ encountering a clause with a list RHS (e.g. "arr=[1,2,3]".) | |
+ | |
+ Mutates results_dictionary. | |
+ | |
+ Args: | |
+ name: Name of variable in assignment ("arr"). | |
+ parse_fn: Function for parsing individual values. | |
+ var_type: Type of named variable. | |
+ m_dict: Dictionary constructed from regex parsing. | |
+ m_dict['val']: RHS value (scalar) | |
+ values: Full expression being parsed | |
+ results_dictionary: The dictionary being updated for return by the parsing | |
+ function. | |
+ | |
+ Raises: | |
+ ValueError: If the name has an index or the values cannot be parsed. | |
+ """ | |
+ if m_dict['index'] is not None: | |
+ raise ValueError('Assignment of a list to a list index.') | |
+ elements = filter(None, re.split('[ ,]', m_dict['vals'])) | |
+ # Make sure the name hasn't already been assigned a value | |
+ if name in results_dictionary: | |
+ raise _reuse_fail(name, values) | |
+ try: | |
+ results_dictionary[name] = [parse_fn(e) for e in elements] | |
+ except ValueError: | |
+ _parse_fail(name, var_type, m_dict['vals'], values) | |
+ | |
+ | |
+def _cast_to_type_if_compatible(name, param_type, value): | |
+ """Cast hparam to the provided type, if compatible. | |
+ | |
+ Args: | |
+ name: Name of the hparam to be cast. | |
+ param_type: The type of the hparam. | |
+ value: The value to be cast, if compatible. | |
+ | |
+ Returns: | |
+ The result of casting `value` to `param_type`. | |
+ | |
+ Raises: | |
+ ValueError: If the type of `value` is not compatible with param_type. | |
+ * If `param_type` is a string type, but `value` is not. | |
+ * If `param_type` is a boolean, but `value` is not, or vice versa. | |
+ * If `param_type` is an integer type, but `value` is not. | |
+ * If `param_type` is a float type, but `value` is not a numeric type. | |
+ """ | |
+ fail_msg = ( | |
+ "Could not cast hparam '%s' of type '%s' from value %r" % | |
+ (name, param_type, value)) | |
+ | |
+ # If `value` is already of type `param_type`, return it directly. | |
+ # `isinstance` is too weak (e.g. isinstance(True, int) == True). | |
+ if type(value) == param_type: # pylint: disable=unidiomatic-typecheck | |
+ return value | |
+ | |
+ # Some callers use None, for which we can't do any casting/checking. :( | |
+ if issubclass(param_type, type(None)): | |
+ return value | |
+ | |
+ # Avoid converting a non-string type to a string. | |
+ if (issubclass(param_type, (six.string_types, six.binary_type)) and | |
+ not isinstance(value, (six.string_types, six.binary_type))): | |
+ raise ValueError(fail_msg) | |
+ | |
+ # Avoid converting a number or string type to a boolean or vice versa. | |
+ if issubclass(param_type, bool) != isinstance(value, bool): | |
+ raise ValueError(fail_msg) | |
+ | |
+ # Avoid converting float to an integer (the reverse is fine). | |
+ if (issubclass(param_type, numbers.Integral) and | |
+ not isinstance(value, numbers.Integral)): | |
+ raise ValueError(fail_msg) | |
+ | |
+ # Avoid converting a non-numeric type to a numeric type. | |
+ if (issubclass(param_type, numbers.Number) and | |
+ not isinstance(value, numbers.Number)): | |
+ raise ValueError(fail_msg) | |
+ | |
+ return param_type(value) | |
+ | |
+ | |
+def parse_values(values, type_map, ignore_unknown=False): | |
+ """Parses hyperparameter values from a string into a python map. | |
+ | |
+ `values` is a string containing comma-separated `name=value` pairs. | |
+ For each pair, the value of the hyperparameter named `name` is set to | |
+ `value`. | |
+ | |
+ If a hyperparameter name appears multiple times in `values`, a ValueError | |
+ is raised (e.g. 'a=1,a=2', 'a[1]=1,a[1]=2'). | |
+ | |
+ If a hyperparameter name in both an index assignment and scalar assignment, | |
+ a ValueError is raised. (e.g. 'a=[1,2,3],a[0] = 1'). | |
+ | |
+ The hyperparameter name may contain '.' symbols, which will result in an | |
+ attribute name that is only accessible through the getattr and setattr | |
+ functions. (And must be first explicit added through add_hparam.) | |
+ | |
+ WARNING: Use of '.' in your variable names is allowed, but is not well | |
+ supported and not recommended. | |
+ | |
+ The `value` in `name=value` must follows the syntax according to the | |
+ type of the parameter: | |
+ | |
+ * Scalar integer: A Python-parsable integer point value. E.g.: 1, | |
+ 100, -12. | |
+ * Scalar float: A Python-parsable floating point value. E.g.: 1.0, | |
+ -.54e89. | |
+ * Boolean: Either true or false. | |
+ * Scalar string: A non-empty sequence of characters, excluding comma, | |
+ spaces, and square brackets. E.g.: foo, bar_1. | |
+ * List: A comma separated list of scalar values of the parameter type | |
+ enclosed in square brackets. E.g.: [1,2,3], [1.0,1e-12], [high,low]. | |
+ | |
+ When index assignment is used, the corresponding type_map key should be the | |
+ list name. E.g. for "arr[1]=0" the type_map must have the key "arr" (not | |
+ "arr[1]"). | |
+ | |
+ Args: | |
+ values: String. Comma separated list of `name=value` pairs where | |
+ 'value' must follow the syntax described above. | |
+ type_map: A dictionary mapping hyperparameter names to types. Note every | |
+ parameter name in values must be a key in type_map. The values must | |
+ conform to the types indicated, where a value V is said to conform to a | |
+ type T if either V has type T, or V is a list of elements of type T. | |
+ Hence, for a multidimensional parameter 'x' taking float values, | |
+ 'x=[0.1,0.2]' will parse successfully if type_map['x'] = float. | |
+ ignore_unknown: Bool. Whether values that are missing a type in type_map | |
+ should be ignored. If set to True, a ValueError will not be raised for | |
+ unknown hyperparameter type. | |
+ | |
+ Returns: | |
+ A python map mapping each name to either: | |
+ * A scalar value. | |
+ * A list of scalar values. | |
+ * A dictionary mapping index numbers to scalar values. | |
+ (e.g. "x=5,L=[1,2],arr[1]=3" results in {'x':5,'L':[1,2],'arr':{1:3}}") | |
+ | |
+ Raises: | |
+ ValueError: If there is a problem with input. | |
+ * If `values` cannot be parsed. | |
+ * If a list is assigned to a list index (e.g. 'a[1] = [1,2,3]'). | |
+ * If the same rvalue is assigned two different values (e.g. 'a=1,a=2', | |
+ 'a[1]=1,a[1]=2', or 'a=1,a=[1]') | |
+ """ | |
+ results_dictionary = {} | |
+ pos = 0 | |
+ while pos < len(values): | |
+ m = PARAM_RE.match(values, pos) | |
+ if not m: | |
+ raise ValueError('Malformed hyperparameter value: %s' % values[pos:]) | |
+ # Check that there is a comma between parameters and move past it. | |
+ pos = m.end() | |
+ # Parse the values. | |
+ m_dict = m.groupdict() | |
+ name = m_dict['name'] | |
+ if name not in type_map: | |
+ if ignore_unknown: | |
+ continue | |
+ raise ValueError('Unknown hyperparameter type for %s' % name) | |
+ type_ = type_map[name] | |
+ | |
+ # Set up correct parsing function (depending on whether type_ is a bool) | |
+ if type_ == bool: | |
+ | |
+ def parse_bool(value): | |
+ if value in ['true', 'True']: | |
+ return True | |
+ elif value in ['false', 'False']: | |
+ return False | |
+ else: | |
+ try: | |
+ return bool(int(value)) | |
+ except ValueError: | |
+ _parse_fail(name, type_, value, values) | |
+ | |
+ parse = parse_bool | |
+ else: | |
+ parse = type_ | |
+ | |
+ # If a singe value is provided | |
+ if m_dict['val'] is not None: | |
+ _process_scalar_value(name, parse, type_, m_dict, values, | |
+ results_dictionary) | |
+ | |
+ # If the assigned value is a list: | |
+ elif m_dict['vals'] is not None: | |
+ _process_list_value(name, parse, type_, m_dict, values, | |
+ results_dictionary) | |
+ | |
+ else: # Not assigned a list or value | |
+ _parse_fail(name, type_, '', values) | |
+ | |
+ return results_dictionary | |
+ | |
+ | |
+class HParams(object): | |
+ """Class to hold a set of hyperparameters as name-value pairs. | |
+ | |
+ A `HParams` object holds hyperparameters used to build and train a model, | |
+ such as the number of hidden units in a neural net layer or the learning rate | |
+ to use when training. | |
+ | |
+ You first create a `HParams` object by specifying the names and values of the | |
+ hyperparameters. | |
+ | |
+ To make them easily accessible the parameter names are added as direct | |
+ attributes of the class. A typical usage is as follows: | |
+ | |
+ ```python | |
+ # Create a HParams object specifying names and values of the model | |
+ # hyperparameters: | |
+ hparams = HParams(learning_rate=0.1, num_hidden_units=100) | |
+ | |
+ # The hyperparameter are available as attributes of the HParams object: | |
+ hparams.learning_rate ==> 0.1 | |
+ hparams.num_hidden_units ==> 100 | |
+ ``` | |
+ | |
+ Hyperparameters have type, which is inferred from the type of their value | |
+ passed at construction type. The currently supported types are: integer, | |
+ float, boolean, string, and list of integer, float, boolean, or string. | |
+ | |
+ You can override hyperparameter values by calling the | |
+ [`parse()`](#HParams.parse) method, passing a string of comma separated | |
+ `name=value` pairs. This is intended to make it possible to override | |
+ any hyperparameter values from a single command-line flag to which | |
+ the user passes 'hyper-param=value' pairs. It avoids having to define | |
+ one flag for each hyperparameter. | |
+ | |
+ The syntax expected for each value depends on the type of the parameter. | |
+ See `parse()` for a description of the syntax. | |
+ | |
+ Example: | |
+ | |
+ ```python | |
+ # Define a command line flag to pass name=value pairs. | |
+ # For example using argparse: | |
+ import argparse | |
+ parser = argparse.ArgumentParser(description='Train my model.') | |
+ parser.add_argument('--hparams', type=str, | |
+ help='Comma separated list of "name=value" pairs.') | |
+ args = parser.parse_args() | |
+ ... | |
+ def my_program(): | |
+ # Create a HParams object specifying the names and values of the | |
+ # model hyperparameters: | |
+ hparams = tf.contrib.training.HParams( | |
+ learning_rate=0.1, | |
+ num_hidden_units=100, | |
+ activations=['relu', 'tanh']) | |
+ | |
+ # Override hyperparameters values by parsing the command line | |
+ hparams.parse(args.hparams) | |
+ | |
+ # If the user passed `--hparams=learning_rate=0.3` on the command line | |
+ # then 'hparams' has the following attributes: | |
+ hparams.learning_rate ==> 0.3 | |
+ hparams.num_hidden_units ==> 100 | |
+ hparams.activations ==> ['relu', 'tanh'] | |
+ | |
+ # If the hyperparameters are in json format use parse_json: | |
+ hparams.parse_json('{"learning_rate": 0.3, "activations": "relu"}') | |
+ ``` | |
+ """ | |
+ | |
+ _HAS_DYNAMIC_ATTRIBUTES = True # Required for pytype checks. | |
+ | |
+ def __init__(self, hparam_def=None, model_structure=None, **kwargs): | |
+ """Create an instance of `HParams` from keyword arguments. | |
+ | |
+ The keyword arguments specify name-values pairs for the hyperparameters. | |
+ The parameter types are inferred from the type of the values passed. | |
+ | |
+ The parameter names are added as attributes of `HParams` object, so they | |
+ can be accessed directly with the dot notation `hparams._name_`. | |
+ | |
+ Example: | |
+ | |
+ ```python | |
+ # Define 3 hyperparameters: 'learning_rate' is a float parameter, | |
+ # 'num_hidden_units' an integer parameter, and 'activation' a string | |
+ # parameter. | |
+ hparams = tf.contrib.training.HParams( | |
+ learning_rate=0.1, num_hidden_units=100, activation='relu') | |
+ | |
+ hparams.activation ==> 'relu' | |
+ ``` | |
+ | |
+ Note that a few names are reserved and cannot be used as hyperparameter | |
+ names. If you use one of the reserved name the constructor raises a | |
+ `ValueError`. | |
+ | |
+ Args: | |
+ hparam_def: Serialized hyperparameters, encoded as a hparam_pb2.HParamDef | |
+ protocol buffer. If provided, this object is initialized by | |
+ deserializing hparam_def. Otherwise **kwargs is used. | |
+ model_structure: An instance of ModelStructure, defining the feature | |
+ crosses to be used in the Trial. | |
+ **kwargs: Key-value pairs where the key is the hyperparameter name and | |
+ the value is the value for the parameter. | |
+ | |
+ Raises: | |
+ ValueError: If both `hparam_def` and initialization values are provided, | |
+ or if one of the arguments is invalid. | |
+ | |
+ """ | |
+ # Register the hyperparameters and their type in _hparam_types. | |
+ # This simplifies the implementation of parse(). | |
+ # _hparam_types maps the parameter name to a tuple (type, bool). | |
+ # The type value is the type of the parameter for scalar hyperparameters, | |
+ # or the type of the list elements for multidimensional hyperparameters. | |
+ # The bool value is True if the value is a list, False otherwise. | |
+ self._hparam_types = {} | |
+ self._model_structure = model_structure | |
+ if hparam_def: | |
+ self._init_from_proto(hparam_def) | |
+ if kwargs: | |
+ raise ValueError('hparam_def and initialization values are ' | |
+ 'mutually exclusive') | |
+ else: | |
+ for name, value in six.iteritems(kwargs): | |
+ self.add_hparam(name, value) | |
+ | |
+ def _init_from_proto(self, hparam_def): | |
+ """Creates a new HParams from `HParamDef` protocol buffer. | |
+ | |
+ Args: | |
+ hparam_def: `HParamDef` protocol buffer. | |
+ """ | |
+ assert isinstance(hparam_def, hparam_pb2.HParamDef) | |
+ for name, value in hparam_def.hparam.items(): | |
+ kind = value.WhichOneof('kind') | |
+ if kind.endswith('_value'): | |
+ # Single value. | |
+ if kind.startswith('int64'): | |
+ # Setting attribute value to be 'int' to ensure the type is compatible | |
+ # with both Python2 and Python3. | |
+ self.add_hparam(name, int(getattr(value, kind))) | |
+ elif kind.startswith('bytes'): | |
+ # Setting attribute value to be 'str' to ensure the type is compatible | |
+ # with both Python2 and Python3. UTF-8 encoding is assumed. | |
+ self.add_hparam(name, compat.as_str(getattr(value, kind))) | |
+ else: | |
+ self.add_hparam(name, getattr(value, kind)) | |
+ else: | |
+ # List of values. | |
+ if kind.startswith('int64'): | |
+ # Setting attribute value to be 'int' to ensure the type is compatible | |
+ # with both Python2 and Python3. | |
+ self.add_hparam(name, [int(v) for v in getattr(value, kind).value]) | |
+ elif kind.startswith('bytes'): | |
+ # Setting attribute value to be 'str' to ensure the type is compatible | |
+ # with both Python2 and Python3. UTF-8 encoding is assumed. | |
+ self.add_hparam( | |
+ name, [compat.as_str(v) for v in getattr(value, kind).value]) | |
+ else: | |
+ self.add_hparam(name, [v for v in getattr(value, kind).value]) | |
+ | |
+ def add_hparam(self, name, value): | |
+ """Adds {name, value} pair to hyperparameters. | |
+ | |
+ Args: | |
+ name: Name of the hyperparameter. | |
+ value: Value of the hyperparameter. Can be one of the following types: | |
+ int, float, string, int list, float list, or string list. | |
+ | |
+ Raises: | |
+ ValueError: if one of the arguments is invalid. | |
+ """ | |
+ # Keys in kwargs are unique, but 'name' could the name of a pre-existing | |
+ # attribute of this object. In that case we refuse to use it as a | |
+ # hyperparameter name. | |
+ if getattr(self, name, None) is not None: | |
+ raise ValueError('Hyperparameter name is reserved: %s' % name) | |
+ if isinstance(value, (list, tuple)): | |
+ if not value: | |
+ raise ValueError( | |
+ 'Multi-valued hyperparameters cannot be empty: %s' % name) | |
+ self._hparam_types[name] = (type(value[0]), True) | |
+ else: | |
+ self._hparam_types[name] = (type(value), False) | |
+ setattr(self, name, value) | |
+ | |
+ def set_hparam(self, name, value): | |
+ """Set the value of an existing hyperparameter. | |
+ | |
+ This function verifies that the type of the value matches the type of the | |
+ existing hyperparameter. | |
+ | |
+ Args: | |
+ name: Name of the hyperparameter. | |
+ value: New value of the hyperparameter. | |
+ | |
+ Raises: | |
+ KeyError: If the hyperparameter doesn't exist. | |
+ ValueError: If there is a type mismatch. | |
+ """ | |
+ param_type, is_list = self._hparam_types[name] | |
+ if isinstance(value, list): | |
+ if not is_list: | |
+ raise ValueError( | |
+ 'Must not pass a list for single-valued parameter: %s' % name) | |
+ setattr(self, name, [ | |
+ _cast_to_type_if_compatible(name, param_type, v) for v in value]) | |
+ else: | |
+ if is_list: | |
+ raise ValueError( | |
+ 'Must pass a list for multi-valued parameter: %s.' % name) | |
+ setattr(self, name, _cast_to_type_if_compatible(name, param_type, value)) | |
+ | |
+ def del_hparam(self, name): | |
+ """Removes the hyperparameter with key 'name'. | |
+ | |
+ Does nothing if it isn't present. | |
+ | |
+ Args: | |
+ name: Name of the hyperparameter. | |
+ """ | |
+ if hasattr(self, name): | |
+ delattr(self, name) | |
+ del self._hparam_types[name] | |
+ | |
+ def parse(self, values): | |
+ """Override existing hyperparameter values, parsing new values from a string. | |
+ | |
+ See parse_values for more detail on the allowed format for values. | |
+ | |
+ Args: | |
+ values: String. Comma separated list of `name=value` pairs where 'value' | |
+ must follow the syntax described above. | |
+ | |
+ Returns: | |
+ The `HParams` instance. | |
+ | |
+ Raises: | |
+ ValueError: If `values` cannot be parsed or a hyperparameter in `values` | |
+ doesn't exist. | |
+ """ | |
+ type_map = {} | |
+ for name, t in self._hparam_types.items(): | |
+ param_type, _ = t | |
+ type_map[name] = param_type | |
+ | |
+ values_map = parse_values(values, type_map) | |
+ return self.override_from_dict(values_map) | |
+ | |
+ def override_from_dict(self, values_dict): | |
+ """Override existing hyperparameter values, parsing new values from a dictionary. | |
+ | |
+ Args: | |
+ values_dict: Dictionary of name:value pairs. | |
+ | |
+ Returns: | |
+ The `HParams` instance. | |
+ | |
+ Raises: | |
+ KeyError: If a hyperparameter in `values_dict` doesn't exist. | |
+ ValueError: If `values_dict` cannot be parsed. | |
+ """ | |
+ for name, value in values_dict.items(): | |
+ self.set_hparam(name, value) | |
+ return self | |
+ | |
+ @deprecation.deprecated(None, 'Use `override_from_dict`.') | |
+ def set_from_map(self, values_map): | |
+ """DEPRECATED. Use override_from_dict.""" | |
+ return self.override_from_dict(values_dict=values_map) | |
+ | |
+ def set_model_structure(self, model_structure): | |
+ self._model_structure = model_structure | |
+ | |
+ def get_model_structure(self): | |
+ return self._model_structure | |
+ | |
+ def to_json(self, indent=None, separators=None, sort_keys=False): | |
+ """Serializes the hyperparameters into JSON. | |
+ | |
+ Args: | |
+ indent: If a non-negative integer, JSON array elements and object members | |
+ will be pretty-printed with that indent level. An indent level of 0, or | |
+ negative, will only insert newlines. `None` (the default) selects the | |
+ most compact representation. | |
+ separators: Optional `(item_separator, key_separator)` tuple. Default is | |
+ `(', ', ': ')`. | |
+ sort_keys: If `True`, the output dictionaries will be sorted by key. | |
+ | |
+ Returns: | |
+ A JSON string. | |
+ """ | |
+ return json.dumps( | |
+ self.values(), | |
+ indent=indent, | |
+ separators=separators, | |
+ sort_keys=sort_keys) | |
+ | |
+ def parse_json(self, values_json): | |
+ """Override existing hyperparameter values, parsing new values from a json object. | |
+ | |
+ Args: | |
+ values_json: String containing a json object of name:value pairs. | |
+ | |
+ Returns: | |
+ The `HParams` instance. | |
+ | |
+ Raises: | |
+ KeyError: If a hyperparameter in `values_json` doesn't exist. | |
+ ValueError: If `values_json` cannot be parsed. | |
+ """ | |
+ values_map = json.loads(values_json) | |
+ return self.override_from_dict(values_map) | |
+ | |
+ def values(self): | |
+ """Return the hyperparameter values as a Python dictionary. | |
+ | |
+ Returns: | |
+ A dictionary with hyperparameter names as keys. The values are the | |
+ hyperparameter values. | |
+ """ | |
+ return {n: getattr(self, n) for n in self._hparam_types.keys()} | |
+ | |
+ def get(self, key, default=None): | |
+ """Returns the value of `key` if it exists, else `default`.""" | |
+ if key in self._hparam_types: | |
+ # Ensure that default is compatible with the parameter type. | |
+ if default is not None: | |
+ param_type, is_param_list = self._hparam_types[key] | |
+ type_str = 'list<%s>' % param_type if is_param_list else str(param_type) | |
+ fail_msg = ("Hparam '%s' of type '%s' is incompatible with " | |
+ 'default=%s' % (key, type_str, default)) | |
+ | |
+ is_default_list = isinstance(default, list) | |
+ if is_param_list != is_default_list: | |
+ raise ValueError(fail_msg) | |
+ | |
+ try: | |
+ if is_default_list: | |
+ for value in default: | |
+ _cast_to_type_if_compatible(key, param_type, value) | |
+ else: | |
+ _cast_to_type_if_compatible(key, param_type, default) | |
+ except ValueError as e: | |
+ raise ValueError('%s. %s' % (fail_msg, e)) | |
+ | |
+ return getattr(self, key) | |
+ | |
+ return default | |
+ | |
+ def __contains__(self, key): | |
+ return key in self._hparam_types | |
+ | |
+ def __str__(self): | |
+ hpdict = self.values() | |
+ output_list = ['{}={}'.format(key, hpdict[key]) for key in hpdict] | |
+ return ','.join(output_list) | |
+ | |
+ def __repr__(self): | |
+ strval = str(sorted(self.values().items())) | |
+ return '%s(%s)' % (type(self).__name__, strval) | |
+ | |
+ @staticmethod | |
+ def _get_kind_name(param_type, is_list): | |
+ """Returns the field name given parameter type and is_list. | |
+ | |
+ Args: | |
+ param_type: Data type of the hparam. | |
+ is_list: Whether this is a list. | |
+ | |
+ Returns: | |
+ A string representation of the field name. | |
+ | |
+ Raises: | |
+ ValueError: If parameter type is not recognized. | |
+ """ | |
+ if issubclass(param_type, bool): | |
+ # This check must happen before issubclass(param_type, six.integer_types), | |
+ # since Python considers bool to be a subclass of int. | |
+ typename = 'bool' | |
+ elif issubclass(param_type, six.integer_types): | |
+ # Setting 'int' and 'long' types to be 'int64' to ensure the type is | |
+ # compatible with both Python2 and Python3. | |
+ typename = 'int64' | |
+ elif issubclass(param_type, (six.string_types, six.binary_type)): | |
+ # Setting 'string' and 'bytes' types to be 'bytes' to ensure the type is | |
+ # compatible with both Python2 and Python3. | |
+ typename = 'bytes' | |
+ elif issubclass(param_type, float): | |
+ typename = 'float' | |
+ else: | |
+ raise ValueError('Unsupported parameter type: %s' % str(param_type)) | |
+ | |
+ suffix = 'list' if is_list else 'value' | |
+ return '_'.join([typename, suffix]) | |
+ | |
+ def to_proto(self, export_scope=None): # pylint: disable=unused-argument | |
+ """Converts a `HParams` object to a `HParamDef` protocol buffer. | |
+ | |
+ Args: | |
+ export_scope: Optional `string`. Name scope to remove. | |
+ | |
+ Returns: | |
+ A `HParamDef` protocol buffer. | |
+ """ | |
+ hparam_proto = hparam_pb2.HParamDef() | |
+ for name in self._hparam_types: | |
+ # Parse the values. | |
+ param_type, is_list = self._hparam_types.get(name, (None, None)) | |
+ kind = HParams._get_kind_name(param_type, is_list) | |
+ | |
+ if is_list: | |
+ if kind.startswith('bytes'): | |
+ v_list = [compat.as_bytes(v) for v in getattr(self, name)] | |
+ else: | |
+ v_list = [v for v in getattr(self, name)] | |
+ getattr(hparam_proto.hparam[name], kind).value.extend(v_list) | |
+ else: | |
+ v = getattr(self, name) | |
+ if kind.startswith('bytes'): | |
+ v = compat.as_bytes(getattr(self, name)) | |
+ setattr(hparam_proto.hparam[name], kind, v) | |
+ | |
+ return hparam_proto | |
+ | |
+ @staticmethod | |
+ def from_proto(hparam_def, import_scope=None): # pylint: disable=unused-argument | |
+ return HParams(hparam_def=hparam_def) | |
+ | |
+ | |
+ops.register_proto_function( | |
+ 'hparams', | |
+ proto_type=hparam_pb2.HParamDef, | |
+ to_proto=HParams.to_proto, | |
+ from_proto=HParams.from_proto) | |
+ | |
diff --git a/src/hparam_pb2.py b/src/hparam_pb2.py | |
new file mode 100644 | |
index 0000000..ba10c30 | |
--- /dev/null | |
+++ b/src/hparam_pb2.py | |
@@ -0,0 +1,399 @@ | |
+# -*- coding: utf-8 -*- | |
+# Generated by the protocol buffer compiler. DO NOT EDIT! | |
+# source: tensorflow/contrib/training/python/training/hparam.proto | |
+ | |
+import sys | |
+_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) | |
+from google.protobuf import descriptor as _descriptor | |
+from google.protobuf import message as _message | |
+from google.protobuf import reflection as _reflection | |
+from google.protobuf import symbol_database as _symbol_database | |
+# @@protoc_insertion_point(imports) | |
+ | |
+_sym_db = _symbol_database.Default() | |
+ | |
+ | |
+ | |
+ | |
+DESCRIPTOR = _descriptor.FileDescriptor( | |
+ name='tensorflow/contrib/training/python/training/hparam.proto', | |
+ package='tensorflow', | |
+ syntax='proto3', | |
+ serialized_options=_b('\370\001\001'), | |
+ serialized_pb=_b('\n8tensorflow/contrib/training/python/training/hparam.proto\x12\ntensorflow\"\xd6\x04\n\tHParamDef\x12\x31\n\x06hparam\x18\x01 \x03(\x0b\x32!.tensorflow.HParamDef.HparamEntry\x1a\x1a\n\tBytesList\x12\r\n\x05value\x18\x01 \x03(\x0c\x1a\x1e\n\tFloatList\x12\x11\n\x05value\x18\x01 \x03(\x02\x42\x02\x10\x01\x1a\x1e\n\tInt64List\x12\x11\n\x05value\x18\x01 \x03(\x03\x42\x02\x10\x01\x1a\x1d\n\x08\x42oolList\x12\x11\n\x05value\x18\x01 \x03(\x08\x42\x02\x10\x01\x1a\xc9\x02\n\nHParamType\x12\x15\n\x0bint64_value\x18\x01 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\x02 \x01(\x02H\x00\x12\x15\n\x0b\x62ytes_value\x18\x03 \x01(\x0cH\x00\x12\x14\n\nbool_value\x18\x07 \x01(\x08H\x00\x12\x35\n\nint64_list\x18\x04 \x01(\x0b\x32\x1f.tensorflow.HParamDef.Int64ListH\x00\x12\x35\n\nfloat_list\x18\x05 \x01(\x0b\x32\x1f.tensorflow.HParamDef.FloatListH\x00\x12\x35\n\nbytes_list\x18\x06 \x01(\x0b\x32\x1f.tensorflow.HParamDef.BytesListH\x00\x12\x33\n\tbool_list\x18\x08 \x01(\x0b\x32\x1e.tensorflow.HParamDef.BoolListH\x00\x42\x06\n\x04kind\x1aO\n\x0bHparamEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .tensorflow.HParamDef.HParamType:\x02\x38\x01\x42\x03\xf8\x01\x01\x62\x06proto3') | |
+) | |
+ | |
+ | |
+ | |
+ | |
+_HPARAMDEF_BYTESLIST = _descriptor.Descriptor( | |
+ name='BytesList', | |
+ full_name='tensorflow.HParamDef.BytesList', | |
+ filename=None, | |
+ file=DESCRIPTOR, | |
+ containing_type=None, | |
+ fields=[ | |
+ _descriptor.FieldDescriptor( | |
+ name='value', full_name='tensorflow.HParamDef.BytesList.value', index=0, | |
+ number=1, type=12, cpp_type=9, label=3, | |
+ has_default_value=False, default_value=[], | |
+ message_type=None, enum_type=None, containing_type=None, | |
+ is_extension=False, extension_scope=None, | |
+ serialized_options=None, file=DESCRIPTOR), | |
+ ], | |
+ extensions=[ | |
+ ], | |
+ nested_types=[], | |
+ enum_types=[ | |
+ ], | |
+ serialized_options=None, | |
+ is_extendable=False, | |
+ syntax='proto3', | |
+ extension_ranges=[], | |
+ oneofs=[ | |
+ ], | |
+ serialized_start=137, | |
+ serialized_end=163, | |
+) | |
+ | |
+_HPARAMDEF_FLOATLIST = _descriptor.Descriptor( | |
+ name='FloatList', | |
+ full_name='tensorflow.HParamDef.FloatList', | |
+ filename=None, | |
+ file=DESCRIPTOR, | |
+ containing_type=None, | |
+ fields=[ | |
+ _descriptor.FieldDescriptor( | |
+ name='value', full_name='tensorflow.HParamDef.FloatList.value', index=0, | |
+ number=1, type=2, cpp_type=6, label=3, | |
+ has_default_value=False, default_value=[], | |
+ message_type=None, enum_type=None, containing_type=None, | |
+ is_extension=False, extension_scope=None, | |
+ serialized_options=_b('\020\001'), file=DESCRIPTOR), | |
+ ], | |
+ extensions=[ | |
+ ], | |
+ nested_types=[], | |
+ enum_types=[ | |
+ ], | |
+ serialized_options=None, | |
+ is_extendable=False, | |
+ syntax='proto3', | |
+ extension_ranges=[], | |
+ oneofs=[ | |
+ ], | |
+ serialized_start=165, | |
+ serialized_end=195, | |
+) | |
+ | |
+_HPARAMDEF_INT64LIST = _descriptor.Descriptor( | |
+ name='Int64List', | |
+ full_name='tensorflow.HParamDef.Int64List', | |
+ filename=None, | |
+ file=DESCRIPTOR, | |
+ containing_type=None, | |
+ fields=[ | |
+ _descriptor.FieldDescriptor( | |
+ name='value', full_name='tensorflow.HParamDef.Int64List.value', index=0, | |
+ number=1, type=3, cpp_type=2, label=3, | |
+ has_default_value=False, default_value=[], | |
+ message_type=None, enum_type=None, containing_type=None, | |
+ is_extension=False, extension_scope=None, | |
+ serialized_options=_b('\020\001'), file=DESCRIPTOR), | |
+ ], | |
+ extensions=[ | |
+ ], | |
+ nested_types=[], | |
+ enum_types=[ | |
+ ], | |
+ serialized_options=None, | |
+ is_extendable=False, | |
+ syntax='proto3', | |
+ extension_ranges=[], | |
+ oneofs=[ | |
+ ], | |
+ serialized_start=197, | |
+ serialized_end=227, | |
+) | |
+ | |
+_HPARAMDEF_BOOLLIST = _descriptor.Descriptor( | |
+ name='BoolList', | |
+ full_name='tensorflow.HParamDef.BoolList', | |
+ filename=None, | |
+ file=DESCRIPTOR, | |
+ containing_type=None, | |
+ fields=[ | |
+ _descriptor.FieldDescriptor( | |
+ name='value', full_name='tensorflow.HParamDef.BoolList.value', index=0, | |
+ number=1, type=8, cpp_type=7, label=3, | |
+ has_default_value=False, default_value=[], | |
+ message_type=None, enum_type=None, containing_type=None, | |
+ is_extension=False, extension_scope=None, | |
+ serialized_options=_b('\020\001'), file=DESCRIPTOR), | |
+ ], | |
+ extensions=[ | |
+ ], | |
+ nested_types=[], | |
+ enum_types=[ | |
+ ], | |
+ serialized_options=None, | |
+ is_extendable=False, | |
+ syntax='proto3', | |
+ extension_ranges=[], | |
+ oneofs=[ | |
+ ], | |
+ serialized_start=229, | |
+ serialized_end=258, | |
+) | |
+ | |
+_HPARAMDEF_HPARAMTYPE = _descriptor.Descriptor( | |
+ name='HParamType', | |
+ full_name='tensorflow.HParamDef.HParamType', | |
+ filename=None, | |
+ file=DESCRIPTOR, | |
+ containing_type=None, | |
+ fields=[ | |
+ _descriptor.FieldDescriptor( | |
+ name='int64_value', full_name='tensorflow.HParamDef.HParamType.int64_value', index=0, | |
+ number=1, type=3, cpp_type=2, label=1, | |
+ has_default_value=False, default_value=0, | |
+ message_type=None, enum_type=None, containing_type=None, | |
+ is_extension=False, extension_scope=None, | |
+ serialized_options=None, file=DESCRIPTOR), | |
+ _descriptor.FieldDescriptor( | |
+ name='float_value', full_name='tensorflow.HParamDef.HParamType.float_value', index=1, | |
+ number=2, type=2, cpp_type=6, label=1, | |
+ has_default_value=False, default_value=float(0), | |
+ message_type=None, enum_type=None, containing_type=None, | |
+ is_extension=False, extension_scope=None, | |
+ serialized_options=None, file=DESCRIPTOR), | |
+ _descriptor.FieldDescriptor( | |
+ name='bytes_value', full_name='tensorflow.HParamDef.HParamType.bytes_value', index=2, | |
+ number=3, type=12, cpp_type=9, label=1, | |
+ has_default_value=False, default_value=_b(""), | |
+ message_type=None, enum_type=None, containing_type=None, | |
+ is_extension=False, extension_scope=None, | |
+ serialized_options=None, file=DESCRIPTOR), | |
+ _descriptor.FieldDescriptor( | |
+ name='bool_value', full_name='tensorflow.HParamDef.HParamType.bool_value', index=3, | |
+ number=7, type=8, cpp_type=7, label=1, | |
+ has_default_value=False, default_value=False, | |
+ message_type=None, enum_type=None, containing_type=None, | |
+ is_extension=False, extension_scope=None, | |
+ serialized_options=None, file=DESCRIPTOR), | |
+ _descriptor.FieldDescriptor( | |
+ name='int64_list', full_name='tensorflow.HParamDef.HParamType.int64_list', index=4, | |
+ number=4, type=11, cpp_type=10, label=1, | |
+ has_default_value=False, default_value=None, | |
+ message_type=None, enum_type=None, containing_type=None, | |
+ is_extension=False, extension_scope=None, | |
+ serialized_options=None, file=DESCRIPTOR), | |
+ _descriptor.FieldDescriptor( | |
+ name='float_list', full_name='tensorflow.HParamDef.HParamType.float_list', index=5, | |
+ number=5, type=11, cpp_type=10, label=1, | |
+ has_default_value=False, default_value=None, | |
+ message_type=None, enum_type=None, containing_type=None, | |
+ is_extension=False, extension_scope=None, | |
+ serialized_options=None, file=DESCRIPTOR), | |
+ _descriptor.FieldDescriptor( | |
+ name='bytes_list', full_name='tensorflow.HParamDef.HParamType.bytes_list', index=6, | |
+ number=6, type=11, cpp_type=10, label=1, | |
+ has_default_value=False, default_value=None, | |
+ message_type=None, enum_type=None, containing_type=None, | |
+ is_extension=False, extension_scope=None, | |
+ serialized_options=None, file=DESCRIPTOR), | |
+ _descriptor.FieldDescriptor( | |
+ name='bool_list', full_name='tensorflow.HParamDef.HParamType.bool_list', index=7, | |
+ number=8, type=11, cpp_type=10, label=1, | |
+ has_default_value=False, default_value=None, | |
+ message_type=None, enum_type=None, containing_type=None, | |
+ is_extension=False, extension_scope=None, | |
+ serialized_options=None, file=DESCRIPTOR), | |
+ ], | |
+ extensions=[ | |
+ ], | |
+ nested_types=[], | |
+ enum_types=[ | |
+ ], | |
+ serialized_options=None, | |
+ is_extendable=False, | |
+ syntax='proto3', | |
+ extension_ranges=[], | |
+ oneofs=[ | |
+ _descriptor.OneofDescriptor( | |
+ name='kind', full_name='tensorflow.HParamDef.HParamType.kind', | |
+ index=0, containing_type=None, fields=[]), | |
+ ], | |
+ serialized_start=261, | |
+ serialized_end=590, | |
+) | |
+ | |
+_HPARAMDEF_HPARAMENTRY = _descriptor.Descriptor( | |
+ name='HparamEntry', | |
+ full_name='tensorflow.HParamDef.HparamEntry', | |
+ filename=None, | |
+ file=DESCRIPTOR, | |
+ containing_type=None, | |
+ fields=[ | |
+ _descriptor.FieldDescriptor( | |
+ name='key', full_name='tensorflow.HParamDef.HparamEntry.key', index=0, | |
+ number=1, type=9, cpp_type=9, label=1, | |
+ has_default_value=False, default_value=_b("").decode('utf-8'), | |
+ message_type=None, enum_type=None, containing_type=None, | |
+ is_extension=False, extension_scope=None, | |
+ serialized_options=None, file=DESCRIPTOR), | |
+ _descriptor.FieldDescriptor( | |
+ name='value', full_name='tensorflow.HParamDef.HparamEntry.value', index=1, | |
+ number=2, type=11, cpp_type=10, label=1, | |
+ has_default_value=False, default_value=None, | |
+ message_type=None, enum_type=None, containing_type=None, | |
+ is_extension=False, extension_scope=None, | |
+ serialized_options=None, file=DESCRIPTOR), | |
+ ], | |
+ extensions=[ | |
+ ], | |
+ nested_types=[], | |
+ enum_types=[ | |
+ ], | |
+ serialized_options=_b('8\001'), | |
+ is_extendable=False, | |
+ syntax='proto3', | |
+ extension_ranges=[], | |
+ oneofs=[ | |
+ ], | |
+ serialized_start=592, | |
+ serialized_end=671, | |
+) | |
+ | |
+_HPARAMDEF = _descriptor.Descriptor( | |
+ name='HParamDef', | |
+ full_name='tensorflow.HParamDef', | |
+ filename=None, | |
+ file=DESCRIPTOR, | |
+ containing_type=None, | |
+ fields=[ | |
+ _descriptor.FieldDescriptor( | |
+ name='hparam', full_name='tensorflow.HParamDef.hparam', index=0, | |
+ number=1, type=11, cpp_type=10, label=3, | |
+ has_default_value=False, default_value=[], | |
+ message_type=None, enum_type=None, containing_type=None, | |
+ is_extension=False, extension_scope=None, | |
+ serialized_options=None, file=DESCRIPTOR), | |
+ ], | |
+ extensions=[ | |
+ ], | |
+ nested_types=[_HPARAMDEF_BYTESLIST, _HPARAMDEF_FLOATLIST, _HPARAMDEF_INT64LIST, _HPARAMDEF_BOOLLIST, _HPARAMDEF_HPARAMTYPE, _HPARAMDEF_HPARAMENTRY, ], | |
+ enum_types=[ | |
+ ], | |
+ serialized_options=None, | |
+ is_extendable=False, | |
+ syntax='proto3', | |
+ extension_ranges=[], | |
+ oneofs=[ | |
+ ], | |
+ serialized_start=73, | |
+ serialized_end=671, | |
+) | |
+ | |
+_HPARAMDEF_BYTESLIST.containing_type = _HPARAMDEF | |
+_HPARAMDEF_FLOATLIST.containing_type = _HPARAMDEF | |
+_HPARAMDEF_INT64LIST.containing_type = _HPARAMDEF | |
+_HPARAMDEF_BOOLLIST.containing_type = _HPARAMDEF | |
+_HPARAMDEF_HPARAMTYPE.fields_by_name['int64_list'].message_type = _HPARAMDEF_INT64LIST | |
+_HPARAMDEF_HPARAMTYPE.fields_by_name['float_list'].message_type = _HPARAMDEF_FLOATLIST | |
+_HPARAMDEF_HPARAMTYPE.fields_by_name['bytes_list'].message_type = _HPARAMDEF_BYTESLIST | |
+_HPARAMDEF_HPARAMTYPE.fields_by_name['bool_list'].message_type = _HPARAMDEF_BOOLLIST | |
+_HPARAMDEF_HPARAMTYPE.containing_type = _HPARAMDEF | |
+_HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind'].fields.append( | |
+ _HPARAMDEF_HPARAMTYPE.fields_by_name['int64_value']) | |
+_HPARAMDEF_HPARAMTYPE.fields_by_name['int64_value'].containing_oneof = _HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind'] | |
+_HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind'].fields.append( | |
+ _HPARAMDEF_HPARAMTYPE.fields_by_name['float_value']) | |
+_HPARAMDEF_HPARAMTYPE.fields_by_name['float_value'].containing_oneof = _HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind'] | |
+_HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind'].fields.append( | |
+ _HPARAMDEF_HPARAMTYPE.fields_by_name['bytes_value']) | |
+_HPARAMDEF_HPARAMTYPE.fields_by_name['bytes_value'].containing_oneof = _HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind'] | |
+_HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind'].fields.append( | |
+ _HPARAMDEF_HPARAMTYPE.fields_by_name['bool_value']) | |
+_HPARAMDEF_HPARAMTYPE.fields_by_name['bool_value'].containing_oneof = _HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind'] | |
+_HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind'].fields.append( | |
+ _HPARAMDEF_HPARAMTYPE.fields_by_name['int64_list']) | |
+_HPARAMDEF_HPARAMTYPE.fields_by_name['int64_list'].containing_oneof = _HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind'] | |
+_HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind'].fields.append( | |
+ _HPARAMDEF_HPARAMTYPE.fields_by_name['float_list']) | |
+_HPARAMDEF_HPARAMTYPE.fields_by_name['float_list'].containing_oneof = _HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind'] | |
+_HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind'].fields.append( | |
+ _HPARAMDEF_HPARAMTYPE.fields_by_name['bytes_list']) | |
+_HPARAMDEF_HPARAMTYPE.fields_by_name['bytes_list'].containing_oneof = _HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind'] | |
+_HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind'].fields.append( | |
+ _HPARAMDEF_HPARAMTYPE.fields_by_name['bool_list']) | |
+_HPARAMDEF_HPARAMTYPE.fields_by_name['bool_list'].containing_oneof = _HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind'] | |
+_HPARAMDEF_HPARAMENTRY.fields_by_name['value'].message_type = _HPARAMDEF_HPARAMTYPE | |
+_HPARAMDEF_HPARAMENTRY.containing_type = _HPARAMDEF | |
+_HPARAMDEF.fields_by_name['hparam'].message_type = _HPARAMDEF_HPARAMENTRY | |
+DESCRIPTOR.message_types_by_name['HParamDef'] = _HPARAMDEF | |
+_sym_db.RegisterFileDescriptor(DESCRIPTOR) | |
+ | |
+HParamDef = _reflection.GeneratedProtocolMessageType('HParamDef', (_message.Message,), dict( | |
+ | |
+ BytesList = _reflection.GeneratedProtocolMessageType('BytesList', (_message.Message,), dict( | |
+ DESCRIPTOR = _HPARAMDEF_BYTESLIST, | |
+ __module__ = 'tensorflow.contrib.training.python.training.hparam_pb2' | |
+ # @@protoc_insertion_point(class_scope:tensorflow.HParamDef.BytesList) | |
+ )) | |
+ , | |
+ | |
+ FloatList = _reflection.GeneratedProtocolMessageType('FloatList', (_message.Message,), dict( | |
+ DESCRIPTOR = _HPARAMDEF_FLOATLIST, | |
+ __module__ = 'tensorflow.contrib.training.python.training.hparam_pb2' | |
+ # @@protoc_insertion_point(class_scope:tensorflow.HParamDef.FloatList) | |
+ )) | |
+ , | |
+ | |
+ Int64List = _reflection.GeneratedProtocolMessageType('Int64List', (_message.Message,), dict( | |
+ DESCRIPTOR = _HPARAMDEF_INT64LIST, | |
+ __module__ = 'tensorflow.contrib.training.python.training.hparam_pb2' | |
+ # @@protoc_insertion_point(class_scope:tensorflow.HParamDef.Int64List) | |
+ )) | |
+ , | |
+ | |
+ BoolList = _reflection.GeneratedProtocolMessageType('BoolList', (_message.Message,), dict( | |
+ DESCRIPTOR = _HPARAMDEF_BOOLLIST, | |
+ __module__ = 'tensorflow.contrib.training.python.training.hparam_pb2' | |
+ # @@protoc_insertion_point(class_scope:tensorflow.HParamDef.BoolList) | |
+ )) | |
+ , | |
+ | |
+ HParamType = _reflection.GeneratedProtocolMessageType('HParamType', (_message.Message,), dict( | |
+ DESCRIPTOR = _HPARAMDEF_HPARAMTYPE, | |
+ __module__ = 'tensorflow.contrib.training.python.training.hparam_pb2' | |
+ # @@protoc_insertion_point(class_scope:tensorflow.HParamDef.HParamType) | |
+ )) | |
+ , | |
+ | |
+ HparamEntry = _reflection.GeneratedProtocolMessageType('HparamEntry', (_message.Message,), dict( | |
+ DESCRIPTOR = _HPARAMDEF_HPARAMENTRY, | |
+ __module__ = 'tensorflow.contrib.training.python.training.hparam_pb2' | |
+ # @@protoc_insertion_point(class_scope:tensorflow.HParamDef.HparamEntry) | |
+ )) | |
+ , | |
+ DESCRIPTOR = _HPARAMDEF, | |
+ __module__ = 'tensorflow.contrib.training.python.training.hparam_pb2' | |
+ # @@protoc_insertion_point(class_scope:tensorflow.HParamDef) | |
+ )) | |
+_sym_db.RegisterMessage(HParamDef) | |
+_sym_db.RegisterMessage(HParamDef.BytesList) | |
+_sym_db.RegisterMessage(HParamDef.FloatList) | |
+_sym_db.RegisterMessage(HParamDef.Int64List) | |
+_sym_db.RegisterMessage(HParamDef.BoolList) | |
+_sym_db.RegisterMessage(HParamDef.HParamType) | |
+_sym_db.RegisterMessage(HParamDef.HparamEntry) | |
+ | |
+ | |
+DESCRIPTOR._options = None | |
+_HPARAMDEF_FLOATLIST.fields_by_name['value']._options = None | |
+_HPARAMDEF_INT64LIST.fields_by_name['value']._options = None | |
+_HPARAMDEF_BOOLLIST.fields_by_name['value']._options = None | |
+_HPARAMDEF_HPARAMENTRY._options = None | |
+# @@protoc_insertion_point(module_scope) | |
+ | |
diff --git a/src/interactive_conditional_samples.py b/src/interactive_conditional_samples.py | |
index 26390e6..0bde72e 100755 | |
--- a/src/interactive_conditional_samples.py | |
+++ b/src/interactive_conditional_samples.py | |
@@ -14,6 +14,23 @@ import tflex | |
import model, sample, encoder | |
+from load_dataset import load_dataset, Sampler | |
+ | |
+def generate_samples(context, enc, sampler, data_sampler, session, batch_size, count=1): | |
+ print('Generating samples...') | |
+ context_tokens = data_sampler.sample(1) | |
+ index = 0 | |
+ while index < count: | |
+ out = session.run( | |
+ sampler, | |
+ feed_dict={context: batch_size * [context_tokens]}) | |
+ for i in range(len(out)): | |
+ text = enc.decode(out[i]) | |
+ text = '======== SAMPLE {} ========\n{}\n'.format( | |
+ index + 1, text) | |
+ print(text) | |
+ index += 1 | |
+ | |
def interact_model( | |
model_name='117M', | |
seed=None, | |
@@ -22,7 +39,9 @@ def interact_model( | |
length=None, | |
temperature=1, | |
top_k=0, | |
- top_p=0.0 | |
+ top_p=0.0, | |
+ dataset=None, | |
+ combine=50000 | |
): | |
""" | |
Interactively run the model | |
@@ -58,10 +77,10 @@ def interact_model( | |
elif length > hparams.n_ctx: | |
raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) | |
- with tf.Session(graph=tf.Graph()) as sess: | |
- context = tf.placeholder(tf.int32, [batch_size, None]) | |
+ with tf.compat.v1.Session(graph=tf.Graph()) as sess: | |
+ context = tf.compat.v1.placeholder(tf.int32, [batch_size, None]) | |
np.random.seed(seed) | |
- tf.set_random_seed(seed) | |
+ tf.compat.v1.set_random_seed(seed) | |
output = sample.sample_sequence( | |
hparams=hparams, length=length, | |
context=context, | |
@@ -73,11 +92,21 @@ def interact_model( | |
ckpt = tflex.latest_checkpoint(os.path.join('models', model_name)) | |
saver.restore(sess, ckpt) | |
+ print('Loading dataset...') | |
+ chunks = load_dataset(enc, dataset, combine=combine) | |
+ data_sampler = Sampler(chunks, seed=seed) | |
+ print('dataset has', data_sampler.total_size, 'tokens', len(chunks), 'chunks') | |
+ | |
+ import pdb | |
+ pdb.set_trace() | |
+ | |
while True: | |
raw_text = input("Model prompt >>> ") | |
while not raw_text: | |
print('Prompt should not be empty!') | |
raw_text = input("Model prompt >>> ") | |
+ raw_text = raw_text.rstrip() | |
+ print(repr(raw_text)) | |
context_tokens = enc.encode(raw_text) | |
generated = 0 | |
for _ in range(nsamples // batch_size): | |
diff --git a/src/memory_saving_gradients.py b/src/memory_saving_gradients.py | |
index 659691f..f2cd20e 100644 | |
--- a/src/memory_saving_gradients.py | |
+++ b/src/memory_saving_gradients.py | |
@@ -2,7 +2,7 @@ from toposort import toposort | |
import contextlib | |
import numpy as np | |
import tensorflow as tf | |
-import tensorflow.contrib.graph_editor as ge | |
+import tflex_graph_editor as ge | |
import time | |
import sys | |
sys.setrecursionlimit(10000) | |
@@ -10,7 +10,7 @@ sys.setrecursionlimit(10000) | |
util = sys.modules[__name__] | |
# getting rid of "WARNING:tensorflow:VARIABLES collection name is deprecated" | |
-setattr(tf.GraphKeys, "VARIABLES", "variables") | |
+setattr(tf.compat.v1.GraphKeys, "VARIABLES", "variables") | |
# save original gradients since tf.gradient could be monkey-patched to point | |
# to our version | |
diff --git a/src/model.py b/src/model.py | |
index d0cde2a..dda5b58 100644 | |
--- a/src/model.py | |
+++ b/src/model.py | |
@@ -1,9 +1,14 @@ | |
import numpy as np | |
import tensorflow as tf | |
-from tensorflow.contrib.training import HParams | |
+import tensorflow_addons as tfa | |
+#from tensorflow.contrib.training import HParams | |
+import hparam | |
+import collections | |
+import tflex | |
+ | |
def default_hparams(): | |
- return HParams( | |
+ return hparam.HParams( | |
n_vocab=50257, | |
n_ctx=1024, | |
n_embd=768, | |
@@ -16,12 +21,22 @@ def default_hparams(): | |
import os | |
-def get_variable(name): | |
- name = os.path.join(tf.get_variable_scope().name, name) | |
- vs = tf.trainable_variables() | |
+def get_variable(name, unset=None): | |
+ fqn = tflex.join_variable_scope(name) | |
+ vs = tf.compat.v1.trainable_variables() | |
for x in vs: | |
- if x.name.startswith(name + ':'): | |
+ # should this next test be this instead? | |
+ # if x.name == fqn + ':0' | |
+ # | |
+ if x.name.startswith(fqn + ':'): | |
return x | |
+ if callable(unset): | |
+ return unset() | |
+ | |
+def init_variable(name, *args, **kws): | |
+ def unset(): | |
+ return tf.compat.v1.get_variable(name, *args, **kws) | |
+ return get_variable(name, unset=unset) | |
def shape_list(x): | |
"""Deal with dynamic shape in tensorflow cleanly.""" | |
@@ -29,24 +44,26 @@ def shape_list(x): | |
dynamic = tf.shape(x) | |
return [dynamic[i] if s is None else s for i, s in enumerate(static)] | |
-def softmax(x, axis=-1): | |
+def softmax(x, axis=-1, name=None): | |
x = x - tf.reduce_max(x, axis=axis, keepdims=True) | |
ex = tf.exp(x) | |
return ex / tf.reduce_sum(ex, axis=axis, keepdims=True) | |
+ #return tf.nn.softmax(x, axis=axis, name=name) | |
def gelu(x): | |
return 0.5*x*(1+tf.tanh(np.sqrt(2/np.pi)*(x+0.044715*tf.pow(x, 3)))) | |
+ #return tfa.activations.gelu(x) | |
def norm(x, scope, *, axis=-1, epsilon=1e-5, hparams=None): | |
"""Normalize to mean = 0, std = 1, then do a diagonal affine transform.""" | |
dtype = hparams.dtype if hparams else tf.float32 | |
- with tf.variable_scope(scope, dtype=dtype): | |
- n_state = x.shape[-1].value | |
- g = get_variable('g') or tf.get_variable('g', [n_state], initializer=tf.constant_initializer(1, dtype=dtype)) | |
- b = get_variable('b') or tf.get_variable('b', [n_state], initializer=tf.constant_initializer(0, dtype=dtype)) | |
+ with tflex.variable_scope(scope, dtype=dtype): | |
+ n_state = shape_list(x)[-1] | |
+ g = init_variable('g', [n_state], initializer=tf.compat.v1.constant_initializer(1, dtype=dtype)) | |
+ b = init_variable('b', [n_state], initializer=tf.compat.v1.constant_initializer(0, dtype=dtype)) | |
u = tf.reduce_mean(x, axis=axis, keepdims=True) | |
s = tf.reduce_mean(tf.square(x-u), axis=axis, keepdims=True) | |
- x = (x - u) * tf.rsqrt(s + epsilon) | |
+ x = (x - u) * tf.compat.v1.rsqrt(s + epsilon) | |
x = x*g + b | |
return x | |
@@ -62,10 +79,10 @@ def merge_states(x): | |
def conv1d(x, scope, nf, *, w_init_stdev=0.02, hparams=None): | |
dtype = hparams.dtype if hparams else tf.float32 | |
- with tf.variable_scope(scope, dtype=dtype): | |
+ with tflex.variable_scope(scope, dtype=dtype): | |
*start, nx = shape_list(x) | |
- w = get_variable('w') or tf.get_variable('w', [1, nx, nf], initializer=tf.random_normal_initializer(stddev=w_init_stdev, dtype=dtype)) | |
- b = get_variable('b') or tf.get_variable('b', [nf], initializer=tf.constant_initializer(0, dtype=dtype)) | |
+ w = init_variable('w', [1, nx, nf], initializer=tf.compat.v1.random_normal_initializer(stddev=w_init_stdev, dtype=dtype)) | |
+ b = init_variable('b', [nf], initializer=tf.compat.v1.constant_initializer(0, dtype=dtype)) | |
c = tf.reshape(tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf]))+b, start+[nf]) | |
return c | |
@@ -105,7 +122,7 @@ def attn(x, scope, n_state, *, past, hparams): | |
def multihead_attn(q, k, v): | |
# q, k, v have shape [batch, heads, sequence, features] | |
w = tf.matmul(q, k, transpose_b=True) | |
- w = w * tf.rsqrt(tf.cast(v.shape[-1].value, w.dtype)) | |
+ w = w * tf.compat.v1.rsqrt(tf.cast(v.shape[-1], w.dtype)) | |
w = mask_attn_weights(w) | |
w = softmax(w) | |
@@ -114,7 +131,7 @@ def attn(x, scope, n_state, *, past, hparams): | |
return a | |
dtype = hparams.dtype if hparams else tf.float32 | |
- with tf.variable_scope(scope, dtype=dtype): | |
+ with tflex.variable_scope(scope, dtype=dtype): | |
c = conv1d(x, 'c_attn', n_state*3, hparams=hparams) | |
q, k, v = map(split_heads, tf.split(c, 3, axis=2)) | |
present = tf.stack([k, v], axis=1) | |
@@ -131,9 +148,10 @@ def attn(x, scope, n_state, *, past, hparams): | |
def mlp(x, scope, n_state, *, hparams): | |
dtype = hparams.dtype if hparams else tf.float32 | |
- with tf.variable_scope(scope, dtype=dtype): | |
- nx = x.shape[-1].value | |
- h = gelu(conv1d(x, 'c_fc', n_state, hparams=hparams)) | |
+ with tflex.variable_scope(scope, dtype=dtype): | |
+ nx = shape_list(x)[-1] | |
+ h0 = conv1d(x, 'c_fc', n_state, hparams=hparams) | |
+ h = gelu(h0) | |
h2 = conv1d(h, 'c_proj', nx, hparams=hparams) | |
h2 = dropout(h2, hparams.res_dropout) | |
return h2 | |
@@ -145,8 +163,8 @@ def dropout(x, pdrop=0.1, train=True): | |
def block(x, scope, *, past, hparams): | |
dtype = hparams.dtype if hparams else tf.float32 | |
- with tf.variable_scope(scope, dtype=dtype): | |
- nx = x.shape[-1].value | |
+ with tflex.variable_scope(scope, dtype=dtype): | |
+ nx = shape_list(x)[-1] | |
a, present = attn(norm(x, 'ln_1', hparams=hparams), 'attn', nx, past=past, hparams=hparams) | |
x = x + a | |
m = mlp(norm(x, 'ln_2', hparams=hparams), 'mlp', nx*4, hparams=hparams) | |
@@ -168,16 +186,16 @@ def positions_for(tokens, past_length): | |
return expand_tile(past_length + tf.range(nsteps), batch_size) | |
-def model(hparams, X, past=None, scope='model', reuse=tf.AUTO_REUSE): | |
+def model(hparams, X, past=None, scope='model', reuse=tf.compat.v1.AUTO_REUSE): | |
dtype = hparams.dtype if hparams else tf.float32 | |
- with tf.variable_scope(scope, reuse=reuse, dtype=dtype): | |
+ with tflex.variable_scope(scope, reuse=reuse, dtype=dtype): | |
results = {} | |
batch, sequence = shape_list(X) | |
- wpe = get_variable('wpe') or tf.get_variable('wpe', [hparams.n_ctx, hparams.n_embd], | |
- initializer=tf.random_normal_initializer(stddev=0.01, dtype=dtype)) | |
- wte = get_variable('wte') or tf.get_variable('wte', [hparams.n_vocab, hparams.n_embd], | |
- initializer=tf.random_normal_initializer(stddev=0.02, dtype=dtype)) | |
+ wpe = init_variable('wpe', [hparams.n_ctx, hparams.n_embd], | |
+ initializer=tf.compat.v1.random_normal_initializer(stddev=0.01, dtype=dtype)) | |
+ wte = init_variable('wte', [hparams.n_vocab, hparams.n_embd], | |
+ initializer=tf.compat.v1.random_normal_initializer(stddev=0.02, dtype=dtype)) | |
past_length = 0 if past is None else tf.shape(past)[-2] | |
h = tf.gather(wte, X) + tf.gather(wpe, positions_for(X, past_length)) | |
@@ -188,7 +206,7 @@ def model(hparams, X, past=None, scope='model', reuse=tf.AUTO_REUSE): | |
for layer, past in enumerate(pasts): | |
h, present = block(h, 'h%d' % layer, past=past, hparams=hparams) | |
if layer == 10: | |
- tf.add_to_collection('checkpoints', h) | |
+ tf.compat.v1.add_to_collection('checkpoints', h) | |
presents.append(present) | |
results['present'] = tf.stack(presents, axis=1) | |
h = norm(h, 'ln_f', hparams=hparams) | |
diff --git a/src/sample.py b/src/sample.py | |
index ff517be..fb4ad80 100644 | |
--- a/src/sample.py | |
+++ b/src/sample.py | |
@@ -44,7 +44,7 @@ def sample_sequence(*, hparams, length, start_token=None, batch_size=None, conte | |
context = tf.fill([batch_size, 1], start_token) | |
def step(hparams, tokens, past=None): | |
- lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE) | |
+ lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.compat.v1.AUTO_REUSE) | |
if hparams.dtype != tf.float32: | |
lm_output["logits"] = tf.cast(lm_output["logits"], tf.float32) | |
@@ -64,12 +64,12 @@ def sample_sequence(*, hparams, length, start_token=None, batch_size=None, conte | |
def body(past, prev, output): | |
next_outputs = step(hparams, prev[:, tf.newaxis], past=past) | |
- logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature) | |
+ logits = next_outputs['logits'][:, -1, :] / tf.compat.v1.to_float(temperature) | |
if top_p > 0.0: | |
logits = top_p_logits(logits, p=top_p, epsilon=epsilon) | |
else: | |
logits = top_k_logits(logits, k=top_k, epsilon=epsilon) | |
- samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32) | |
+ samples = tf.compat.v1.multinomial(logits, num_samples=1, output_dtype=tf.int32) | |
return [ | |
tf.concat([past, next_outputs['presents']], axis=-2), | |
tf.squeeze(samples, axis=[1]), | |
diff --git a/tflex.py b/tflex.py | |
index 7c76c44..4ec65b4 100644 | |
--- a/tflex.py | |
+++ b/tflex.py | |
@@ -8,6 +8,7 @@ import tqdm | |
import h5py | |
import shutil | |
import tempfile | |
+import traceback | |
def split_by_params(vs, n=200e6, f=None): | |
if f is None: | |
@@ -25,7 +26,7 @@ def split_by_params(vs, n=200e6, f=None): | |
yield xs | |
def latest_checkpoint(checkpoint_dir, latest_filename=None): | |
- ctrs = np.array([[int(y) for y in re.findall(r'model-([0-9]+)(?:-[0-9]+)?[.](?:npy|hdf5)', x)] for x in glob(os.path.join(checkpoint_dir, 'model-*.*'))]).flatten() | |
+ ctrs = np.array([[int(y) for y in re.findall(r'model-([0-9]+)(?:-[0-9]+)?[.](?:npy|hdf5)', x)] for x in glob(os.path.join(checkpoint_dir, 'model-*.*')) if not x.endswith('.tmp')]).flatten() | |
if len(ctrs) <= 0: | |
ckpt = tf.train.latest_checkpoint(checkpoint_dir, latest_filename=latest_filename) | |
return ckpt | |
@@ -65,23 +66,55 @@ def assign_values(variables, values, session=None): | |
def load_snapshot(ckpt, session=None, var_list=None, reshape=False): | |
session = session or tf.get_default_session() | |
reader = pywrap_tensorflow.NewCheckpointReader(ckpt) | |
- vs = var_list or tf.trainable_variables() | |
+ vs = var_list or tf.compat.v1.trainable_variables() | |
for variables in tqdm.tqdm(list(split_by_params(vs))): | |
values = [value for variable, value in grab_values(variables, reader, reshape=reshape)] | |
assign_values(variables, values, session=session) | |
+def current_variable_scope(): | |
+ current = tf.compat.v1.get_variable_scope().name | |
+ return current | |
+ | |
+def fix_variable_scope(scope): | |
+ current = current_variable_scope() | |
+ print('TKTK', 'current', current, 'scope', scope) | |
+ if current == scope: | |
+ return current | |
+ if current.endswith('/' + scope): | |
+ return current | |
+ return scope | |
+ | |
+def join_variable_scope(name, base=None): | |
+ if base is None: | |
+ base = current_variable_scope() | |
+ if len(base) > 0 and name.startswith(base): | |
+ import pdb | |
+ pdb.set_trace() | |
+ if len(name) > 0 and base.startswith(name): | |
+ import pdb | |
+ pdb.set_trace() | |
+ return os.path.join(base, name) | |
+ | |
+def variable_scope(name=None, **kws): | |
+ if name is None: | |
+ name = current_variable_scope() | |
+ scope = fix_variable_scope(name) | |
+ #import pdb | |
+ #pdb.set_trace() | |
+ return tf.compat.v1.variable_scope(scope, **kws) | |
+ | |
def get_variable(name, var_list=None): | |
name, num = name.split(':') if ':' in name else (name, '0') | |
num = int(num) | |
- name = os.path.join(tf.get_variable_scope().name, name) | |
- vs = var_list or tf.trainable_variables() | |
+ name = join_variable_scope(name) | |
+ vs = var_list or tf.compat.v1.trainable_variables() | |
for x in vs: | |
if x.name.startswith(name + ':%d' % num): | |
return x | |
def load_weights(ckpt, session=None, var_list=None, reshape=False): | |
session = session or tf.get_default_session() | |
- vs = var_list or tf.trainable_variables() | |
+ vs = var_list or tf.compat.v1.trainable_variables() | |
files = list(sorted(glob(ckpt + '-*.npy'))) | |
for out in tqdm.tqdm(files): | |
for name, value in np.load(out, allow_pickle=True): | |
@@ -94,10 +127,10 @@ def load_weights(ckpt, session=None, var_list=None, reshape=False): | |
def load_variables(ckpt, session=None, var_list=None, reshape=False): | |
session = session or tf.get_default_session() | |
- vs = var_list or tf.trainable_variables() | |
+ vs = var_list or tf.compat.v1.trainable_variables() | |
with h5py.File(ckpt) as f: | |
for variables in tqdm.tqdm(list(split_by_params(vs))): | |
- values = [truncate_value(x, f[x.name], reshape=reshape) for x in variables] | |
+ values = [truncate_value(x, f[x.name], reshape=reshape) for x in variables if x.name in f] | |
assign_values(variables, values, session=session) | |
def maketree(path): | |
@@ -107,21 +140,31 @@ def maketree(path): | |
pass | |
def save_variables(ckpt, session=None, var_list=None): | |
- session = session or tf.get_default_session() | |
- vs = var_list or tf.trainable_variables() | |
- maketree(os.path.dirname(ckpt)) | |
- fname = ckpt+'.tmp' | |
- with h5py.File(fname, "w") as f: | |
- for variables in tqdm.tqdm(list(split_by_params(vs))): | |
- values = session.run(variables) | |
- for value, variable in zip(values, variables): | |
- name = variable.name | |
- shape = variable.shape.as_list() | |
- dtype = variable.dtype | |
- dset = f.create_dataset(name, shape, dtype=np.float32) | |
- dset[:] = value | |
- print('Writing snapshot %s' % ckpt) | |
- os.rename(ckpt+'.tmp', ckpt) | |
+ while True: | |
+ try: | |
+ session = session or tf.get_default_session() | |
+ vs = var_list or tf.compat.v1.trainable_variables() | |
+ maketree(os.path.dirname(ckpt)) | |
+ fname = ckpt+'.tmp' | |
+ with h5py.File(fname, "w") as f: | |
+ for variables in tqdm.tqdm(list(split_by_params(vs))): | |
+ values = session.run(variables) | |
+ for value, variable in zip(values, variables): | |
+ name = variable.name | |
+ shape = variable.shape.as_list() | |
+ dtype = variable.dtype | |
+ dset = f.create_dataset(name, shape, dtype=np.float32) | |
+ dset[:] = value | |
+ print('Writing snapshot %s' % ckpt) | |
+ os.rename(ckpt+'.tmp', ckpt) | |
+ break | |
+ except: | |
+ traceback.print_exc(file=sys.stderr) | |
+ print('Exception while saving checkpoint. To try again, press "c" then enter.') | |
+ print('(Is your disk full?)') | |
+ print('Attaching debugger...') | |
+ import pdb | |
+ pdb.set_trace() | |
class Saver(object): | |
def __init__( | |
@@ -137,7 +180,7 @@ class Saver(object): | |
builder=None, | |
defer_build=False, | |
allow_empty=False, | |
- write_version=tf.train.SaverDef.V2, | |
+ write_version=tf.compat.v1.train.SaverDef.V2, | |
pad_step_number=False, | |
save_relative_paths=False, | |
filename=None): | |
@@ -159,6 +202,7 @@ class Saver(object): | |
self.checkpoints = [] | |
def restore(self, sess, save_path): | |
+ print('Restoring from {}...'.format(save_path)) | |
if save_path.endswith('.ckpt'): | |
load_snapshot(save_path, session=sess, var_list=self.var_list, reshape=self.reshape) | |
elif save_path.endswith('.hdf5'): | |
diff --git a/tflex_graph_editor/__init__.py b/tflex_graph_editor/__init__.py | |
new file mode 100644 | |
index 0000000..5434d4f | |
--- /dev/null | |
+++ b/tflex_graph_editor/__init__.py | |
@@ -0,0 +1,41 @@ | |
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved. | |
+# | |
+# Licensed under the Apache License, Version 2.0 (the "License"); | |
+# you may not use this file except in compliance with the License. | |
+# You may obtain a copy of the License at | |
+# | |
+# http://www.apache.org/licenses/LICENSE-2.0 | |
+# | |
+# Unless required by applicable law or agreed to in writing, software | |
+# distributed under the License is distributed on an "AS IS" BASIS, | |
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+# See the License for the specific language governing permissions and | |
+# limitations under the License. | |
+# ============================================================================== | |
+"""TensorFlow Graph Editor.""" | |
+ | |
+from __future__ import absolute_import | |
+from __future__ import division | |
+from __future__ import print_function | |
+ | |
+# pylint: disable=wildcard-import | |
+from tflex_graph_editor.edit import * | |
+from tflex_graph_editor.reroute import * | |
+from tflex_graph_editor.select import * | |
+from tflex_graph_editor.subgraph import * | |
+from tflex_graph_editor.transform import * | |
+from tflex_graph_editor.util import * | |
+# pylint: enable=wildcard-import | |
+ | |
+# some useful aliases | |
+# pylint: disable=g-bad-import-order | |
+from tflex_graph_editor import subgraph as _subgraph | |
+from tflex_graph_editor import util as _util | |
+# pylint: enable=g-bad-import-order | |
+ph = _util.make_placeholder_from_dtype_and_shape | |
+sgv = _subgraph.make_view | |
+sgv_scope = _subgraph.make_view_from_scope | |
+ | |
+del absolute_import | |
+del division | |
+del print_function | |
diff --git a/tflex_graph_editor/edit.py b/tflex_graph_editor/edit.py | |
new file mode 100644 | |
index 0000000..411a86a | |
--- /dev/null | |
+++ b/tflex_graph_editor/edit.py | |
@@ -0,0 +1,221 @@ | |
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved. | |
+# | |
+# Licensed under the Apache License, Version 2.0 (the "License"); | |
+# you may not use this file except in compliance with the License. | |
+# You may obtain a copy of the License at | |
+# | |
+# http://www.apache.org/licenses/LICENSE-2.0 | |
+# | |
+# Unless required by applicable law or agreed to in writing, software | |
+# distributed under the License is distributed on an "AS IS" BASIS, | |
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+# See the License for the specific language governing permissions and | |
+# limitations under the License. | |
+# ============================================================================== | |
+"""Various function for graph editing.""" | |
+ | |
+from __future__ import absolute_import | |
+from __future__ import division | |
+from __future__ import print_function | |
+ | |
+from tflex_graph_editor import reroute | |
+from tflex_graph_editor import select | |
+from tflex_graph_editor import subgraph | |
+from tflex_graph_editor import util | |
+from tensorflow.python.ops import array_ops as tf_array_ops | |
+ | |
+__all__ = [ | |
+ "detach_control_inputs", | |
+ "detach_control_outputs", | |
+ "detach_inputs", | |
+ "detach_outputs", | |
+ "detach", | |
+ "connect", | |
+ "bypass", | |
+] | |
+ | |
+ | |
+def detach_control_inputs(sgv): | |
+ """Detach all the external control inputs of the subgraph sgv. | |
+ | |
+ Args: | |
+ sgv: the subgraph view to be detached. This argument is converted to a | |
+ subgraph using the same rules as the function subgraph.make_view. | |
+ """ | |
+ sgv = subgraph.make_view(sgv) | |
+ for op in sgv.ops: | |
+ cops = [cop for cop in op.control_inputs if cop not in sgv.ops] | |
+ reroute.remove_control_inputs(op, cops) | |
+ | |
+ | |
+def detach_control_outputs(sgv, control_outputs): | |
+ """Detach all the external control outputs of the subgraph sgv. | |
+ | |
+ Args: | |
+ sgv: the subgraph view to be detached. This argument is converted to a | |
+ subgraph using the same rules as the function subgraph.make_view. | |
+ control_outputs: a util.ControlOutputs instance. | |
+ """ | |
+ if not isinstance(control_outputs, util.ControlOutputs): | |
+ raise TypeError("Expected a util.ControlOutputs, got: {}", | |
+ type(control_outputs)) | |
+ control_outputs.update() | |
+ sgv = subgraph.make_view(sgv) | |
+ for op in sgv.ops: | |
+ for cop in control_outputs.get(op): | |
+ if cop not in sgv.ops: | |
+ reroute.remove_control_inputs(cop, op) | |
+ | |
+ | |
+def detach_inputs(sgv, control_inputs=False): | |
+ """Detach the inputs of a subgraph view. | |
+ | |
+ Args: | |
+ sgv: the subgraph view to be detached. This argument is converted to a | |
+ subgraph using the same rules as the function subgraph.make_view. | |
+ Note that sgv is modified in place. | |
+ control_inputs: if True control_inputs are also detached. | |
+ Returns: | |
+ A tuple `(sgv, input_placeholders)` where | |
+ `sgv` is a new subgraph view of the detached subgraph; | |
+ `input_placeholders` is a list of the created input placeholders. | |
+ Raises: | |
+ StandardError: if sgv cannot be converted to a SubGraphView using | |
+ the same rules than the function subgraph.make_view. | |
+ """ | |
+ sgv = subgraph.make_view(sgv) | |
+ | |
+ with sgv.graph.as_default(): | |
+ input_placeholders = [ | |
+ tf_array_ops.placeholder( | |
+ dtype=input_t.dtype, name=util.placeholder_name(input_t)) | |
+ for input_t in sgv.inputs | |
+ ] | |
+ | |
+ reroute.swap_inputs(sgv, input_placeholders) | |
+ if control_inputs: | |
+ detach_control_inputs(sgv) | |
+ return sgv, input_placeholders | |
+ | |
+ | |
+def detach_outputs(sgv, control_outputs=None): | |
+ """Detach the output of a subgraph view. | |
+ | |
+ Args: | |
+ sgv: the subgraph view to be detached. This argument is converted to a | |
+ subgraph using the same rules as the function subgraph.make_view. | |
+ Note that sgv is modified in place. | |
+ control_outputs: a util.ControlOutputs instance or None. If not None the | |
+ control outputs are also detached. | |
+ Returns: | |
+ A tuple `(sgv, output_placeholders)` where | |
+ `sgv` is a new subgraph view of the detached subgraph; | |
+ `output_placeholders` is a list of the created output placeholders. | |
+ Raises: | |
+ StandardError: if sgv cannot be converted to a SubGraphView using | |
+ the same rules than the function subgraph.make_view. | |
+ """ | |
+ sgv = subgraph.make_view(sgv) | |
+ # only select outputs with consumers | |
+ sgv_ = sgv.remap_outputs([output_id | |
+ for output_id, output_t in enumerate(sgv.outputs) | |
+ if output_t.consumers()]) | |
+ # create consumer subgraph and remap | |
+ consumers_sgv = subgraph.SubGraphView(sgv_.consumers()) | |
+ consumers_sgv = consumers_sgv.remap_inputs( | |
+ [input_id for input_id, input_t in enumerate(consumers_sgv.inputs) | |
+ if input_t in sgv_.outputs]) | |
+ | |
+ with sgv_.graph.as_default(): | |
+ output_placeholders = [ | |
+ util.make_placeholder_from_tensor(input_t) | |
+ for input_t in consumers_sgv.inputs | |
+ ] | |
+ | |
+ reroute.swap_outputs(sgv_, output_placeholders) | |
+ if control_outputs is not None: | |
+ detach_control_outputs(sgv_, control_outputs) | |
+ return sgv_, output_placeholders | |
+ | |
+ | |
+def detach(sgv, control_inputs=False, control_outputs=None, control_ios=None): | |
+ """Detach both the inputs and the outputs of a subgraph view. | |
+ | |
+ Args: | |
+ sgv: the subgraph view to be detached. This argument is converted to a | |
+ subgraph using the same rules as the function subgraph.make_view. | |
+ Note that sgv is modified in place. | |
+ control_inputs: A boolean indicating whether control inputs are enabled. | |
+ control_outputs: An instance of util.ControlOutputs or None. If not None, | |
+ control outputs are enabled. | |
+ control_ios: An instance of util.ControlOutputs or None. If not None, both | |
+ control inputs and control outputs are enabled. This is equivalent to set | |
+ control_inputs to True and control_outputs to the util.ControlOutputs | |
+ instance. | |
+ Returns: | |
+ A tuple `(sgv, detached_inputs, detached_outputs)` where: | |
+ `sgv` is a new subgraph view of the detached subgraph; | |
+ `detach_inputs` is a list of the created input placeholders; | |
+ `detach_outputs` is a list of the created output placeholders. | |
+ Raises: | |
+ StandardError: if sgv cannot be converted to a SubGraphView using | |
+ the same rules than the function subgraph.make_view. | |
+ """ | |
+ control_inputs, control_outputs = select.check_cios(control_inputs, | |
+ control_outputs, | |
+ control_ios) | |
+ _, detached_inputs = detach_inputs(sgv, control_inputs) | |
+ _, detached_outputs = detach_outputs(sgv, control_outputs) | |
+ return sgv, detached_inputs, detached_outputs | |
+ | |
+ | |
+def connect(sgv0, sgv1, disconnect_first=False): | |
+ """Connect the outputs of sgv0 to the inputs of sgv1. | |
+ | |
+ Args: | |
+ sgv0: the first subgraph to have its outputs swapped. This argument is | |
+ converted to a subgraph using the same rules as the function | |
+ subgraph.make_view. | |
+ Note that sgv0 is modified in place. | |
+ sgv1: the second subgraph to have its outputs swapped. This argument is | |
+ converted to a subgraph using the same rules as the function | |
+ subgraph.make_view. | |
+ Note that sgv1 is modified in place. | |
+ disconnect_first: if True the current outputs of sgv0 are disconnected. | |
+ Returns: | |
+ A tuple `(sgv0, sgv1)` of the now connected subgraphs. | |
+ Raises: | |
+ StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using | |
+ the same rules than the function subgraph.make_view. | |
+ """ | |
+ sgv0 = subgraph.make_view(sgv0) | |
+ sgv1 = subgraph.make_view(sgv1) | |
+ util.check_graphs(sgv0, sgv1) | |
+ if disconnect_first: | |
+ detach_outputs(sgv0) | |
+ sgv0_outputs = subgraph.SubGraphView(passthrough_ts=sgv0.outputs) | |
+ reroute.reroute_inputs(sgv0_outputs, sgv1) | |
+ return sgv0, sgv1 | |
+ | |
+ | |
+def bypass(sgv): | |
+ """Bypass the given subgraph by connecting its inputs to its outputs. | |
+ | |
+ Args: | |
+ sgv: the subgraph view to be bypassed. This argument is converted to a | |
+ subgraph using the same rules than the function subgraph.make_view. | |
+ Note that sgv is modified in place. | |
+ Returns: | |
+ A tuple `(sgv, detached_inputs)` where: | |
+ `sgv` is a new subgraph view of the bypassed subgraph; | |
+ `detached_inputs` is a list of the created input placeholders. | |
+ Raises: | |
+ StandardError: if sgv cannot be converted to a SubGraphView using | |
+ the same rules than the function subgraph.make_view. | |
+ """ | |
+ # TODO(fkp): allows to plug sgv.inputs to individual sgv.outputs consumers | |
+ sgv = subgraph.make_view(sgv) | |
+ sgv_inputs = list(sgv.inputs) | |
+ sgv, detached_inputs = detach_inputs(sgv) | |
+ reroute.reroute_ts(sgv_inputs, sgv.outputs) | |
+ return sgv, detached_inputs | |
diff --git a/tflex_graph_editor/reroute.py b/tflex_graph_editor/reroute.py | |
new file mode 100644 | |
index 0000000..998991f | |
--- /dev/null | |
+++ b/tflex_graph_editor/reroute.py | |
@@ -0,0 +1,502 @@ | |
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved. | |
+# | |
+# Licensed under the Apache License, Version 2.0 (the "License"); | |
+# you may not use this file except in compliance with the License. | |
+# You may obtain a copy of the License at | |
+# | |
+# http://www.apache.org/licenses/LICENSE-2.0 | |
+# | |
+# Unless required by applicable law or agreed to in writing, software | |
+# distributed under the License is distributed on an "AS IS" BASIS, | |
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+# See the License for the specific language governing permissions and | |
+# limitations under the License. | |
+# ============================================================================== | |
+"""Various function for graph rerouting.""" | |
+ | |
+from __future__ import absolute_import | |
+from __future__ import division | |
+from __future__ import print_function | |
+ | |
+from tflex_graph_editor import subgraph as _subgraph | |
+from tflex_graph_editor import util as _util | |
+from tensorflow.python.framework import ops as _tf_ops | |
+ | |
+from tensorflow.python.util.all_util import remove_undocumented | |
+ | |
+_allowed_symbols = [ | |
+ "swap_ts", | |
+ "reroute_ts", | |
+ "swap_inputs", | |
+ "reroute_inputs", | |
+ "swap_outputs", | |
+ "reroute_outputs", | |
+ "swap_ios", | |
+ "reroute_ios", | |
+ "remove_control_inputs", | |
+ "add_control_inputs", | |
+] | |
+ | |
+ | |
+def _check_ts_compatibility(ts0, ts1): | |
+ """Make sure the shape and dtype of the two tensor's lists are compatible. | |
+ | |
+ Args: | |
+ ts0: an object convertible to a list of `tf.Tensor`. | |
+ ts1: an object convertible to a list of `tf.Tensor`. | |
+ Raises: | |
+ ValueError: if any pair of tensors (same index in ts0 and ts1) have | |
+ a dtype or a shape which is not compatible. | |
+ """ | |
+ ts0 = _util.make_list_of_t(ts0) | |
+ ts1 = _util.make_list_of_t(ts1) | |
+ if len(ts0) != len(ts1): | |
+ raise ValueError("ts0 and ts1 have different sizes: {} != {}".format( | |
+ len(ts0), len(ts1))) | |
+ for t0, t1 in zip(ts0, ts1): | |
+ # check dtype | |
+ dtype0, dtype1 = t0.dtype, t1.dtype | |
+ if not dtype0.is_compatible_with(dtype1): | |
+ raise ValueError("Dtypes {} and {} are not compatible.".format(dtype0, | |
+ dtype1)) | |
+ # check shape | |
+ shape0, shape1 = t0.get_shape(), t1.get_shape() | |
+ if not shape0.is_compatible_with(shape1): | |
+ raise ValueError("Shapes {} and {} are not compatible.".format(shape0, | |
+ shape1)) | |
+ | |
+ | |
+class _RerouteMode(object): | |
+ """Enums for reroute's mode. | |
+ | |
+ swap: the end of tensors a and b are swapped. | |
+ a2b: the end of the tensor a are also rerouted to the end of the tensor b | |
+ (the end of b is left dangling). | |
+ b2a: the end of the tensor b are also rerouted to the end of the tensor a | |
+ (the end of a is left dangling). | |
+ """ | |
+ swap, a2b, b2a = range(3) | |
+ | |
+ @classmethod | |
+ def check(cls, mode): | |
+ """Check swap mode. | |
+ | |
+ Args: | |
+ mode: an integer representing one of the modes. | |
+ Returns: | |
+ A tuple `(a2b, b2a)` boolean indicating what rerouting needs doing. | |
+ Raises: | |
+ ValueError: if mode is outside the enum range. | |
+ """ | |
+ if mode == cls.swap: | |
+ return True, True | |
+ elif mode == cls.b2a: | |
+ return False, True | |
+ elif mode == cls.a2b: | |
+ return True, False | |
+ else: | |
+ raise ValueError("Unknown _RerouteMode: {}".format(mode)) | |
+ | |
+ | |
+def _reroute_t(t0, t1, consumers1, can_modify=None, cannot_modify=None): | |
+ """Reroute the end of the tensors (t0,t1). | |
+ | |
+ Warning: this function is directly manipulating the internals of the | |
+ `tf.Graph`. | |
+ | |
+ Args: | |
+ t0: a tf.Tensor. | |
+ t1: a tf.Tensor. | |
+ consumers1: The consumers of t1 which needs to be rerouted. | |
+ can_modify: iterable of operations which can be modified. Any operation | |
+ outside within_ops will be left untouched by this function. | |
+ cannot_modify: iterable of operations which cannot be modified. | |
+ Any operation within cannot_modify will be left untouched by this | |
+ function. | |
+ Returns: | |
+ The number of individual modifications made by the function. | |
+ """ | |
+ nb_update_inputs = 0 | |
+ if can_modify is not None: | |
+ consumers1 &= can_modify | |
+ if cannot_modify is not None: | |
+ consumers1 -= cannot_modify | |
+ consumers1_indices = {} | |
+ for consumer1 in consumers1: | |
+ consumers1_indices[consumer1] = [i for i, t in enumerate(consumer1.inputs) | |
+ if t is t1] | |
+ for consumer1 in consumers1: | |
+ for i in consumers1_indices[consumer1]: | |
+ consumer1._update_input(i, t0) # pylint: disable=protected-access | |
+ nb_update_inputs += 1 | |
+ return nb_update_inputs | |
+ | |
+ | |
+def _reroute_ts(ts0, ts1, mode, can_modify=None, cannot_modify=None): | |
+ """Reroute the end of the tensors in each pair (t0,t1) in ts0 x ts1. | |
+ | |
+ This function is the back-bone of the Graph-Editor. It is essentially a thin | |
+ wrapper on top of the tf.Operation._update_input. | |
+ | |
+ Given a pair of tensor t0, t1 in ts0 x ts1, this function re-route the end | |
+ of t0 and t1 in three possible ways: | |
+ 1) The reroute mode is "a<->b" or "b<->a": the tensors' end are swapped. After | |
+ this operation, the previous consumers of t0 are now consumers of t1 and | |
+ vice-versa. | |
+ 2) The reroute mode is "a->b": the tensors' end of t0 are re-routed to the | |
+ tensors's end of t1 (which are left dangling). After this operation, the | |
+ previous consumers of t0 are still consuming t0 but the previous consumers of | |
+ t1 are not also consuming t0. The tensor t1 has no consumer. | |
+ 3) The reroute mode is "b->a": this mode is the symmetric of the "a->b" mode. | |
+ | |
+ Note that this function is re-routing the end of two tensors, not the start. | |
+ Re-routing the start of two tensors is not supported by this library. The | |
+ reason for that is the following: TensorFlow, by design, creates a strong bond | |
+ between an op and its output tensor. This Graph editor follows this design and | |
+ treats an operation A and its generating tensors {t_i} as an entity which | |
+ cannot be broken. In other words, an op cannot be detached from any of its | |
+ output tensors, ever. But it is possible to detach an op from its input | |
+ tensors, which is what this function concerns itself with. | |
+ | |
+ Warning: this function is directly manipulating the internals of the tf.Graph. | |
+ | |
+ Args: | |
+ ts0: an object convertible to a list of `tf.Tensor`. | |
+ ts1: an object convertible to a list of `tf.Tensor`. | |
+ mode: what to do with those tensors: "a->b" or "b<->a" for swaping and | |
+ "a->b" or "b->a" for one direction re-routing. | |
+ can_modify: iterable of operations which can be modified. Any operation | |
+ outside within_ops will be left untouched by this function. | |
+ cannot_modify: iterable of operations which cannot be modified. | |
+ Any operation within cannot_modify will be left untouched by this | |
+ function. | |
+ Returns: | |
+ The number of individual modifications made by the function. | |
+ Raises: | |
+ TypeError: if `ts0` or `ts1` cannot be converted to a list of `tf.Tensor`. | |
+ TypeError: if `can_modify` or `cannot_modify` is not `None` and cannot be | |
+ converted to a list of `tf.Operation`. | |
+ """ | |
+ a2b, b2a = _RerouteMode.check(mode) | |
+ ts0 = _util.make_list_of_t(ts0) | |
+ ts1 = _util.make_list_of_t(ts1) | |
+ _check_ts_compatibility(ts0, ts1) | |
+ if cannot_modify is not None: | |
+ cannot_modify = frozenset(_util.make_list_of_op(cannot_modify)) | |
+ if can_modify is not None: | |
+ can_modify = frozenset(_util.make_list_of_op(can_modify)) | |
+ nb_update_inputs = 0 | |
+ precomputed_consumers = [] | |
+ # precompute consumers to avoid issue with repeated tensors: | |
+ for t0, t1 in zip(ts0, ts1): | |
+ consumers0 = set(t0.consumers()) | |
+ consumers1 = set(t1.consumers()) | |
+ precomputed_consumers.append((consumers0, consumers1)) | |
+ for t0, t1, consumers in zip(ts0, ts1, precomputed_consumers): | |
+ if t0 is t1: | |
+ continue # Silently ignore identical tensors. | |
+ consumers0, consumers1 = consumers | |
+ if a2b: | |
+ nb_update_inputs += _reroute_t(t0, t1, consumers1, can_modify, | |
+ cannot_modify) | |
+ if b2a: | |
+ nb_update_inputs += _reroute_t(t1, t0, consumers0, can_modify, | |
+ cannot_modify) | |
+ return nb_update_inputs | |
+ | |
+ | |
+def swap_ts(ts0, ts1, can_modify=None, cannot_modify=None): | |
+ """For each tensor's pair, swap the end of (t0,t1). | |
+ | |
+ B0 B1 B0 B1 | |
+ | | => X | |
+ A0 A1 A0 A1 | |
+ | |
+ Args: | |
+ ts0: an object convertible to a list of `tf.Tensor`. | |
+ ts1: an object convertible to a list of `tf.Tensor`. | |
+ can_modify: iterable of operations which can be modified. Any operation | |
+ outside within_ops will be left untouched by this function. | |
+ cannot_modify: iterable of operations which cannot be modified. | |
+ Any operation within cannot_modify will be left untouched by this | |
+ function. | |
+ Returns: | |
+ The number of individual modifications made by the function. | |
+ Raises: | |
+ TypeError: if ts0 or ts1 cannot be converted to a list of tf.Tensor. | |
+ TypeError: if can_modify or cannot_modify is not None and cannot be | |
+ converted to a list of tf.Operation. | |
+ """ | |
+ return _reroute_ts(ts0, ts1, _RerouteMode.swap, can_modify, cannot_modify) | |
+ | |
+ | |
+def reroute_ts(ts0, ts1, can_modify=None, cannot_modify=None): | |
+ """For each tensor's pair, replace the end of t1 by the end of t0. | |
+ | |
+ B0 B1 B0 B1 | |
+ | | => |/ | |
+ A0 A1 A0 A1 | |
+ | |
+ The end of the tensors in ts1 are left dangling. | |
+ | |
+ Args: | |
+ ts0: an object convertible to a list of `tf.Tensor`. | |
+ ts1: an object convertible to a list of `tf.Tensor`. | |
+ can_modify: iterable of operations which can be modified. Any operation | |
+ outside within_ops will be left untouched by this function. | |
+ cannot_modify: iterable of operations which cannot be modified. Any | |
+ operation within cannot_modify will be left untouched by this function. | |
+ Returns: | |
+ The number of individual modifications made by the function. | |
+ Raises: | |
+ TypeError: if ts0 or ts1 cannot be converted to a list of tf.Tensor. | |
+ TypeError: if can_modify or cannot_modify is not None and cannot be | |
+ converted to a list of tf.Operation. | |
+ """ | |
+ return _reroute_ts(ts0, ts1, _RerouteMode.a2b, can_modify, cannot_modify) | |
+ | |
+ | |
+def _reroute_sgv_remap(sgv0, sgv1, mode): | |
+ """Remap in place the inputs of two subgraph views to mimic the reroute. | |
+ | |
+ This function is meant to used by reroute_inputs only. | |
+ | |
+ Args: | |
+ sgv0: the first subgraph to have its inputs remapped. | |
+ sgv1: the second subgraph to have its inputs remapped. | |
+ mode: reroute mode, see _reroute_ts(...). | |
+ Raises: | |
+ TypeError: if svg0 or svg1 are not SubGraphView. | |
+ ValueError: if sgv0 and sgv1 do not belong to the same graph. | |
+ """ | |
+ a2b, b2a = _RerouteMode.check(mode) | |
+ if not isinstance(sgv0, _subgraph.SubGraphView): | |
+ raise TypeError("Expected a SubGraphView, got {}".format(type(sgv0))) | |
+ if not isinstance(sgv1, _subgraph.SubGraphView): | |
+ raise TypeError("Expected a SubGraphView, got {}".format(type(sgv1))) | |
+ _util.check_graphs(sgv0, sgv1) | |
+ sgv0_ = sgv0.copy() | |
+ sgv1_ = sgv1.copy() | |
+ # pylint: disable=protected-access | |
+ if a2b and b2a: | |
+ (sgv0_._input_ts, sgv1_._input_ts) = (sgv1_._input_ts, sgv0_._input_ts) | |
+ (sgv0_._passthrough_ts, sgv1_._passthrough_ts) = (sgv1_._passthrough_ts, | |
+ sgv0_._passthrough_ts) | |
+ elif a2b: | |
+ sgv1_._input_ts = sgv0_._input_ts[:] | |
+ sgv1_._passthrough_ts = sgv0_._passthrough_ts[:] | |
+ elif b2a: | |
+ sgv0_._input_ts = sgv1_._input_ts[:] | |
+ sgv0_._passthrough_ts = sgv1_._passthrough_ts[:] | |
+ # pylint: enable=protected-access | |
+ | |
+ # Update the passthrough outputs as well. | |
+ def update_passthrough_outputs(a, b): | |
+ # pylint: disable=protected-access | |
+ for i, t in enumerate(b._output_ts): | |
+ if t in a._passthrough_ts: | |
+ ii = a._input_ts.index(t) | |
+ b._output_ts[i] = b._input_ts[ii] | |
+ # pylint: enable=protected-access | |
+ | |
+ if a2b: | |
+ update_passthrough_outputs(sgv0_, sgv1_) | |
+ if b2a: | |
+ update_passthrough_outputs(sgv1_, sgv0_) | |
+ | |
+ # in-place | |
+ # pylint: disable=protected-access | |
+ sgv0._assign_from(sgv0_) | |
+ sgv1._assign_from(sgv1_) | |
+ # pylint: enable=protected-access | |
+ | |
+ | |
+def _reroute_sgv_inputs(sgv0, sgv1, mode): | |
+ """Re-route all the inputs of two subgraphs. | |
+ | |
+ Args: | |
+ sgv0: the first subgraph to have its inputs swapped. This argument is | |
+ converted to a subgraph using the same rules than the function | |
+ subgraph.make_view. | |
+ sgv1: the second subgraph to have its inputs swapped. This argument is | |
+ converted to a subgraph using the same rules than the function | |
+ subgraph.make_view. | |
+ mode: reroute mode, see _reroute_ts(...). | |
+ Returns: | |
+ A tuple `(sgv0, sgv1)` of subgraph views with their inputs swapped. | |
+ Note that the function argument sgv0 and sgv1 are also modified in place. | |
+ Raises: | |
+ StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using | |
+ the same rules than the function subgraph.make_view. | |
+ """ | |
+ sgv0 = _subgraph.make_view(sgv0) | |
+ sgv1 = _subgraph.make_view(sgv1) | |
+ _util.check_graphs(sgv0, sgv1) | |
+ can_modify = sgv0.ops + sgv1.ops | |
+ # also allow consumers of passthrough to be modified: | |
+ can_modify += _util.get_consuming_ops(sgv0.passthroughs) | |
+ can_modify += _util.get_consuming_ops(sgv1.passthroughs) | |
+ _reroute_ts(sgv0.inputs, sgv1.inputs, mode, can_modify=can_modify) | |
+ _reroute_sgv_remap(sgv0, sgv1, mode) | |
+ return sgv0, sgv1 | |
+ | |
+ | |
+def _reroute_sgv_outputs(sgv0, sgv1, mode): | |
+ """Re-route all the outputs of two operations. | |
+ | |
+ Args: | |
+ sgv0: the first subgraph to have its outputs swapped. This argument is | |
+ converted to a subgraph using the same rules than the function | |
+ subgraph.make_view. | |
+ sgv1: the second subgraph to have its outputs swapped. This argument is | |
+ converted to a subgraph using the same rules than the function | |
+ subgraph.make_view. | |
+ mode: reroute mode, see _reroute_ts(...). | |
+ Returns: | |
+ A tuple `(sgv0, sgv1)` of subgraph views with their outputs swapped. | |
+ Note that the function argument sgv0 and sgv1 are also modified in place. | |
+ Raises: | |
+ StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using | |
+ the same rules than the function subgraph.make_view. | |
+ """ | |
+ sgv0 = _subgraph.make_view(sgv0) | |
+ sgv1 = _subgraph.make_view(sgv1) | |
+ _util.check_graphs(sgv0, sgv1) | |
+ cannot_modify = sgv0.ops + sgv1.ops | |
+ _reroute_ts(sgv0.outputs, sgv1.outputs, mode, cannot_modify=cannot_modify) | |
+ return sgv0, sgv1 | |
+ | |
+ | |
+def _reroute_sgv(sgv0, sgv1, mode): | |
+ """Re-route both the inputs and the outputs of the two subgraph views. | |
+ | |
+ This involves swapping all the inputs/outputs of the two subgraph views. | |
+ | |
+ Args: | |
+ sgv0: the first subgraph to be swapped. This argument is converted to a | |
+ subgraph using the same rules than the function subgraph.make_view. | |
+ sgv1: the second subgraph to be swapped. This argument is converted to a | |
+ subgraph using the same rules than the function subgraph.make_view. | |
+ mode: reroute mode, see _reroute_ts(...). | |
+ Returns: | |
+ A tuple `(sgv0, sgv1)` of subgraph views with their outputs and inputs | |
+ swapped. | |
+ Note that the function argument sgv0 and sgv1 are also modified in place. | |
+ Raises: | |
+ StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using | |
+ the same rules than the function subgraph.make_view. | |
+ """ | |
+ _reroute_sgv_outputs(sgv0, sgv1, mode) | |
+ _reroute_sgv_inputs(sgv0, sgv1, mode) | |
+ return sgv0, sgv1 | |
+ | |
+ | |
+def swap_inputs(sgv0, sgv1): | |
+ """Swap all the inputs of sgv0 and sgv1 (see reroute_inputs).""" | |
+ return _reroute_sgv_inputs(sgv0, sgv1, _RerouteMode.swap) | |
+ | |
+ | |
+def reroute_inputs(sgv0, sgv1): | |
+ """Re-route all the inputs of two subgraphs. | |
+ | |
+ Args: | |
+ sgv0: the first subgraph to have its inputs swapped. This argument is | |
+ converted to a subgraph using the same rules than the function | |
+ subgraph.make_view. | |
+ sgv1: the second subgraph to have its inputs swapped. This argument is | |
+ converted to a subgraph using the same rules than the function | |
+ subgraph.make_view. | |
+ Returns: | |
+ A tuple `(sgv0, sgv1)` of subgraph views with their inputs swapped. | |
+ Note that the function argument sgv0 and sgv1 are also modified in place. | |
+ Raises: | |
+ StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using | |
+ the same rules than the function subgraph.make_view. | |
+ """ | |
+ return _reroute_sgv_inputs(sgv0, sgv1, _RerouteMode.a2b) | |
+ | |
+ | |
+def swap_outputs(sgv0, sgv1): | |
+ """Swap all the outputs of sgv0 and sgv1 (see reroute_outputs).""" | |
+ return _reroute_sgv_outputs(sgv0, sgv1, _RerouteMode.swap) | |
+ | |
+ | |
+def reroute_outputs(sgv0, sgv1): | |
+ """Re-route all the outputs of two operations. | |
+ | |
+ Args: | |
+ sgv0: the first subgraph to have its outputs swapped. This argument is | |
+ converted to a subgraph using the same rules than the function | |
+ subgraph.make_view. | |
+ sgv1: the second subgraph to have its outputs swapped. This argument is | |
+ converted to a subgraph using the same rules than the function | |
+ subgraph.make_view. | |
+ Returns: | |
+ A tuple `(sgv0, sgv1)` of subgraph views with their outputs swapped. | |
+ Note that the function argument sgv0 and sgv1 are also modified in place. | |
+ Raises: | |
+ StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using | |
+ the same rules than the function subgraph.make_view. | |
+ """ | |
+ return _reroute_sgv_outputs(sgv0, sgv1, _RerouteMode.a2b) | |
+ | |
+ | |
+def swap_ios(sgv0, sgv1): | |
+ """Swap the inputs and outputs of sgv1 to sgv0 (see _reroute_sgv).""" | |
+ return _reroute_sgv(sgv0, sgv1, _RerouteMode.swap) | |
+ | |
+ | |
+def reroute_ios(sgv0, sgv1): | |
+ """Re-route the inputs and outputs of sgv0 to sgv1 (see _reroute_sgv).""" | |
+ return _reroute_sgv(sgv0, sgv1, _RerouteMode.a2b) | |
+ | |
+ | |
+def remove_control_inputs(op, cops): | |
+ """Remove the control inputs cops from co. | |
+ | |
+ Warning: this function is directly manipulating the internals of the | |
+ `tf.Graph`. | |
+ | |
+ Args: | |
+ op: a `tf.Operation` from which to remove the control inputs. | |
+ cops: an object convertible to a list of `tf.Operation`. | |
+ Raises: | |
+ TypeError: if op is not a `tf.Operation`. | |
+ ValueError: if any cop in cops is not a control input of op. | |
+ """ | |
+ if not isinstance(op, _tf_ops.Operation): | |
+ raise TypeError("Expected a tf.Operation, got: {}", type(op)) | |
+ cops = _util.make_list_of_op(cops, allow_graph=False) | |
+ for cop in cops: | |
+ if cop not in op.control_inputs: | |
+ raise ValueError("{} is not a control_input of {}".format(op.name, | |
+ cop.name)) | |
+ control_inputs = [cop for cop in op.control_inputs if cop not in cops] | |
+ # pylint: disable=protected-access | |
+ op._remove_all_control_inputs() | |
+ op._add_control_inputs(control_inputs) | |
+ # pylint: enable=protected-access | |
+ | |
+ | |
+def add_control_inputs(op, cops): | |
+ """Add the control inputs cops to op. | |
+ | |
+ Warning: this function is directly manipulating the internals of the tf.Graph. | |
+ | |
+ Args: | |
+ op: a tf.Operation to which the control inputs are added. | |
+ cops: an object convertible to a list of `tf.Operation`. | |
+ Raises: | |
+ TypeError: if op is not a tf.Operation | |
+ ValueError: if any cop in cops is already a control input of op. | |
+ """ | |
+ if not isinstance(op, _tf_ops.Operation): | |
+ raise TypeError("Expected a tf.Operation, got: {}", type(op)) | |
+ cops = _util.make_list_of_op(cops, allow_graph=False) | |
+ for cop in cops: | |
+ if cop in op.control_inputs: | |
+ raise ValueError("{} is already a control_input of {}".format(cop.name, | |
+ op.name)) | |
+ op._add_control_inputs(cops) # pylint: disable=protected-access | |
+ | |
+remove_undocumented(__name__, _allowed_symbols) | |
diff --git a/tflex_graph_editor/select.py b/tflex_graph_editor/select.py | |
new file mode 100644 | |
index 0000000..30d83de | |
--- /dev/null | |
+++ b/tflex_graph_editor/select.py | |
@@ -0,0 +1,773 @@ | |
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved. | |
+# | |
+# Licensed under the Apache License, Version 2.0 (the "License"); | |
+# you may not use this file except in compliance with the License. | |
+# You may obtain a copy of the License at | |
+# | |
+# http://www.apache.org/licenses/LICENSE-2.0 | |
+# | |
+# Unless required by applicable law or agreed to in writing, software | |
+# distributed under the License is distributed on an "AS IS" BASIS, | |
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+# See the License for the specific language governing permissions and | |
+# limitations under the License. | |
+# ============================================================================== | |
+"""Various ways of selecting operations and tensors in a graph.""" | |
+ | |
+from __future__ import absolute_import | |
+from __future__ import division | |
+from __future__ import print_function | |
+ | |
+import re | |
+ | |
+from six import iteritems | |
+from six import string_types | |
+ | |
+from tflex_graph_editor import util | |
+from tensorflow.python.ops import op_selector | |
+from tensorflow.python.framework import ops as tf_ops | |
+from tensorflow.python.util import deprecation | |
+ | |
+ | |
+ | |
+__all__ = [ | |
+ "can_be_regex", | |
+ "make_regex", | |
+ "filter_ts", | |
+ "filter_ts_from_regex", | |
+ "filter_ops", | |
+ "filter_ops_from_regex", | |
+ "get_name_scope_ops", | |
+ "check_cios", | |
+ "get_ops_ios", | |
+ "compute_boundary_ts", | |
+ "get_within_boundary_ops", | |
+ "get_forward_walk_ops", | |
+ "get_backward_walk_ops", | |
+ "get_walks_intersection_ops", | |
+ "get_walks_union_ops", | |
+ "select_ops", | |
+ "select_ts", | |
+ "select_ops_and_ts", | |
+] | |
+ | |
+_RE_TYPE = type(re.compile("")) | |
+ | |
+ | |
+def can_be_regex(obj): | |
+ """Return True if obj can be turned into a regular expression.""" | |
+ return isinstance(obj, string_types + (_RE_TYPE,)) | |
+ | |
+ | |
+def make_regex(obj): | |
+ """Return a compiled regular expression. | |
+ | |
+ Args: | |
+ obj: a string or a regular expression. | |
+ Returns: | |
+ A compiled regular expression. | |
+ Raises: | |
+ ValueError: if obj could not be converted to a regular expression. | |
+ """ | |
+ if not can_be_regex(obj): | |
+ raise ValueError("Expected a string or a regex, got: {}".format(type(obj))) | |
+ | |
+ if isinstance(obj, string_types): | |
+ return re.compile(obj) | |
+ else: | |
+ return obj | |
+ | |
+ | |
+def _get_input_ts(ops): | |
+ """Compute the list of unique input tensors of all the op in ops. | |
+ | |
+ Args: | |
+ ops: an object convertible to a list of `tf.Operation`. | |
+ Returns: | |
+ The list of unique input tensors of all the op in ops. | |
+ Raises: | |
+ TypeError: if ops cannot be converted to a list of `tf.Operation`. | |
+ """ | |
+ ops = util.make_list_of_op(ops) | |
+ ts = [] | |
+ ts_set = set() | |
+ for op in ops: | |
+ for t in op.inputs: | |
+ if t not in ts_set: | |
+ ts.append(t) | |
+ ts_set.add(t) | |
+ return ts | |
+ | |
+ | |
+def _get_output_ts(ops): | |
+ """Compute the list of unique output tensors of all the op in ops. | |
+ | |
+ Args: | |
+ ops: an object convertible to a list of tf.Operation. | |
+ Returns: | |
+ The list of unique output tensors of all the op in ops. | |
+ Raises: | |
+ TypeError: if ops cannot be converted to a list of tf.Operation. | |
+ """ | |
+ ops = util.make_list_of_op(ops) | |
+ ts = [] | |
+ for op in ops: | |
+ ts += op.outputs | |
+ return ts | |
+ | |
+ | |
+def filter_ts(ops, positive_filter): | |
+ """Get all the tensors which are input or output of an op in ops. | |
+ | |
+ Args: | |
+ ops: an object convertible to a list of `tf.Operation`. | |
+ positive_filter: a function deciding whether to keep a tensor or not. | |
+ If `True`, all the tensors are returned. | |
+ Returns: | |
+ A list of `tf.Tensor`. | |
+ Raises: | |
+ TypeError: if ops cannot be converted to a list of `tf.Operation`. | |
+ """ | |
+ ops = util.make_list_of_op(ops) | |
+ ts = _get_input_ts(ops) | |
+ util.concatenate_unique(ts, _get_output_ts(ops)) | |
+ if positive_filter is not True: | |
+ ts = [t for t in ts if positive_filter(t)] | |
+ return ts | |
+ | |
+ | |
+def filter_ts_from_regex(ops, regex): | |
+ r"""Get all the tensors linked to ops that match the given regex. | |
+ | |
+ Args: | |
+ ops: an object convertible to a list of tf.Operation. | |
+ regex: a regular expression matching the tensors' name. | |
+ For example, "^foo(/.*)?:\d+$" will match all the tensors in the "foo" | |
+ scope. | |
+ Returns: | |
+ A list of tf.Tensor. | |
+ Raises: | |
+ TypeError: if ops cannot be converted to a list of tf.Operation. | |
+ """ | |
+ ops = util.make_list_of_op(ops) | |
+ regex_obj = make_regex(regex) | |
+ return filter_ts(ops, positive_filter=lambda op: regex_obj.search(op.name)) | |
+ | |
+ | |
+def filter_ops(ops, positive_filter): | |
+ """Get the ops passing the given filter. | |
+ | |
+ Args: | |
+ ops: an object convertible to a list of tf.Operation. | |
+ positive_filter: a function deciding where to keep an operation or not. | |
+ If True, all the operations are returned. | |
+ Returns: | |
+ A list of selected tf.Operation. | |
+ Raises: | |
+ TypeError: if ops cannot be converted to a list of tf.Operation. | |
+ """ | |
+ ops = util.make_list_of_op(ops) | |
+ if positive_filter is not True: # pylint: disable=g-explicit-bool-comparison | |
+ ops = [op for op in ops if positive_filter(op)] | |
+ return ops | |
+ | |
+ | |
+def filter_ops_from_regex(ops, regex): | |
+ """Get all the operations that match the given regex. | |
+ | |
+ Args: | |
+ ops: an object convertible to a list of `tf.Operation`. | |
+ regex: a regular expression matching the operation's name. | |
+ For example, `"^foo(/.*)?$"` will match all the operations in the "foo" | |
+ scope. | |
+ Returns: | |
+ A list of `tf.Operation`. | |
+ Raises: | |
+ TypeError: if ops cannot be converted to a list of `tf.Operation`. | |
+ """ | |
+ ops = util.make_list_of_op(ops) | |
+ regex_obj = make_regex(regex) | |
+ return filter_ops(ops, lambda op: regex_obj.search(op.name)) | |
+ | |
+ | |
+def get_name_scope_ops(ops, scope): | |
+ """Get all the operations under the given scope path. | |
+ | |
+ Args: | |
+ ops: an object convertible to a list of tf.Operation. | |
+ scope: a scope path. | |
+ Returns: | |
+ A list of tf.Operation. | |
+ Raises: | |
+ TypeError: if ops cannot be converted to a list of tf.Operation. | |
+ """ | |
+ if scope and scope[-1] == "/": | |
+ scope = scope[:-1] | |
+ return filter_ops_from_regex(ops, "^{}(/.*)?$".format(scope)) | |
+ | |
+ | |
+def check_cios(control_inputs=False, control_outputs=None, control_ios=None): | |
+ """Do various check on control_inputs and control_outputs. | |
+ | |
+ Args: | |
+ control_inputs: A boolean indicating whether control inputs are enabled. | |
+ control_outputs: An instance of util.ControlOutputs or None. If not None, | |
+ control outputs are enabled. | |
+ control_ios: An instance of util.ControlOutputs or None. If not None, both | |
+ control inputs and control outputs are enabled. This is equivalent to set | |
+ control_inputs to True and control_outputs to the util.ControlOutputs | |
+ instance. | |
+ Returns: | |
+ A tuple `(control_inputs, control_outputs)` where: | |
+ `control_inputs` is a boolean indicating whether to use control inputs. | |
+ `control_outputs` is an instance of util.ControlOutputs or None | |
+ Raises: | |
+ ValueError: if control_inputs is an instance of util.ControlOutputs but | |
+ control_outputs is not None | |
+ TypeError: if control_outputs is not None and is not a util.ControlOutputs. | |
+ """ | |
+ if control_ios is not None: | |
+ if not isinstance(control_ios, util.ControlOutputs): | |
+ raise TypeError("Expected a util.ControlOutputs, got: {}".format( | |
+ type(control_ios))) | |
+ if control_outputs is not None: | |
+ raise ValueError("control_outputs should be None when using control_ios.") | |
+ control_inputs = True | |
+ control_outputs = control_ios | |
+ elif control_outputs is not None: | |
+ if not isinstance(control_outputs, util.ControlOutputs): | |
+ raise TypeError("Expected a util.ControlOutputs, got: {}".format( | |
+ type(control_outputs))) | |
+ | |
+ if control_outputs is not None: | |
+ control_outputs.update() | |
+ return control_inputs, control_outputs | |
+ | |
+ | |
+def get_ops_ios(ops, control_inputs=False, control_outputs=None, | |
+ control_ios=None): | |
+ """Return all the `tf.Operation` which are connected to an op in ops. | |
+ | |
+ Args: | |
+ ops: an object convertible to a list of `tf.Operation`. | |
+ control_inputs: A boolean indicating whether control inputs are enabled. | |
+ control_outputs: An instance of `util.ControlOutputs` or `None`. If not | |
+ `None`, control outputs are enabled. | |
+ control_ios: An instance of `util.ControlOutputs` or `None`. If not `None`, | |
+ both control inputs and control outputs are enabled. This is equivalent to | |
+ set `control_inputs` to `True` and `control_outputs` to the | |
+ `util.ControlOutputs` instance. | |
+ Returns: | |
+ All the `tf.Operation` surrounding the given ops. | |
+ Raises: | |
+ TypeError: if `ops` cannot be converted to a list of `tf.Operation`. | |
+ """ | |
+ control_inputs, control_outputs = check_cios(control_inputs, control_outputs, | |
+ control_ios) | |
+ ops = util.make_list_of_op(ops) | |
+ res = [] | |
+ for op in ops: | |
+ util.concatenate_unique(res, [t.op for t in op.inputs]) | |
+ for t in op.outputs: | |
+ util.concatenate_unique(res, t.consumers()) | |
+ if control_outputs is not None: | |
+ util.concatenate_unique(res, control_outputs.get(op)) | |
+ if control_inputs: | |
+ util.concatenate_unique(res, op.control_inputs) | |
+ return res | |
+ | |
+ | |
+def compute_boundary_ts(ops): | |
+ """Compute the tensors at the boundary of a set of ops. | |
+ | |
+ This function looks at all the tensors connected to the given ops (in/out) | |
+ and classify them into three categories: | |
+ 1) input tensors: tensors whose generating operation is not in ops. | |
+ 2) output tensors: tensors whose consumer operations are not in ops | |
+ 3) inside tensors: tensors which are neither input nor output tensors. | |
+ | |
+ Note that a tensor can be both an inside tensor and an output tensor if it is | |
+ consumed by operations both outside and inside of `ops`. | |
+ | |
+ Args: | |
+ ops: an object convertible to a list of tf.Operation. | |
+ Returns: | |
+ A tuple `(outside_input_ts, outside_output_ts, inside_ts)` where: | |
+ `outside_input_ts` is a Python list of input tensors; | |
+ `outside_output_ts` is a python list of output tensors; | |
+ `inside_ts` is a python list of inside tensors. | |
+ Since a tensor can be both an inside tensor and an output tensor, | |
+ `outside_output_ts` and `inside_ts` might intersect. | |
+ Raises: | |
+ TypeError: if ops cannot be converted to a list of tf.Operation. | |
+ """ | |
+ ops = util.make_list_of_op(ops) | |
+ input_ts = _get_input_ts(ops) | |
+ output_ts = _get_output_ts(ops) | |
+ output_ts_set = frozenset(output_ts) | |
+ ops_set = frozenset(ops) | |
+ | |
+ # Compute inside tensors. | |
+ inside_ts = [] | |
+ only_inside_ts = [] | |
+ for t in input_ts: | |
+ # Skip if the input tensor is not also an output tensor. | |
+ if t not in output_ts_set: | |
+ continue | |
+ # Mark as "inside". | |
+ inside_ts.append(t) | |
+ # Mark as "only inside" if the tensor is not both inside and output. | |
+ consumers = frozenset(t.consumers()) | |
+ if consumers - ops_set: | |
+ continue | |
+ only_inside_ts.append(t) | |
+ | |
+ inside_ts_set = frozenset(inside_ts) | |
+ only_inside_ts_set = frozenset(only_inside_ts) | |
+ outside_output_ts = [t for t in output_ts if t not in only_inside_ts_set] | |
+ outside_input_ts = [t for t in input_ts if t not in inside_ts_set] | |
+ return outside_input_ts, outside_output_ts, inside_ts | |
+ | |
+ | |
+def get_within_boundary_ops(ops, | |
+ seed_ops, | |
+ boundary_ops=(), | |
+ inclusive=True, | |
+ control_inputs=False, | |
+ control_outputs=None, | |
+ control_ios=None): | |
+ """Return all the `tf.Operation` within the given boundary. | |
+ | |
+ Args: | |
+ ops: an object convertible to a list of `tf.Operation`. those ops define the | |
+ set in which to perform the operation (if a `tf.Graph` is given, it | |
+ will be converted to the list of all its operations). | |
+ seed_ops: the operations from which to start expanding. | |
+ boundary_ops: the ops forming the boundary. | |
+ inclusive: if `True`, the result will also include the boundary ops. | |
+ control_inputs: A boolean indicating whether control inputs are enabled. | |
+ control_outputs: An instance of `util.ControlOutputs` or `None`. If not | |
+ `None`, control outputs are enabled. | |
+ control_ios: An instance of `util.ControlOutputs` or `None`. If not | |
+ `None`, both control inputs and control outputs are enabled. This is | |
+ equivalent to set control_inputs to True and control_outputs to | |
+ the `util.ControlOutputs` instance. | |
+ Returns: | |
+ All the `tf.Operation` surrounding the given ops. | |
+ Raises: | |
+ TypeError: if `ops` or `seed_ops` cannot be converted to a list of | |
+ `tf.Operation`. | |
+ ValueError: if the boundary is intersecting with the seeds. | |
+ """ | |
+ control_inputs, control_outputs = check_cios(control_inputs, control_outputs, | |
+ control_ios) | |
+ ops = util.make_list_of_op(ops) | |
+ seed_ops = util.make_list_of_op(seed_ops, allow_graph=False) | |
+ boundary_ops = set(util.make_list_of_op(boundary_ops)) | |
+ res = set(seed_ops) | |
+ if boundary_ops & res: | |
+ raise ValueError("Boundary is intersecting with the seeds.") | |
+ wave = set(seed_ops) | |
+ while wave: | |
+ new_wave = set() | |
+ ops_io = get_ops_ios(wave, control_inputs, control_outputs) | |
+ for op in ops_io: | |
+ if op in res: | |
+ continue | |
+ if op in boundary_ops: | |
+ if inclusive: | |
+ res.add(op) | |
+ else: | |
+ new_wave.add(op) | |
+ res.update(new_wave) | |
+ wave = new_wave | |
+ return [op for op in ops if op in res] | |
+ | |
+ | |
+def get_forward_walk_ops(seed_ops, | |
+ inclusive=True, | |
+ within_ops=None, | |
+ within_ops_fn=None, | |
+ stop_at_ts=(), | |
+ control_outputs=None): | |
+ """Do a forward graph walk and return all the visited ops. | |
+ | |
+ Args: | |
+ seed_ops: an iterable of operations from which the forward graph | |
+ walk starts. If a list of tensors is given instead, the seed_ops are set | |
+ to be the consumers of those tensors. | |
+ inclusive: if True the given seed_ops are also part of the resulting set. | |
+ within_ops: an iterable of `tf.Operation` within which the search is | |
+ restricted. If `within_ops` is `None`, the search is performed within | |
+ the whole graph. | |
+ within_ops_fn: if provided, a function on ops that should return True iff | |
+ the op is within the graph traversal. This can be used along within_ops, | |
+ in which case an op is within if it is also in within_ops. | |
+ stop_at_ts: an iterable of tensors at which the graph walk stops. | |
+ control_outputs: a `util.ControlOutputs` instance or None. | |
+ If not `None`, it will be used while walking the graph forward. | |
+ Returns: | |
+ A Python set of all the `tf.Operation` ahead of `seed_ops`. | |
+ Raises: | |
+ TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of | |
+ `tf.Operation`. | |
+ """ | |
+ _, control_outputs = check_cios(False, control_outputs) | |
+ if not util.is_iterable(seed_ops): | |
+ seed_ops = [seed_ops] | |
+ if not seed_ops: | |
+ return [] | |
+ if isinstance(seed_ops[0], tf_ops.Tensor): | |
+ ts = util.make_list_of_t(seed_ops, allow_graph=False) | |
+ seed_ops = util.get_consuming_ops(ts) | |
+ else: | |
+ seed_ops = util.make_list_of_op(seed_ops, allow_graph=False) | |
+ | |
+ seed_ops = frozenset(seed_ops) | |
+ stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts)) | |
+ if within_ops: | |
+ within_ops = util.make_list_of_op(within_ops, allow_graph=False) | |
+ within_ops = frozenset(within_ops) | |
+ seed_ops &= within_ops | |
+ | |
+ def is_within(op): | |
+ return (within_ops is None or op in within_ops) and ( | |
+ within_ops_fn is None or within_ops_fn(op)) | |
+ | |
+ result = list(seed_ops) | |
+ wave = set(seed_ops) | |
+ while wave: | |
+ new_wave = set() | |
+ for op in wave: | |
+ for new_t in op.outputs: | |
+ if new_t in stop_at_ts: | |
+ continue | |
+ for new_op in new_t.consumers(): | |
+ if new_op not in result and is_within(new_op): | |
+ new_wave.add(new_op) | |
+ if control_outputs is not None: | |
+ for new_op in control_outputs.get(op): | |
+ if new_op not in result and is_within(new_op): | |
+ new_wave.add(new_op) | |
+ util.concatenate_unique(result, new_wave) | |
+ wave = new_wave | |
+ if not inclusive: | |
+ result = [op for op in result if op not in seed_ops] | |
+ return result | |
+ | |
+ | |
+@deprecation.deprecated( | |
+ "2019-06-06", | |
+ "Please use tensorflow.python.ops.op_selector.get_backward_walk_ops.", | |
+ warn_once=True) | |
+def get_backward_walk_ops(seed_ops, | |
+ inclusive=True, | |
+ within_ops=None, | |
+ within_ops_fn=None, | |
+ stop_at_ts=(), | |
+ control_inputs=False): | |
+ """Do a backward graph walk and return all the visited ops. | |
+ | |
+ Args: | |
+ seed_ops: an iterable of operations from which the backward graph | |
+ walk starts. If a list of tensors is given instead, the seed_ops are set | |
+ to be the generators of those tensors. | |
+ inclusive: if True the given seed_ops are also part of the resulting set. | |
+ within_ops: an iterable of `tf.Operation` within which the search is | |
+ restricted. If `within_ops` is `None`, the search is performed within | |
+ the whole graph. | |
+ within_ops_fn: if provided, a function on ops that should return True iff | |
+ the op is within the graph traversal. This can be used along within_ops, | |
+ in which case an op is within if it is also in within_ops. | |
+ stop_at_ts: an iterable of tensors at which the graph walk stops. | |
+ control_inputs: if True, control inputs will be used while moving backward. | |
+ Returns: | |
+ A Python set of all the `tf.Operation` behind `seed_ops`. | |
+ Raises: | |
+ TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of | |
+ `tf.Operation`. | |
+ """ | |
+ return op_selector.get_backward_walk_ops( | |
+ seed_ops, | |
+ inclusive=inclusive, | |
+ within_ops=within_ops, | |
+ within_ops_fn=within_ops_fn, | |
+ stop_at_ts=stop_at_ts, | |
+ control_inputs=control_inputs) | |
+ | |
+ | |
+def get_walks_intersection_ops(forward_seed_ops, | |
+ backward_seed_ops, | |
+ forward_inclusive=True, | |
+ backward_inclusive=True, | |
+ within_ops=None, | |
+ within_ops_fn=None, | |
+ control_inputs=False, | |
+ control_outputs=None, | |
+ control_ios=None): | |
+ """Return the intersection of a forward and a backward walk. | |
+ | |
+ Args: | |
+ forward_seed_ops: an iterable of operations from which the forward graph | |
+ walk starts. If a list of tensors is given instead, the seed_ops are set | |
+ to be the consumers of those tensors. | |
+ backward_seed_ops: an iterable of operations from which the backward graph | |
+ walk starts. If a list of tensors is given instead, the seed_ops are set | |
+ to be the generators of those tensors. | |
+ forward_inclusive: if True the given forward_seed_ops are also part of the | |
+ resulting set. | |
+ backward_inclusive: if True the given backward_seed_ops are also part of the | |
+ resulting set. | |
+ within_ops: an iterable of tf.Operation within which the search is | |
+ restricted. If within_ops is None, the search is performed within | |
+ the whole graph. | |
+ within_ops_fn: if provided, a function on ops that should return True iff | |
+ the op is within the graph traversal. This can be used along within_ops, | |
+ in which case an op is within if it is also in within_ops. | |
+ control_inputs: A boolean indicating whether control inputs are enabled. | |
+ control_outputs: An instance of util.ControlOutputs or None. If not None, | |
+ control outputs are enabled. | |
+ control_ios: An instance of util.ControlOutputs or None. If not None, both | |
+ control inputs and control outputs are enabled. This is equivalent to set | |
+ control_inputs to True and control_outputs to the util.ControlOutputs | |
+ instance. | |
+ Returns: | |
+ A Python set of all the tf.Operation in the intersection of a forward and a | |
+ backward walk. | |
+ Raises: | |
+ TypeError: if `forward_seed_ops` or `backward_seed_ops` or `within_ops` | |
+ cannot be converted to a list of `tf.Operation`. | |
+ """ | |
+ control_inputs, control_outputs = check_cios(control_inputs, control_outputs, | |
+ control_ios) | |
+ forward_ops = get_forward_walk_ops( | |
+ forward_seed_ops, | |
+ inclusive=forward_inclusive, | |
+ within_ops=within_ops, | |
+ within_ops_fn=within_ops_fn, | |
+ control_outputs=control_outputs) | |
+ backward_ops = get_backward_walk_ops( | |
+ backward_seed_ops, | |
+ inclusive=backward_inclusive, | |
+ within_ops=within_ops, | |
+ within_ops_fn=within_ops_fn, | |
+ control_inputs=control_inputs) | |
+ return [op for op in forward_ops if op in backward_ops] | |
+ | |
+ | |
+def get_walks_union_ops(forward_seed_ops, | |
+ backward_seed_ops, | |
+ forward_inclusive=True, | |
+ backward_inclusive=True, | |
+ within_ops=None, | |
+ within_ops_fn=None, | |
+ control_inputs=False, | |
+ control_outputs=None, | |
+ control_ios=None): | |
+ """Return the union of a forward and a backward walk. | |
+ | |
+ Args: | |
+ forward_seed_ops: an iterable of operations from which the forward graph | |
+ walk starts. If a list of tensors is given instead, the seed_ops are set | |
+ to be the consumers of those tensors. | |
+ backward_seed_ops: an iterable of operations from which the backward graph | |
+ walk starts. If a list of tensors is given instead, the seed_ops are set | |
+ to be the generators of those tensors. | |
+ forward_inclusive: if True the given forward_seed_ops are also part of the | |
+ resulting set. | |
+ backward_inclusive: if True the given backward_seed_ops are also part of the | |
+ resulting set. | |
+ within_ops: restrict the search within those operations. If within_ops is | |
+ None, the search is done within the whole graph. | |
+ within_ops_fn: if provided, a function on ops that should return True iff | |
+ the op is within the graph traversal. This can be used along within_ops, | |
+ in which case an op is within if it is also in within_ops. | |
+ control_inputs: A boolean indicating whether control inputs are enabled. | |
+ control_outputs: An instance of util.ControlOutputs or None. If not None, | |
+ control outputs are enabled. | |
+ control_ios: An instance of util.ControlOutputs or None. If not None, both | |
+ control inputs and control outputs are enabled. This is equivalent to set | |
+ control_inputs to True and control_outputs to the util.ControlOutputs | |
+ instance. | |
+ Returns: | |
+ A Python set of all the tf.Operation in the union of a forward and a | |
+ backward walk. | |
+ Raises: | |
+ TypeError: if forward_seed_ops or backward_seed_ops or within_ops cannot be | |
+ converted to a list of tf.Operation. | |
+ """ | |
+ control_inputs, control_outputs = check_cios(control_inputs, control_outputs, | |
+ control_ios) | |
+ forward_ops = get_forward_walk_ops( | |
+ forward_seed_ops, | |
+ inclusive=forward_inclusive, | |
+ within_ops=within_ops, | |
+ within_ops_fn=within_ops_fn, | |
+ control_outputs=control_outputs) | |
+ backward_ops = get_backward_walk_ops( | |
+ backward_seed_ops, | |
+ inclusive=backward_inclusive, | |
+ within_ops=within_ops, | |
+ within_ops_fn=within_ops_fn, | |
+ control_inputs=control_inputs) | |
+ return util.concatenate_unique(forward_ops, backward_ops) | |
+ | |
+ | |
+def select_ops(*args, **kwargs): | |
+ """Helper to select operations. | |
+ | |
+ Args: | |
+ *args: list of 1) regular expressions (compiled or not) or 2) (array of) | |
+ `tf.Operation`. `tf.Tensor` instances are silently ignored. | |
+ **kwargs: 'graph': `tf.Graph` in which to perform the regex query.This is | |
+ required when using regex. | |
+ 'positive_filter': an elem if selected only if `positive_filter(elem)` is | |
+ `True`. This is optional. | |
+ 'restrict_ops_regex': a regular expression is ignored if it doesn't start | |
+ with the substring "(?#ops)". | |
+ Returns: | |
+ A list of `tf.Operation`. | |
+ Raises: | |
+ TypeError: if the optional keyword argument graph is not a `tf.Graph` | |
+ or if an argument in args is not an (array of) `tf.Operation` | |
+ or an (array of) `tf.Tensor` (silently ignored) or a string | |
+ or a regular expression. | |
+ ValueError: if one of the keyword arguments is unexpected or if a regular | |
+ expression is used without passing a graph as a keyword argument. | |
+ """ | |
+ # get keywords arguments | |
+ graph = None | |
+ positive_filter = None | |
+ restrict_ops_regex = False | |
+ for k, v in iteritems(kwargs): | |
+ if k == "graph": | |
+ graph = v | |
+ if graph is not None and not isinstance(graph, tf_ops.Graph): | |
+ raise TypeError("Expected a tf.Graph, got: {}".format(type(graph))) | |
+ elif k == "positive_filter": | |
+ positive_filter = v | |
+ elif k == "restrict_ops_regex": | |
+ restrict_ops_regex = v | |
+ elif k == "restrict_ts_regex": | |
+ pass | |
+ else: | |
+ raise ValueError("Wrong keywords argument: {}.".format(k)) | |
+ | |
+ ops = [] | |
+ | |
+ for arg in args: | |
+ if can_be_regex(arg): | |
+ if graph is None: | |
+ raise ValueError("Use the keyword argument 'graph' to use regex.") | |
+ regex = make_regex(arg) | |
+ if regex.pattern.startswith("(?#ts)"): | |
+ continue | |
+ if restrict_ops_regex and not regex.pattern.startswith("(?#ops)"): | |
+ continue | |
+ ops_ = filter_ops_from_regex(graph, regex) | |
+ for op_ in ops_: | |
+ if op_ not in ops: | |
+ if positive_filter is None or positive_filter(op_): | |
+ ops.append(op_) | |
+ else: | |
+ ops_aux = util.make_list_of_op(arg, ignore_ts=True) | |
+ if positive_filter is not None: | |
+ ops_aux = [op for op in ops_aux if positive_filter(op)] | |
+ ops_aux = [op for op in ops_aux if op not in ops] | |
+ ops += ops_aux | |
+ | |
+ return ops | |
+ | |
+ | |
+def select_ts(*args, **kwargs): | |
+ """Helper to select tensors. | |
+ | |
+ Args: | |
+ *args: list of 1) regular expressions (compiled or not) or 2) (array of) | |
+ `tf.Tensor`. `tf.Operation` instances are silently ignored. | |
+ **kwargs: 'graph': `tf.Graph` in which to perform the regex query.This is | |
+ required when using regex. | |
+ 'positive_filter': an elem if selected only if `positive_filter(elem)` is | |
+ `True`. This is optional. | |
+ 'restrict_ts_regex': a regular expression is ignored if it doesn't start | |
+ with the substring "(?#ts)". | |
+ Returns: | |
+ A list of `tf.Tensor`. | |
+ Raises: | |
+ TypeError: if the optional keyword argument graph is not a `tf.Graph` | |
+ or if an argument in args is not an (array of) `tf.Tensor` | |
+ or an (array of) `tf.Operation` (silently ignored) or a string | |
+ or a regular expression. | |
+ ValueError: if one of the keyword arguments is unexpected or if a regular | |
+ expression is used without passing a graph as a keyword argument. | |
+ """ | |
+ # get keywords arguments | |
+ graph = None | |
+ positive_filter = None | |
+ restrict_ts_regex = False | |
+ for k, v in iteritems(kwargs): | |
+ if k == "graph": | |
+ graph = v | |
+ if graph is not None and not isinstance(graph, tf_ops.Graph): | |
+ raise TypeError("Expected a tf.Graph, got {}".format(type(graph))) | |
+ elif k == "positive_filter": | |
+ positive_filter = v | |
+ elif k == "restrict_ts_regex": | |
+ restrict_ts_regex = v | |
+ elif k == "restrict_ops_regex": | |
+ pass | |
+ else: | |
+ raise ValueError("Wrong keywords argument: {}.".format(k)) | |
+ | |
+ ts = [] | |
+ | |
+ for arg in args: | |
+ if can_be_regex(arg): | |
+ if graph is None: | |
+ raise ValueError("Use the keyword argument 'graph' to use regex.") | |
+ regex = make_regex(arg) | |
+ if regex.pattern.startswith("(?#ops)"): | |
+ continue | |
+ if restrict_ts_regex and not regex.pattern.startswith("(?#ts)"): | |
+ continue | |
+ ts_ = filter_ts_from_regex(graph, regex) | |
+ for t_ in ts_: | |
+ if t_ not in ts: | |
+ if positive_filter is None or positive_filter(t_): | |
+ ts.append(t_) | |
+ else: | |
+ ts_aux = util.make_list_of_t(arg, ignore_ops=True) | |
+ if positive_filter is not None: | |
+ ts_aux = [t for t in ts_aux if positive_filter(t)] | |
+ ts_aux = [t for t in ts_aux if t not in ts] | |
+ ts += ts_aux | |
+ | |
+ return ts | |
+ | |
+ | |
+def select_ops_and_ts(*args, **kwargs): | |
+ """Helper to select operations and tensors. | |
+ | |
+ Args: | |
+ *args: list of 1) regular expressions (compiled or not) or 2) (array of) | |
+ `tf.Operation` 3) (array of) tf.Tensor. Regular expressions matching | |
+ tensors must start with the comment `"(?#ts)"`, for instance: | |
+ `"(?#ts)^foo/.*"`. | |
+ **kwargs: 'graph': `tf.Graph` in which to perform the regex query.This is | |
+ required when using regex. | |
+ 'positive_filter': an elem if selected only if `positive_filter(elem)` is | |
+ `True`. This is optional. | |
+ Returns: | |
+ A tuple `(ops, ts)` where: | |
+ `ops` is a list of `tf.Operation`, and | |
+ `ts` is a list of `tf.Tensor` | |
+ Raises: | |
+ TypeError: if the optional keyword argument graph is not a `tf.Graph` | |
+ or if an argument in args is not an (array of) `tf.Tensor` | |
+ or an (array of) `tf.Operation` or a string or a regular expression. | |
+ ValueError: if one of the keyword arguments is unexpected or if a regular | |
+ expression is used without passing a graph as a keyword argument. | |
+ """ | |
+ ops = select_ops(*args, restrict_ops_regex=False, **kwargs) | |
+ ts = select_ts(*args, restrict_ts_regex=True, **kwargs) | |
+ return ops, ts | |
diff --git a/tflex_graph_editor/subgraph.py b/tflex_graph_editor/subgraph.py | |
new file mode 100644 | |
index 0000000..4bac8c5 | |
--- /dev/null | |
+++ b/tflex_graph_editor/subgraph.py | |
@@ -0,0 +1,668 @@ | |
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved. | |
+# | |
+# Licensed under the Apache License, Version 2.0 (the "License"); | |
+# you may not use this file except in compliance with the License. | |
+# You may obtain a copy of the License at | |
+# | |
+# http://www.apache.org/licenses/LICENSE-2.0 | |
+# | |
+# Unless required by applicable law or agreed to in writing, software | |
+# distributed under the License is distributed on an "AS IS" BASIS, | |
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+# See the License for the specific language governing permissions and | |
+# limitations under the License. | |
+# ============================================================================== | |
+"""SubGraphView: a subgraph view on an existing tf.Graph. | |
+""" | |
+ | |
+from __future__ import absolute_import | |
+from __future__ import division | |
+from __future__ import print_function | |
+ | |
+import copy | |
+ | |
+import six | |
+from six import iteritems | |
+from six import StringIO | |
+ | |
+from tflex_graph_editor import select | |
+from tflex_graph_editor import util | |
+from tensorflow.python.framework import ops as tf_ops | |
+ | |
+__all__ = [ | |
+ "SubGraphView", | |
+ "make_view", | |
+ "make_view_from_scope", | |
+] | |
+ | |
+ | |
+def _finalize_index(index_or_t, ts): | |
+ """Returns index as is or return index of tensor in `ts`.""" | |
+ if isinstance(index_or_t, six.integer_types): | |
+ return index_or_t | |
+ else: | |
+ return ts.index(index_or_t) | |
+ | |
+ | |
+def _finalize_indices(list_of_index_or_t, ts): | |
+ """Returns index in `indices` as is or replace with tensor's index.""" | |
+ return [_finalize_index(index_or_t, ts) for index_or_t in list_of_index_or_t] | |
+ | |
+ | |
+def _check_within_range(mapping, n, repetition): | |
+ """Check is the mapping is valid. | |
+ | |
+ Args: | |
+ mapping: an iterable of integer. | |
+ n: define the input domain as [0, n-1]. Note that the mapping can be | |
+ under-complete, that is, it can only contain a subset of the integers on | |
+ [0, n-1]. | |
+ repetition: if True repetition are allowed (the function is surjective) | |
+ otherwise repetition are not allowed (the function is injective). | |
+ Raises: | |
+ ValueError: if the mapping is out of range ot if repetition is False and | |
+ the mapping has some repetition. | |
+ """ | |
+ for i in mapping: | |
+ if not 0 <= i < n: | |
+ raise ValueError("Out of [0, {}[ range: {}".format(n, i)) | |
+ if not repetition and len(set(mapping)) != len(mapping): | |
+ raise ValueError("Found repetition in mapping: {}".format(mapping)) | |
+ | |
+ | |
+class SubGraphView(object): | |
+ """A subgraph view on an existing `tf.Graph`. | |
+ | |
+ An instance of this class is a subgraph view on an existing `tf.Graph`. | |
+ "subgraph" means that it can represent part of the whole `tf.Graph`. | |
+ "view" means that it only provides a passive observation and do not to act | |
+ on the `tf.Graph`. Note that in this documentation, the term "subgraph" is | |
+ often used as substitute to "subgraph view". | |
+ | |
+ A subgraph contains: | |
+ | |
+ * a list of input tensors, accessible via the `inputs` property. | |
+ * a list of output tensors, accessible via the `outputs` property. | |
+ * and the operations in between, accessible via the "ops" property. | |
+ | |
+ An subgraph can be seen as a function F(i0, i1, ...) -> o0, o1, ... It is a | |
+ function which takes as input some input tensors and returns as output some | |
+ output tensors. The computation that the function performs is encoded in the | |
+ operations of the subgraph. | |
+ | |
+ The tensors (input or output) can be of two kinds: | |
+ | |
+ - connected: a connected tensor connects to at least one operation contained | |
+ in the subgraph. One example is a subgraph representing a single operation | |
+ and its inputs and outputs: all the input and output tensors of the op | |
+ are "connected". | |
+ - passthrough: a passthrough tensor does not connect to any operation | |
+ contained in the subgraph. One example is a subgraph representing a | |
+ single tensor: this tensor is passthrough. By default a passthrough tensor is | |
+ present both in the input and output tensors of the subgraph. It can however | |
+ be remapped to only appear as an input (or output) only. | |
+ | |
+ The input and output tensors can be remapped. For instance, some input tensor | |
+ can be omitted. For instance, a subgraph representing an operation with two | |
+ inputs can be remapped to only take one input. Note that this does not change | |
+ at all the underlying `tf.Graph` (remember, it is a view). It means that | |
+ the other input is being ignored, or is being treated as "given". | |
+ The analogy with functions can be extended like this: F(x,y) is the original | |
+ function. Remapping the inputs from [x, y] to just [x] means that the subgraph | |
+ now represent the function F_y(x) (y is "given"). | |
+ | |
+ The output tensors can also be remapped. For instance, some output tensor can | |
+ be omitted. Other output tensor can be duplicated as well. As mentioned | |
+ before, this does not change at all the underlying `tf.Graph`. | |
+ The analogy with functions can be extended like this: F(...)->x,y is the | |
+ original function. Remapping the outputs from [x, y] to just [y,y] means that | |
+ the subgraph now represent the function M(F(...)) where M is the function | |
+ M(a,b)->b,b. | |
+ | |
+ It is useful to describe three other kind of tensors: | |
+ | |
+ * internal: an internal tensor is a tensor connecting operations contained | |
+ in the subgraph. One example in the subgraph representing the two | |
+ operations A and B connected sequentially: -> A -> B ->. The middle arrow | |
+ is an internal tensor. | |
+ * actual input: an input tensor of the subgraph, regardless of whether it is | |
+ listed in "inputs" or not (masked-out). | |
+ * actual output: an output tensor of the subgraph, regardless of whether it is | |
+ listed in "outputs" or not (masked-out). | |
+ * hidden input: an actual input which has been masked-out using an | |
+ input remapping. In other word, a hidden input is a non-internal tensor | |
+ not listed as a input tensor and one of whose consumers belongs to | |
+ the subgraph. | |
+ * hidden output: a actual output which has been masked-out using an output | |
+ remapping. In other word, a hidden output is a non-internal tensor | |
+ not listed as an output and one of whose generating operations belongs to | |
+ the subgraph. | |
+ | |
+ Here are some useful guarantees about an instance of a SubGraphView: | |
+ | |
+ * the input (or output) tensors are not internal. | |
+ * the input (or output) tensors are either "connected" or "passthrough". | |
+ * the passthrough tensors are not connected to any of the operation of | |
+ the subgraph. | |
+ | |
+ Note that there is no guarantee that an operation in a subgraph contributes | |
+ at all to its inputs or outputs. For instance, remapping both the inputs and | |
+ outputs to empty lists will produce a subgraph which still contains all the | |
+ original operations. However, the remove_unused_ops function can be used to | |
+ make a new subgraph view whose operations are connected to at least one of | |
+ the input or output tensors. | |
+ | |
+ An instance of this class is meant to be a lightweight object which is not | |
+ modified in-place by the user. Rather, the user can create new modified | |
+ instances of a given subgraph. In that sense, the class SubGraphView is meant | |
+ to be used like an immutable python object. | |
+ | |
+ A common problem when using views is that they can get out-of-sync with the | |
+ data they observe (in this case, a `tf.Graph`). This is up to the user to | |
+ ensure that this doesn't happen. To keep on the safe side, it is recommended | |
+ that the life time of subgraph views are kept very short. One way to achieve | |
+ this is to use subgraphs within a "with make_sgv(...) as sgv:" Python context. | |
+ | |
+ To alleviate the out-of-sync problem, some functions are granted the right to | |
+ modified subgraph in place. This is typically the case of graph manipulation | |
+ functions which, given some subgraphs as arguments, can modify the underlying | |
+ `tf.Graph`. Since this modification is likely to render the subgraph view | |
+ invalid, those functions can modify the argument in place to reflect the | |
+ change. For instance, calling the function swap_inputs(svg0, svg1) will modify | |
+ svg0 and svg1 in place to reflect the fact that their inputs have now being | |
+ swapped. | |
+ """ | |
+ | |
+ def __init__(self, inside_ops=(), passthrough_ts=()): | |
+ """Create a subgraph containing the given ops and the "passthrough" tensors. | |
+ | |
+ Args: | |
+ inside_ops: an object convertible to a list of `tf.Operation`. This list | |
+ defines all the operations in the subgraph. | |
+ passthrough_ts: an object convertible to a list of `tf.Tensor`. This list | |
+ define all the "passthrough" tensors. A passthrough tensor is a tensor | |
+ which goes directly from the input of the subgraph to it output, without | |
+ any intermediate operations. All the non passthrough tensors are | |
+ silently ignored. | |
+ Raises: | |
+ TypeError: if inside_ops cannot be converted to a list of `tf.Operation` | |
+ or if `passthrough_ts` cannot be converted to a list of `tf.Tensor`. | |
+ """ | |
+ | |
+ inside_ops = util.make_list_of_op(inside_ops) | |
+ passthrough_ts = util.make_list_of_t(passthrough_ts) | |
+ ops_and_ts = inside_ops + passthrough_ts | |
+ if ops_and_ts: | |
+ self._graph = util.get_unique_graph(ops_and_ts) | |
+ self._ops = inside_ops | |
+ | |
+ # Compute inside and outside tensor | |
+ inputs, outputs, insides = select.compute_boundary_ts(inside_ops) | |
+ | |
+ # Compute passthrough tensors, silently ignoring the non-passthrough ones. | |
+ all_tensors = frozenset(inputs + outputs + list(insides)) | |
+ self._passthrough_ts = [t for t in passthrough_ts if t not in all_tensors] | |
+ | |
+ # Set inputs and outputs. | |
+ self._input_ts = inputs + self._passthrough_ts | |
+ self._output_ts = outputs + self._passthrough_ts | |
+ else: | |
+ self._graph = None | |
+ self._passthrough_ts = [] | |
+ self._input_ts = [] | |
+ self._output_ts = [] | |
+ self._ops = [] | |
+ | |
+ def __copy__(self): | |
+ """Create a copy of this subgraph. | |
+ | |
+ Note that this class is a "view", copying it only create another view and | |
+ does not copy the underlying part of the `tf.Graph`. | |
+ | |
+ Returns: | |
+ A new identical instance of the original subgraph view. | |
+ """ | |
+ cls = self.__class__ | |
+ result = cls.__new__(cls) | |
+ for k, v in iteritems(self.__dict__): | |
+ if k == "_graph": | |
+ setattr(result, k, v) | |
+ else: | |
+ setattr(result, k, list(v)) # copy the list | |
+ return result | |
+ | |
+ def _assign_from(self, other): | |
+ """Assign other to itself. | |
+ | |
+ Args: | |
+ other: another subgraph-view. | |
+ Returns: | |
+ A new instance identical to the original one. | |
+ Raises: | |
+ TypeError: if other is not an SubGraphView. | |
+ """ | |
+ if not isinstance(other, SubGraphView): | |
+ raise TypeError("Expected SubGraphView, got: {}".format(type(other))) | |
+ # pylint: disable=protected-access | |
+ self._graph = other._graph | |
+ self._ops = list(other._ops) | |
+ self._passthrough_ts = list(other._passthrough_ts) | |
+ self._input_ts = list(other._input_ts) | |
+ self._output_ts = list(other._output_ts) | |
+ # pylint: enable=protected-access | |
+ | |
+ def copy(self): | |
+ """Return a copy of itself. | |
+ | |
+ Note that this class is a "view", copying it only create another view and | |
+ does not copy the underlying part of the tf.Graph. | |
+ | |
+ Returns: | |
+ A new instance identical to the original one. | |
+ """ | |
+ return copy.copy(self) | |
+ | |
+ def _remap_default(self, remove_input_map=True, remove_output_map=True): | |
+ """Remap in the place the inputs and/or outputs to the default mapping. | |
+ | |
+ Args: | |
+ remove_input_map: if True the input map is reset to the default one. | |
+ remove_output_map: if True the output map is reset to the default one. | |
+ """ | |
+ if not remove_input_map and not remove_output_map: | |
+ return | |
+ | |
+ # Compute inside and outside tensor | |
+ inputs, outputs, _ = select.compute_boundary_ts(self._ops) | |
+ if remove_input_map: | |
+ self._input_ts = list(inputs) + self._passthrough_ts | |
+ if remove_output_map: | |
+ self._output_ts = list(outputs) + self._passthrough_ts | |
+ | |
+ def remap_default(self, remove_input_map=True, remove_output_map=True): | |
+ """Remap the inputs and/or outputs to the default mapping. | |
+ | |
+ Args: | |
+ remove_input_map: if True the input map is reset to the default one. | |
+ remove_output_map: if True the output map is reset to the default one. | |
+ Returns: | |
+ A new modified instance of the original subgraph view with its | |
+ input and/or output mapping reset to the default one. | |
+ """ | |
+ res = self.copy() | |
+ res._remap_default(remove_input_map, remove_output_map) # pylint: disable=protected-access | |
+ return res | |
+ | |
+ def _remap_inputs(self, new_input_indices): | |
+ """Remap the inputs of the subgraph in-place.""" | |
+ new_input_indices = _finalize_indices(new_input_indices, self._input_ts) | |
+ _check_within_range( | |
+ new_input_indices, len(self._input_ts), repetition=False) | |
+ self._input_ts = [self._input_ts[i] for i in new_input_indices] | |
+ | |
+ def _remap_outputs(self, new_output_indices): | |
+ """Remap the outputs of the subgraph in-place.""" | |
+ new_output_indices = _finalize_indices(new_output_indices, self._output_ts) | |
+ _check_within_range( | |
+ new_output_indices, len(self._output_ts), repetition=True) | |
+ self._output_ts = [self._output_ts[i] for i in new_output_indices] | |
+ | |
+ def _remap_outputs_make_unique(self): | |
+ """Remap the outputs in place so that all the tensors appears only once.""" | |
+ output_ts = list(self._output_ts) | |
+ self._output_ts = [] | |
+ util.concatenate_unique(self._output_ts, output_ts) | |
+ | |
+ def _remap_outputs_to_consumers(self): | |
+ """Remap the outputs in place to match the number of consumers.""" | |
+ self._remap_outputs_make_unique() | |
+ output_ts = list(self._output_ts) | |
+ self._output_ts = [] | |
+ for t in output_ts: | |
+ self._output_ts += [t] * len(t.consumers()) | |
+ | |
+ def remap_outputs_make_unique(self): | |
+ """Remap the outputs so that all the tensors appears only once.""" | |
+ res = copy.copy(self) | |
+ res._remap_outputs_make_unique() # pylint: disable=protected-access | |
+ return res | |
+ | |
+ def remap_outputs_to_consumers(self): | |
+ """Remap the outputs to match the number of consumers.""" | |
+ res = copy.copy(self) | |
+ res._remap_outputs_to_consumers() # pylint: disable=protected-access | |
+ return res | |
+ | |
+ def _remove_unused_ops(self, control_inputs=True): | |
+ """Remove unused ops in place. | |
+ | |
+ Args: | |
+ control_inputs: if True, control inputs are used to detect used ops. | |
+ Returns: | |
+ A new subgraph view which only contains used operations. | |
+ """ | |
+ ops = select.get_walks_union_ops( | |
+ self.connected_inputs, | |
+ self.connected_outputs, | |
+ within_ops=self._ops, | |
+ control_inputs=control_inputs) | |
+ self._ops = [op for op in self._ops if op in ops] | |
+ | |
+ def remove_unused_ops(self, control_inputs=True): | |
+ """Remove unused ops. | |
+ | |
+ Args: | |
+ control_inputs: if True, control inputs are used to detect used ops. | |
+ Returns: | |
+ A new subgraph view which only contains used operations. | |
+ """ | |
+ res = copy.copy(self) | |
+ res._remove_unused_ops(control_inputs) # pylint: disable=protected-access | |
+ return res | |
+ | |
+ def remap_inputs(self, new_input_indices): | |
+ """Remap the inputs of the subgraph. | |
+ | |
+ If the inputs of the original subgraph are [t0, t1, t2], remapping to [2,0] | |
+ will create a new instance whose inputs is [t2, t0]. | |
+ | |
+ Note that this is only modifying the view: the underlying `tf.Graph` is not | |
+ affected. | |
+ | |
+ Args: | |
+ new_input_indices: an iterable of integers or tf.Tensors | |
+ representing a mapping between the old inputs and the new ones. | |
+ Integers must be positive and smaller than the number of old inputs. | |
+ tf.Tensors must belong to the old list of inputs. | |
+ This mapping can be under-complete and must be without repetitions. | |
+ Returns: | |
+ A new modified instance of the original subgraph view with remapped | |
+ inputs. | |
+ """ | |
+ res = self.copy() | |
+ res._remap_inputs(new_input_indices) # pylint: disable=protected-access | |
+ return res | |
+ | |
+ def remap_outputs(self, new_output_indices): | |
+ """Remap the output of the subgraph. | |
+ | |
+ If the output of the original subgraph are [t0, t1, t2], remapping to | |
+ [1,1,0] will create a new instance whose outputs is [t1, t1, t0]. | |
+ | |
+ Note that this is only modifying the view: the underlying tf.Graph is not | |
+ affected. | |
+ | |
+ Args: | |
+ new_output_indices: an iterable of integers or tf.Tensors | |
+ representing a mapping between the old outputs and the new ones. | |
+ Integers must be positive and smaller than the number of old outputs. | |
+ tf.Tensors must belong to the old list of outputs. | |
+ This mapping can be under-complete and can have repetitions. | |
+ Returns: | |
+ A new modified instance of the original subgraph view with remapped | |
+ outputs. | |
+ """ | |
+ res = copy.copy(self) | |
+ res._remap_outputs(new_output_indices) # pylint: disable=protected-access | |
+ return res | |
+ | |
+ def remap(self, new_input_indices=None, new_output_indices=None): | |
+ """Remap the inputs and outputs of the subgraph. | |
+ | |
+ Note that this is only modifying the view: the underlying tf.Graph is not | |
+ affected. | |
+ | |
+ Args: | |
+ new_input_indices: an iterable of integers or tf.Tensors | |
+ representing a mapping between the old inputs and the new ones. | |
+ Integers must be positive and smaller than the number of old inputs. | |
+ tf.Tensors must belong to the old list of inputs. | |
+ This mapping can be under-complete and must be without repetitions. | |
+ new_output_indices: an iterable of integers or tf.Tensors | |
+ representing a mapping between the old outputs and the new ones. | |
+ Integers must be positive and smaller than the number of old outputs. | |
+ tf.Tensors must belong to the old list of outputs. | |
+ This mapping can be under-complete and can have repetitions. | |
+ Returns: | |
+ A new modified instance of the original subgraph view with remapped | |
+ inputs and outputs. | |
+ """ | |
+ res = copy.copy(self) | |
+ if new_input_indices is not None: | |
+ res._remap_inputs(new_input_indices) # pylint: disable=protected-access | |
+ if new_output_indices is not None: | |
+ res._remap_outputs(new_output_indices) # pylint: disable=protected-access | |
+ return res | |
+ | |
+ def find_op_by_name(self, op_name): | |
+ """Return the op named op_name. | |
+ | |
+ Args: | |
+ op_name: the name to search for | |
+ Returns: | |
+ The op named op_name. | |
+ Raises: | |
+ ValueError: if the op_name could not be found. | |
+ AssertionError: if the name was found multiple time. | |
+ """ | |
+ res = [op for op in self._ops if op.name == op_name] | |
+ if not res: | |
+ raise ValueError("{} not in subgraph.".format(op_name)) | |
+ if len(res) > 1: | |
+ raise AssertionError("More than 1 op named: {}!".format(op_name)) | |
+ return res[0] | |
+ | |
+ def __str__(self): | |
+ if not self: | |
+ return "SubGraphView: empty" | |
+ | |
+ def op_name(op): | |
+ return op.name | |
+ | |
+ def tensor_name(t): | |
+ if t in self._passthrough_ts: | |
+ return "{} *".format(t.name) | |
+ else: | |
+ return t.name | |
+ | |
+ def print_list(name, iterable, get_name): | |
+ if iterable: | |
+ print("** {}[{}]:".format(name, len(iterable)), file=res) | |
+ print("\n".join([" {}".format(get_name(elem)) for elem in iterable]), | |
+ file=res) | |
+ else: | |
+ print("** {}: empty".format(name), file=res) | |
+ | |
+ res = StringIO() | |
+ print("SubGraphView (graphid={}):".format(id(self.graph)), file=res) | |
+ print_list("ops", self._ops, op_name) | |
+ print_list("inputs", self._input_ts, tensor_name) | |
+ print_list("outputs", self._output_ts, tensor_name) | |
+ return res.getvalue() | |
+ | |
+ @property | |
+ def graph(self): | |
+ """The underlying `tf.Graph`.""" | |
+ return self._graph | |
+ | |
+ @property | |
+ def ops(self): | |
+ """The operations in this subgraph view.""" | |
+ return self._ops | |
+ | |
+ @property | |
+ def inputs(self): | |
+ """The input tensors of this subgraph view.""" | |
+ return util.ListView(self._input_ts) | |
+ | |
+ @property | |
+ def connected_inputs(self): | |
+ """The connected input tensors of this subgraph view.""" | |
+ return [t for t in self._input_ts if t not in self._passthrough_ts] | |
+ | |
+ @property | |
+ def outputs(self): | |
+ """The output tensors of this subgraph view.""" | |
+ return util.ListView(self._output_ts) | |
+ | |
+ @property | |
+ def connected_outputs(self): | |
+ """The connected output tensors of this subgraph view.""" | |
+ return [t for t in self._output_ts if t not in self._passthrough_ts] | |
+ | |
+ @property | |
+ def passthroughs(self): | |
+ """The passthrough tensors, going straight from input to output.""" | |
+ return util.ListView(self._passthrough_ts) | |
+ | |
+ def __bool__(self): | |
+ """Allows for implicit boolean conversion.""" | |
+ return self._graph is not None | |
+ | |
+ # Python 3 wants __bool__, Python 2.7 wants __nonzero__ | |
+ __nonzero__ = __bool__ | |
+ | |
+ def op(self, op_id): | |
+ """Get an op by its index.""" | |
+ return self._ops[op_id] | |
+ | |
+ def is_passthrough(self, t): | |
+ """Check whether a tensor is passthrough.""" | |
+ return t in self._passthrough_ts | |
+ | |
+ def __enter__(self): | |
+ """Allow Python context to minimize the life time of a subgraph view. | |
+ | |
+ A subgraph view is meant to be a lightweight and transient object. A short | |
+ lifetime will alleviate the "out-of-sync" issue mentioned earlier. For that | |
+ reason, a SubGraphView instance can be used within a Python context. For | |
+ example: | |
+ | |
+ from tensorflow.contrib import graph_editor as ge | |
+ with ge.make_sgv(...) as sgv: | |
+ print(sgv) | |
+ | |
+ Returns: | |
+ Itself. | |
+ """ | |
+ return self | |
+ | |
+ def __exit__(self, exc_type, exc_value, traceback): | |
+ pass | |
+ | |
+ def input_index(self, t): | |
+ """Find the input index corresponding to the given input tensor t. | |
+ | |
+ Args: | |
+ t: the input tensor of this subgraph view. | |
+ Returns: | |
+ The index in the self.inputs list. | |
+ Raises: | |
+ Error: if t in not an input tensor. | |
+ """ | |
+ try: | |
+ subgraph_id = self._input_ts.index(t) | |
+ except: | |
+ raise ValueError("Can't find {} in inputs of subgraph {}.".format( | |
+ t.name, self.name)) | |
+ return subgraph_id | |
+ | |
+ def output_index(self, t): | |
+ """Find the output index corresponding to given output tensor t. | |
+ | |
+ Args: | |
+ t: the output tensor of this subgraph view. | |
+ Returns: | |
+ The index in the self.outputs list. | |
+ Raises: | |
+ Error: if t in not an output tensor. | |
+ """ | |
+ try: | |
+ subgraph_id = self._output_ts.index(t) | |
+ except: | |
+ raise ValueError("Can't find {} in outputs of subgraph {}.".format( | |
+ t.name, self.name)) | |
+ return subgraph_id | |
+ | |
+ def consumers(self): | |
+ """Return a Python set of all the consumers of this subgraph view. | |
+ | |
+ A consumer of a subgraph view is a tf.Operation which is a consumer | |
+ of one of the output tensors and is not in the subgraph. | |
+ | |
+ Returns: | |
+ A list of `tf.Operation` which are the consumers of this subgraph view. | |
+ """ | |
+ ops_set = frozenset(self._ops) | |
+ res = [] | |
+ for output in self._output_ts: | |
+ consumers = [op for op in output.consumers() if op not in ops_set] | |
+ util.concatenate_unique(res, consumers) | |
+ return res | |
+ | |
+ | |
+def _check_graph(sgv, graph): | |
+ """Check if sgv belongs to the given graph. | |
+ | |
+ Args: | |
+ sgv: a SubGraphView. | |
+ graph: a graph or None. | |
+ Returns: | |
+ The SubGraphView sgv. | |
+ Raises: | |
+ TypeError: if sgv is not a SubGraphView or if graph is not None and not | |
+ a tf.Graph. | |
+ ValueError: if the graph of sgv and the given graph are not None and | |
+ different. | |
+ """ | |
+ if not isinstance(sgv, SubGraphView): | |
+ raise TypeError("Expected a SubGraphView, got: {}".format(type(graph))) | |
+ if graph is None or not sgv.graph: | |
+ return sgv | |
+ if not isinstance(graph, tf_ops.Graph): | |
+ raise TypeError("Expected a tf.Graph, got: {}".format(type(graph))) | |
+ if sgv.graph is not graph: | |
+ raise ValueError("Graph mismatch.") | |
+ return sgv | |
+ | |
+ | |
+def make_view(*args, **kwargs): | |
+ """Create a SubGraphView from selected operations and passthrough tensors. | |
+ | |
+ Args: | |
+ *args: list of 1) regular expressions (compiled or not) or 2) (array of) | |
+ `tf.Operation` 3) (array of) `tf.Tensor`. Those objects will be converted | |
+ into a list of operations and a list of candidate for passthrough tensors. | |
+ **kwargs: keyword graph is used 1) to check that the ops and ts are from | |
+ the correct graph 2) for regular expression query | |
+ Returns: | |
+ A subgraph view. | |
+ Raises: | |
+ TypeError: if the optional keyword argument graph is not a `tf.Graph` | |
+ or if an argument in args is not an (array of) `tf.Tensor` | |
+ or an (array of) `tf.Operation` or a string or a regular expression. | |
+ ValueError: if one of the keyword arguments is unexpected. | |
+ """ | |
+ # get keywords arguments | |
+ graph = kwargs["graph"] if "graph" in kwargs else None | |
+ | |
+ # already a view? | |
+ if len(args) == 1 and isinstance(args[0], SubGraphView): | |
+ return _check_graph(args[0], graph) | |
+ | |
+ ops, ts = select.select_ops_and_ts(*args, **kwargs) | |
+ sgv = SubGraphView(ops, ts) | |
+ return _check_graph(sgv, graph) | |
+ | |
+ | |
+def make_view_from_scope(scope, graph): | |
+ """Make a subgraph from a name scope. | |
+ | |
+ Args: | |
+ scope: the name of the scope. | |
+ graph: the `tf.Graph`. | |
+ Returns: | |
+ A subgraph view representing the given scope. | |
+ """ | |
+ ops = select.get_name_scope_ops(graph, scope) | |
+ return SubGraphView(ops) | |
diff --git a/tflex_graph_editor/transform.py b/tflex_graph_editor/transform.py | |
new file mode 100644 | |
index 0000000..e33223a | |
--- /dev/null | |
+++ b/tflex_graph_editor/transform.py | |
@@ -0,0 +1,752 @@ | |
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved. | |
+# | |
+# Licensed under the Apache License, Version 2.0 (the "License"); | |
+# you may not use this file except in compliance with the License. | |
+# You may obtain a copy of the License at | |
+# | |
+# http://www.apache.org/licenses/LICENSE-2.0 | |
+# | |
+# Unless required by applicable law or agreed to in writing, software | |
+# distributed under the License is distributed on an "AS IS" BASIS, | |
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+# See the License for the specific language governing permissions and | |
+# limitations under the License. | |
+# ============================================================================== | |
+"""Class to transform an subgraph into another. | |
+""" | |
+ | |
+from __future__ import absolute_import | |
+from __future__ import division | |
+from __future__ import print_function | |
+ | |
+from copy import deepcopy | |
+from functools import partial | |
+from six import iteritems | |
+from six import string_types | |
+from six import StringIO | |
+from tflex_graph_editor import reroute | |
+from tflex_graph_editor import select | |
+from tflex_graph_editor import subgraph | |
+from tflex_graph_editor import util | |
+from tensorflow.python.framework import ops as tf_ops | |
+from tensorflow.python.platform import tf_logging as logging | |
+ | |
+ | |
+__all__ = [ | |
+ "replace_t_with_placeholder_handler", | |
+ "keep_t_if_possible_handler", | |
+ "assign_renamed_collections_handler", | |
+ "transform_op_if_inside_handler", | |
+ "copy_op_handler", | |
+ "Transformer", | |
+ "TransformerInfo", | |
+ "copy", | |
+ "copy_with_input_replacements", | |
+ "graph_replace", | |
+] | |
+ | |
+ | |
+def replace_t_with_placeholder_handler(info, t): | |
+ """Transform a tensor into a placeholder tensor. | |
+ | |
+ This handler is typically used to transform a subgraph input tensor into a | |
+ placeholder. | |
+ | |
+ Args: | |
+ info: Transform._TmpInfo instance. | |
+ t: tensor whose input must be transformed into a place holder. | |
+ Returns: | |
+ The tensor generated by the newly created place holder. | |
+ """ | |
+ with info.graph_.as_default(): | |
+ t_ = util.make_placeholder_from_tensor(t, scope=info.scope_) | |
+ return t_ | |
+ | |
+ | |
+def keep_t_if_possible_handler(info, t): | |
+ """Transform a tensor into itself (identity) if possible. | |
+ | |
+ This handler transform a tensor into itself if the source and destination | |
+ graph are the same. Otherwise it will create a placeholder. | |
+ This handler is typically used to transform a hidden input tensors. | |
+ | |
+ Args: | |
+ info: Transform._TmpInfo instance. | |
+ t: tensor whose input must be transformed into a place holder. | |
+ Returns: | |
+ The tensor generated by the newly created place holder. | |
+ """ | |
+ if info.graph is info.graph_: | |
+ return t | |
+ else: | |
+ return replace_t_with_placeholder_handler(info, t) | |
+ | |
+ | |
+def assign_renamed_collections_handler(info, elem, elem_): | |
+ """Add the transformed elem to the (renamed) collections of elem. | |
+ | |
+ A collection is renamed only if is not a known key, as described in | |
+ `tf.compat.v1.GraphKeys`. | |
+ | |
+ Args: | |
+ info: Transform._TmpInfo instance. | |
+ elem: the original element (`tf.Tensor` or `tf.Operation`) | |
+ elem_: the transformed element | |
+ """ | |
+ known_collection_names = util.get_predefined_collection_names() | |
+ for name, collection in iteritems(info.collections): | |
+ if elem not in collection: | |
+ continue | |
+ | |
+ if name in known_collection_names: | |
+ transformed_name = name | |
+ else: | |
+ transformed_name = info.new_name(name) | |
+ info.graph_.add_to_collection(transformed_name, elem_) | |
+ | |
+ | |
+def transform_op_if_inside_handler(info, op, keep_if_possible=True): | |
+ """Transform an optional op only if it is inside the subgraph. | |
+ | |
+ This handler is typically use to handle original op: it is fine to keep them | |
+ if they are inside the subgraph, otherwise they are just ignored. | |
+ | |
+ Args: | |
+ info: Transform._TmpInfo instance. | |
+ op: the optional op to transform (or ignore). | |
+ keep_if_possible: re-attach to the original op if possible, that is, | |
+ if the source graph and the destination graph are the same. | |
+ Returns: | |
+ The transformed op or None. | |
+ """ | |
+ if op in info.sgv.ops: | |
+ return info.transformed_ops[op] | |
+ else: | |
+ if keep_if_possible and info.graph is info.graph_: | |
+ return op | |
+ else: | |
+ return None | |
+ | |
+ | |
+def copy_op_handler(info, op, new_inputs, copy_shape=False, nodedef_fn=None): | |
+ """Copy a `tf.Operation`. | |
+ | |
+ Args: | |
+ info: Transform._TmpInfo instance. | |
+ op: the `tf.Operation` to be copied. | |
+ new_inputs: The new inputs for this op. | |
+ copy_shape: also copy the shape of the tensor | |
+ nodedef_fn: If provided, a function that will be run on the NodeDef | |
+ and should return a mutated NodeDef before a new Operation is created. | |
+ This is useful as certain features cannot be set on the Operation and | |
+ must be modified in NodeDef. | |
+ | |
+ Returns: | |
+ A `(op, op_outputs)` tuple containing the transformed op and its outputs. | |
+ """ | |
+ # The `new_inputs` was added to this function. For compatibility reason, | |
+ # let's raise an error if `new_inputs` is a boolean. | |
+ if isinstance(new_inputs, bool): | |
+ raise TypeError("the `new_inputs` argument must be an iterable.") | |
+ | |
+ # pylint: disable=protected-access | |
+ | |
+ # Clone the node def: | |
+ node_def_ = deepcopy(op.node_def) | |
+ | |
+ # Transform name: | |
+ name_ = info.new_name(op.name) | |
+ name_ = info.graph_.unique_name(name_) | |
+ node_def_.name = name_ | |
+ | |
+ # Mutate NodeDef if requested: | |
+ if nodedef_fn is not None: | |
+ node_def_ = nodedef_fn(node_def_) | |
+ | |
+ # Copy the other inputs needed for initialization | |
+ output_types_ = op._output_types[:] | |
+ input_types_ = op._input_types[:] | |
+ | |
+ # Make a copy of the op_def too. | |
+ # Its unique to every _type_ of Operation. | |
+ op_def_ = deepcopy(op.op_def) | |
+ | |
+ # Initialize a new Operation instance | |
+ op_ = tf_ops.Operation(node_def_, info.graph_, new_inputs, output_types_, | |
+ [], input_types_, None, op_def_) | |
+ | |
+ # copy the shape over | |
+ if copy_shape: | |
+ for t, t_ in zip(op.outputs, op_.outputs): | |
+ t_.set_shape(t.get_shape()) | |
+ | |
+ # Original op cannot be finalised here yet. Because some ops require this | |
+ # attribute to exist, we will create a dummy original_op first and then | |
+ # later finalise it with the actual original_op when all the ops have | |
+ # been copied. | |
+ # TODO(fkp): Stop worrying about _original_op and remove this code? | |
+ if op._original_op: | |
+ op_._original_op = op._original_op | |
+ | |
+ return op_, op_.outputs | |
+ | |
+ | |
+class TransformerInfo(object): | |
+ """"Contains information about the result of a transform operation.""" | |
+ | |
+ def __init__(self, info): | |
+ """Constructor. | |
+ | |
+ Args: | |
+ info: an instance of Transformer._TmpInfo containing various internal | |
+ information about the transform operation. | |
+ """ | |
+ self._graph = info.graph | |
+ self._scope = info.scope | |
+ self._graph_ = info.graph_ | |
+ self._scope_ = info.scope_ | |
+ self._transformed_ops = info.transformed_ops | |
+ self._transformed_ts = info.transformed_ts | |
+ | |
+ def _get_transformed_map(self, top): | |
+ """Return the correct container depending on the type of `top`.""" | |
+ if isinstance(top, tf_ops.Operation): | |
+ return self._transformed_ops | |
+ elif isinstance(top, tf_ops.Tensor): | |
+ return self._transformed_ts | |
+ else: | |
+ raise TypeError( | |
+ "Expected a tf.Tensor or a tf.Operation, got a {}".format( | |
+ type(top))) | |
+ | |
+ def _transformed_elem(self, original_top, missing_fn=None): | |
+ """Return the transformed op/tensor corresponding to the original one. | |
+ | |
+ Args: | |
+ original_top: the original tensor/operation. | |
+ missing_fn: function handling the case where the counterpart | |
+ cannot be found. By default, None is returned. | |
+ Returns: | |
+ the transformed tensor/operation (or None if no match is found). | |
+ """ | |
+ transformed_map = self._get_transformed_map(original_top) | |
+ if isinstance(original_top, string_types): | |
+ for original, transformed in iteritems(transformed_map): | |
+ if original.name == original_top: | |
+ return transformed | |
+ return None if missing_fn is None else missing_fn(original_top) | |
+ else: | |
+ if original_top not in transformed_map: | |
+ return None if missing_fn is None else missing_fn(original_top) | |
+ return transformed_map[original_top] | |
+ | |
+ def _original_elem(self, transformed_top, missing_fn=None): | |
+ """Return the original op/tensor corresponding to the transformed one. | |
+ | |
+ Args: | |
+ transformed_top: the transformed tensor/operation. | |
+ missing_fn: function handling the case where the counterpart | |
+ cannot be found. By default, None is returned. | |
+ Returns: | |
+ the original tensor/operation (or None if no match is found). | |
+ """ | |
+ transformed_map = self._get_transformed_map(transformed_top) | |
+ if isinstance(transformed_top, string_types): | |
+ finder = lambda transformed: transformed.name == transformed_top | |
+ else: | |
+ finder = lambda transformed: transformed == transformed_top | |
+ for original, transformed in iteritems(transformed_map): | |
+ if finder(transformed): | |
+ return original | |
+ return None if missing_fn is None else missing_fn(transformed_top) | |
+ | |
+ def transformed(self, original, missing_fn=None): | |
+ """Return the transformed op/tensor corresponding to the original one. | |
+ | |
+ Note that the output of this function mimics the hierarchy | |
+ of its input argument `original`. | |
+ Given an iterable, it returns a list. Given an operation or a tensor, | |
+ it will return an operation or a tensor. | |
+ | |
+ Args: | |
+ original: the original tensor/operation. | |
+ missing_fn: function handling the case where the counterpart | |
+ cannot be found. By default, None is returned. | |
+ Returns: | |
+ the transformed tensor/operation (or None if no match is found). | |
+ """ | |
+ transformed_elem = partial(self._transformed_elem, missing_fn=missing_fn) | |
+ return util.transform_tree(original, transformed_elem) | |
+ | |
+ def original(self, transformed, missing_fn=None): | |
+ """Return the original op/tensor corresponding to the transformed one. | |
+ | |
+ Note that the output of this function mimics the hierarchy | |
+ of its input argument `transformed`. | |
+ Given an iterable, it returns a list. Given an operation or a tensor, | |
+ it will return an operation or a tensor. | |
+ | |
+ Args: | |
+ transformed: the transformed tensor/operation. | |
+ missing_fn: function handling the case where the counterpart | |
+ cannot be found. By default, None is returned. | |
+ Returns: | |
+ the original tensor/operation (or None if no match is found). | |
+ """ | |
+ original_elem = partial(self._original_elem, missing_fn=missing_fn) | |
+ return util.transform_tree(transformed, original_elem) | |
+ | |
+ def __str__(self): | |
+ res = StringIO() | |
+ print("Transform result info:", file=res) | |
+ if self._graph == self._graph_: | |
+ in_place_str = "" if self._scope_ else " IN-PLACE" | |
+ print(" Within graph[{}]{}".format( | |
+ id(self._graph), in_place_str), file=res) | |
+ else: | |
+ print(" graph[{}] => graph[{}]".format( | |
+ id(self._graph), id(self._graph_)), file=res) | |
+ if self._scope: | |
+ print(" Relative to source scope: {}".format(self._scope), file=res) | |
+ if self._scope_: | |
+ print(" Scope destination: {}".format(self._scope_), file=res) | |
+ print("Operations mapping:", file=res) | |
+ for op, op_ in iteritems(self._transformed_ops): | |
+ print(" {} => {}".format(op.name, op_.name), file=res) | |
+ return res.getvalue() | |
+ | |
+ | |
+class _TmpInfo(object): | |
+ """Transformer temporary data. | |
+ | |
+ An instance of this class holds all the information relevant to a call | |
+ to a transformer instance (that is, a call to __call__). An instance | |
+ is created for the life-time of the __call__ function and is passed as | |
+ argument to the handlers. | |
+ """ | |
+ | |
+ def __init__(self, sgv, dst_graph, dst_scope, src_scope): | |
+ self.sgv = sgv | |
+ self.sgv_inputs_set = frozenset(sgv.inputs) | |
+ self.ops = frozenset(sgv.ops) | |
+ self.control_outputs = util.ControlOutputs(sgv.graph) | |
+ self.graph = sgv.graph | |
+ self.scope = src_scope | |
+ self.graph_ = dst_graph | |
+ self.scope_ = dst_scope | |
+ self.transformed_ops = {} | |
+ self.transformed_ts = {} | |
+ self.collections = dict((key, self.graph.get_collection(key)) | |
+ for key in self.graph.get_all_collection_keys()) | |
+ self.cyclic_ops = [] | |
+ self.transform_original_op_handler = transform_op_if_inside_handler | |
+ # The graph is transformed op by op, in the same order the original ops | |
+ # were created. However, this is sometimes not possible due to cycles | |
+ # (i.e. while loops). So when the transformer creates a new op whose | |
+ # inputs do not exist yet, temporary placeholders are created and stored | |
+ # in this `tmp_cyclic_ts` container. During a second pass, | |
+ # those temporary tensors are replaced by the proper transformed tensors | |
+ # (see the function `_finalize_cycles`). | |
+ self.tmp_cyclic_ts = [] | |
+ | |
+ def new_name(self, name): | |
+ """Compute a destination name from a source name. | |
+ | |
+ Args: | |
+ name: the name to be "transformed". | |
+ Returns: | |
+ The transformed name. | |
+ Raises: | |
+ ValueError: if the source scope is used (that is, not an empty string) | |
+ and the source name does not belong to the source scope. | |
+ """ | |
+ scope = self.scope | |
+ if not name.startswith(scope): | |
+ raise ValueError("{} does not belong to source scope: {}.".format( | |
+ name, scope)) | |
+ rel_name = name[len(scope):] | |
+ name_ = self.scope_ + rel_name | |
+ return name_ | |
+ | |
+ | |
+class Transformer(object): | |
+ """Transform a subgraph into another one. | |
+ | |
+ By default, the constructor create a transform which copy a subgraph and | |
+ replaces inputs with placeholders. This behavior can be modified by changing | |
+ the handlers. | |
+ """ | |
+ | |
+ def __init__(self): | |
+ """Transformer constructor. | |
+ | |
+ The following members can be modified: | |
+ transform_op_handler: handle the transformation of a `tf.Operation`. | |
+ This handler defaults to a simple copy. | |
+ assign_collections_handler: handle the assignment of collections. | |
+ This handler defaults to assigning new collections created under the | |
+ given name-scope. | |
+ transform_external_input_handler: handle the transform of the inputs to | |
+ the given subgraph. This handler defaults to creating placeholders | |
+ instead of the ops just before the input tensors of the subgraph. | |
+ transform_external_hidden_input_handler: handle the transform of the | |
+ hidden inputs of the subgraph, that is, the inputs which are not listed | |
+ in sgv.inputs. This handler defaults to a transform which keep the same | |
+ input if the source and destination graphs are the same, otherwise | |
+ use placeholders. | |
+ transform_original_op_handler: handle the transform of original_op. This | |
+ handler defaults to transforming original_op only if they are in the | |
+ subgraph, otherwise they are ignored. | |
+ """ | |
+ | |
+ # handlers | |
+ self.transform_op_handler = copy_op_handler | |
+ self.transform_control_input_handler = transform_op_if_inside_handler | |
+ self.assign_collections_handler = assign_renamed_collections_handler | |
+ self.transform_external_input_handler = replace_t_with_placeholder_handler | |
+ self.transform_external_hidden_input_handler = keep_t_if_possible_handler | |
+ self.transform_original_op_handler = transform_op_if_inside_handler | |
+ | |
+ def __call__(self, | |
+ sgv, | |
+ dst_graph, | |
+ dst_scope, | |
+ src_scope="", | |
+ reuse_dst_scope=False): | |
+ """Execute the transformation. | |
+ | |
+ Args: | |
+ sgv: the source subgraph-view. | |
+ dst_graph: the destination graph. | |
+ dst_scope: the destination scope. | |
+ src_scope: the source scope, which specify the path from which the | |
+ relative path of the transformed nodes are computed. For instance, if | |
+ src_scope is a/ and dst_scoped is b/, then the node a/x/y will have a | |
+ relative path of x/y and will be transformed into b/x/y. | |
+ reuse_dst_scope: if True the dst_scope is re-used if it already exists. | |
+ Otherwise, the scope is given a unique name based on the one given | |
+ by appending an underscore followed by a digit (default). | |
+ Returns: | |
+ A tuple `(sgv, info)` where: | |
+ `sgv` is the transformed subgraph view; | |
+ `info` is an instance of TransformerInfo containing | |
+ information about the transform, including mapping between | |
+ original and transformed tensors and operations. | |
+ Raises: | |
+ ValueError: if the arguments are invalid. | |
+ """ | |
+ sgv = subgraph.make_view(sgv) | |
+ if not isinstance(dst_graph, tf_ops.Graph): | |
+ raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph))) | |
+ | |
+ src_scope = util.scope_finalize(src_scope) | |
+ dst_scope = util.scope_finalize(dst_scope) | |
+ | |
+ # Potentially create new scope if reuse_dst_scope is False | |
+ if dst_scope and not reuse_dst_scope: | |
+ dst_scope = util.scope_finalize(dst_graph.unique_name(dst_scope[:-1])) | |
+ | |
+ # Create temporary info used during this transform call | |
+ info = _TmpInfo(sgv, dst_graph, dst_scope, src_scope) | |
+ | |
+ self._copy_ops(info) | |
+ self._finalize_cycles(info) | |
+ self._connect_control_inputs(info) | |
+ | |
+ # Compute information about the transformation | |
+ res_info = TransformerInfo(info) | |
+ sgv_ = self._transform_sgv(info, sgv) | |
+ return sgv_, res_info | |
+ | |
+ def _copy_ops(self, info): | |
+ """Copy ops without connecting them.""" | |
+ sorted_ops = sorted(info.sgv.ops, key=lambda op: op._id) # pylint: disable=protected-access | |
+ for op in sorted_ops: | |
+ new_inputs = [self._transformed_t(info, t, op) for t in op.inputs] | |
+ op_, op_outputs_ = self.transform_op_handler(info, op, new_inputs) | |
+ if op is op_: | |
+ raise ValueError("In-place transformation not allowed.") | |
+ | |
+ # Process op. | |
+ info.transformed_ops[op] = op_ | |
+ self.assign_collections_handler(info, op, op_) | |
+ | |
+ # Process output tensors. | |
+ for op_output, op_output_ in zip(op.outputs, op_outputs_): | |
+ info.transformed_ts[op_output] = op_output_ | |
+ self.assign_collections_handler(info, op_output, op_output_) | |
+ | |
+ def _finalize_cycles(self, info): | |
+ """Reconnects the cyclic tensors.""" | |
+ for t, tmp_t_, consumer_op in info.tmp_cyclic_ts: | |
+ if t not in info.transformed_ts: | |
+ raise ValueError("The tensor {} should be transformed by now.".format( | |
+ t.name)) | |
+ if consumer_op not in info.transformed_ops: | |
+ raise ValueError("The op {} should be transformed by now.".format( | |
+ consumer_op.name)) | |
+ t_ = info.transformed_ts[t] | |
+ consumer_op_ = info.transformed_ops[consumer_op] | |
+ t_index_ = list(consumer_op_.inputs).index(tmp_t_) | |
+ consumer_op_._update_input(t_index_, t_) # pylint: disable=protected-access | |
+ | |
+ def _connect_control_inputs(self, info): | |
+ """Connect the previously copied ops.""" | |
+ for op in info.sgv.ops: | |
+ logging.debug("Connecting control inputs of op: %s", op.name) | |
+ op_ = info.transformed_ops[op] | |
+ | |
+ # Finalize original op. | |
+ # TODO(fkp): Stop worrying about _original_op and remove this code? | |
+ # pylint: disable=protected-access | |
+ if op._original_op: | |
+ original_op = self.transform_original_op_handler(info, op._original_op) | |
+ if original_op is None: | |
+ logging.debug("Could not find original op for: %s", op_.name) | |
+ else: | |
+ op_._original_op = original_op | |
+ # pylint: enable=protected-access | |
+ | |
+ # Finalize control inputs: | |
+ control_inputs_ = [self.transform_control_input_handler(info, ci) | |
+ for ci in op.control_inputs] | |
+ control_inputs_ = [ci for ci in control_inputs_ if ci is not None] | |
+ reroute.add_control_inputs(op_, control_inputs_) | |
+ | |
+ def _transform_sgv(self, info, sgv): | |
+ """Transform a subgraph view. | |
+ | |
+ For convenience, a transform operation returns a subgraph view of the | |
+ transformed graph. | |
+ | |
+ Args: | |
+ info: Temporary information for this transorfm call. | |
+ sgv: the subgraph to be transformed. | |
+ Returns: | |
+ The transformed subgraph. | |
+ """ | |
+ ops_ = [op_ for _, op_ in iteritems(info.transformed_ops)] | |
+ sgv_ = subgraph.SubGraphView(ops_) | |
+ sgv_inputs_ = sgv_.inputs | |
+ sgv_outputs_ = sgv_.outputs | |
+ | |
+ # re-order inputs | |
+ input_map_ = [] | |
+ for input_t in sgv.inputs: | |
+ if input_t not in info.transformed_ts: | |
+ continue | |
+ input_t_ = info.transformed_ts[input_t] | |
+ if input_t_ not in sgv_inputs_: | |
+ continue | |
+ input_t_index_ = sgv_.input_index(input_t_) | |
+ input_map_.append(input_t_index_) | |
+ | |
+ # re-order outputs | |
+ output_map_ = [] | |
+ for output_t in sgv.outputs: | |
+ if output_t not in info.transformed_ts: | |
+ continue | |
+ output_t_ = info.transformed_ts[output_t] | |
+ if output_t_ not in sgv_outputs_: | |
+ continue | |
+ output_t_index_ = sgv_.output_index(output_t_) | |
+ output_map_.append(output_t_index_) | |
+ | |
+ return sgv_.remap(input_map_, output_map_) | |
+ | |
+ def _transformed_t(self, info, t, consumer_op): | |
+ """Return tre transformed tensor of `t`.""" | |
+ if t in info.transformed_ts: | |
+ # If op is in the subgraph, just return its transformed counterpart. | |
+ return info.transformed_ts[t] | |
+ | |
+ if t in info.sgv_inputs_set: | |
+ # `t` is an input of the subgraph. | |
+ return self.transform_external_input_handler(info, t) | |
+ elif t.op in info.ops: | |
+ # `t` is an internal tensor but is not transformed yet because it | |
+ # belongs to a graph cycle. | |
+ logging.debug("Cyclic tensor: t.name = %s", t.name) | |
+ # Try to find an existing tensor we can use for now, | |
+ # otherwise create one. We'll rewire this later. | |
+ if consumer_op.type == "Merge": | |
+ first_input = consumer_op.inputs[0] | |
+ tmp_t_ = self._transformed_t(info, first_input, consumer_op) | |
+ elif t.op.type == "Enter": | |
+ enter_input = t.op.inputs[0] | |
+ tmp_t_ = self._transformed_t(info, enter_input, consumer_op) | |
+ else: | |
+ with info.graph_.as_default(): | |
+ tmp_t_ = util.make_placeholder_from_tensor(t, scope=info.scope_, | |
+ prefix="geph_tmp") | |
+ logging.debug("Created temporary placeholder: %s.", tmp_t_.name) | |
+ # Register as temporary and return. | |
+ info.tmp_cyclic_ts.append((t, tmp_t_, consumer_op)) | |
+ return tmp_t_ | |
+ else: | |
+ # `t` is a hidden input of the subgraph. | |
+ return self.transform_external_hidden_input_handler(info, t) | |
+ | |
+ | |
+def copy(sgv, dst_graph=None, dst_scope="", src_scope="", | |
+ reuse_dst_scope=False): | |
+ """Copy a subgraph. | |
+ | |
+ Args: | |
+ sgv: the source subgraph-view. This argument is converted to a subgraph | |
+ using the same rules than the function subgraph.make_view. | |
+ dst_graph: the destination graph. | |
+ dst_scope: the destination scope. | |
+ src_scope: the source scope. | |
+ reuse_dst_scope: if True the dst_scope is re-used if it already exists. | |
+ Otherwise, the scope is given a unique name based on the one given | |
+ by appending an underscore followed by a digit (default). | |
+ Returns: | |
+ A tuple `(sgv, info)` where: | |
+ `sgv` is the transformed subgraph view; | |
+ `info` is an instance of TransformerInfo containing | |
+ information about the transform, including mapping between | |
+ original and transformed tensors and operations. | |
+ Raises: | |
+ TypeError: if `dst_graph` is not a `tf.Graph`. | |
+ StandardError: if sgv cannot be converted to a SubGraphView using | |
+ the same rules than the function subgraph.make_view. | |
+ """ | |
+ sgv = subgraph.make_view(sgv) | |
+ if dst_graph is None: | |
+ dst_graph = sgv.graph | |
+ if not isinstance(dst_graph, tf_ops.Graph): | |
+ raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph))) | |
+ | |
+ copier = Transformer() | |
+ return copier( | |
+ sgv, dst_graph, dst_scope, src_scope, reuse_dst_scope=reuse_dst_scope) | |
+ | |
+ | |
+def copy_with_input_replacements(sgv, replacement_ts, | |
+ dst_graph=None, dst_scope="", src_scope="", | |
+ reuse_dst_scope=False): | |
+ """Copy a subgraph, replacing some of its inputs. | |
+ | |
+ Note a replacement only happens if the tensor to be replaced | |
+ is an input of the given subgraph. The inputs of a subgraph can | |
+ be queried using sgv.inputs. | |
+ | |
+ Args: | |
+ sgv: the source subgraph-view. This argument is converted to a subgraph | |
+ using the same rules as the function subgraph.make_view. | |
+ replacement_ts: dictionary mapping from original tensors to the | |
+ replaced one. | |
+ dst_graph: the destination graph. | |
+ dst_scope: the destination scope. | |
+ src_scope: the source scope. | |
+ reuse_dst_scope: if True the dst_scope is re-used if it already exists. | |
+ Otherwise, the scope is given a unique name based on the one given | |
+ by appending an underscore followed by a digit (default). | |
+ Returns: | |
+ A tuple `(sgv, info)` where: | |
+ `sgv` is the transformed subgraph view; | |
+ `info` is an instance of TransformerInfo containing | |
+ information about the transform, including mapping between | |
+ original and transformed tensors and operations. | |
+ Raises: | |
+ TypeError: if dst_graph is not a tf.Graph. | |
+ StandardError: if sgv cannot be converted to a SubGraphView using | |
+ the same rules as the function subgraph.make_view. | |
+ """ | |
+ sgv = subgraph.make_view(sgv) | |
+ if dst_graph is None: | |
+ dst_graph = sgv.graph | |
+ if not isinstance(dst_graph, tf_ops.Graph): | |
+ raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph))) | |
+ | |
+ copier = Transformer() | |
+ # Replace tensor if possible. | |
+ def replace_t_with_replacement_handler(info, t): | |
+ if t in replacement_ts: | |
+ return replacement_ts[t] | |
+ else: | |
+ return keep_t_if_possible_handler(info, t) | |
+ copier.transform_external_input_handler = replace_t_with_replacement_handler | |
+ return copier( | |
+ sgv, dst_graph, dst_scope, src_scope, reuse_dst_scope=reuse_dst_scope) | |
+ | |
+ | |
+def _add_control_flow_ops(ops, control_ios): | |
+ """Complete `ops` so that the transformed graph is valid. | |
+ | |
+ Partially copying a graph can lead to a malformed graph. For instance, | |
+ copying half of a while construct is likely to result in an invalid graph. | |
+ This function attempts to add missing ops so that the transformation result | |
+ in a valid graph. | |
+ | |
+ Args: | |
+ ops: list of ops (modifed in-place). | |
+ control_ios: object created by a call to `util.ControlOutputs`. | |
+ """ | |
+ # Find while contexts. | |
+ control_flow_contexts = set() | |
+ for op in ops: | |
+ cfc = op._control_flow_context # pylint: disable=protected-access | |
+ if cfc: | |
+ control_flow_contexts.add(cfc) | |
+ # Find new ops. | |
+ new_ops = [] | |
+ for cfc in control_flow_contexts: | |
+ if cfc.IsWhileContext(): | |
+ new_ops += select.get_walks_intersection_ops( | |
+ [enter_t.op for enter_t in cfc.loop_enters], | |
+ [exit_t.op for exit_t in cfc.loop_exits], | |
+ control_ios=control_ios) | |
+ # Add new ops. | |
+ new_ops_set = set(new_ops) | |
+ ops_set = frozenset(ops) | |
+ for op in new_ops_set: | |
+ if op not in ops_set: | |
+ ops.append(op) | |
+ | |
+ | |
+def graph_replace(target_ts, replacement_ts, dst_scope="", | |
+ src_scope="", reuse_dst_scope=False): | |
+ """Create a new graph which compute the targets from the replaced Tensors. | |
+ | |
+ Args: | |
+ target_ts: a single tf.Tensor or an iterable of tf.Tensor. | |
+ replacement_ts: dictionary mapping from original tensors to replaced tensors | |
+ dst_scope: the destination scope. | |
+ src_scope: the source scope. | |
+ reuse_dst_scope: if True the dst_scope is re-used if it already exists. | |
+ Otherwise, the scope is given a unique name based on the one given | |
+ by appending an underscore followed by a digit (default). | |
+ Returns: | |
+ A single tf.Tensor or a list of target tf.Tensor, depending on | |
+ the type of the input argument `target_ts`. | |
+ The returned tensors are recomputed using the tensors from replacement_ts. | |
+ Raises: | |
+ ValueError: if the targets are not connected to replacement_ts. | |
+ """ | |
+ # Identify operations in the graph that will change. | |
+ # Start forward walk at Tensors that will be replaced, and | |
+ # backward walk at the target output Tensors. | |
+ flatten_target_ts = util.flatten_tree(target_ts) | |
+ # Construct the forward control dependencies edges so that | |
+ # the get_walks_intersection_ops can also traverse the | |
+ # control dependencies. | |
+ graph = util.get_unique_graph(flatten_target_ts, check_types=(tf_ops.Tensor)) | |
+ control_ios = util.ControlOutputs(graph) | |
+ ops = select.get_walks_intersection_ops( | |
+ list(replacement_ts), flatten_target_ts, control_ios=control_ios) | |
+ if not ops: | |
+ raise ValueError("Targets and replacements are not connected!") | |
+ | |
+ # Complete ops to avoid malformed control flow. | |
+ # TODO(fkp): Consider moving this function deeper (in the transformer?). | |
+ _add_control_flow_ops(ops, control_ios) | |
+ | |
+ # Create a copy of the relevant subgraph | |
+ unused_sgv_, info = copy_with_input_replacements( | |
+ ops, replacement_ts, None, dst_scope, src_scope, reuse_dst_scope) | |
+ # Return the transformed targets but keep the original if the transformed | |
+ # counterpart cannot be found | |
+ missing_fn = lambda original_t: original_t | |
+ return info.transformed(target_ts, missing_fn) | |
diff --git a/tflex_graph_editor/util.py b/tflex_graph_editor/util.py | |
new file mode 100644 | |
index 0000000..543c1da | |
--- /dev/null | |
+++ b/tflex_graph_editor/util.py | |
@@ -0,0 +1,566 @@ | |
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved. | |
+# | |
+# Licensed under the Apache License, Version 2.0 (the "License"); | |
+# you may not use this file except in compliance with the License. | |
+# You may obtain a copy of the License at | |
+# | |
+# http://www.apache.org/licenses/LICENSE-2.0 | |
+# | |
+# Unless required by applicable law or agreed to in writing, software | |
+# distributed under the License is distributed on an "AS IS" BASIS, | |
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+# See the License for the specific language governing permissions and | |
+# limitations under the License. | |
+# ============================================================================== | |
+"""Utility functions for the graph_editor. | |
+""" | |
+ | |
+from __future__ import absolute_import | |
+from __future__ import division | |
+from __future__ import print_function | |
+ | |
+import re | |
+from six import iteritems | |
+from tensorflow.python.framework import ops as tf_ops | |
+from tensorflow.python.ops import array_ops as tf_array_ops | |
+from tensorflow.python.util.compat import collections_abc | |
+ | |
+__all__ = [ | |
+ "make_list_of_op", | |
+ "get_tensors", | |
+ "make_list_of_t", | |
+ "get_generating_ops", | |
+ "get_consuming_ops", | |
+ "ControlOutputs", | |
+ "placeholder_name", | |
+ "make_placeholder_from_tensor", | |
+ "make_placeholder_from_dtype_and_shape", | |
+] | |
+ | |
+ | |
+# The graph editor sometimes need to create placeholders, they are named | |
+# "geph_*". "geph" stands for Graph-Editor PlaceHolder. | |
+_DEFAULT_PLACEHOLDER_PREFIX = "geph" | |
+ | |
+ | |
+def concatenate_unique(la, lb): | |
+ """Add all the elements of `lb` to `la` if they are not there already. | |
+ | |
+ The elements added to `la` maintain ordering with respect to `lb`. | |
+ | |
+ Args: | |
+ la: List of Python objects. | |
+ lb: List of Python objects. | |
+ Returns: | |
+ `la`: The list `la` with missing elements from `lb`. | |
+ """ | |
+ la_set = set(la) | |
+ for l in lb: | |
+ if l not in la_set: | |
+ la.append(l) | |
+ la_set.add(l) | |
+ return la | |
+ | |
+ | |
+# TODO(fkp): very generic code, it should be moved in a more generic place. | |
+class ListView(object): | |
+ """Immutable list wrapper. | |
+ | |
+ This class is strongly inspired by the one in tf.Operation. | |
+ """ | |
+ | |
+ def __init__(self, list_): | |
+ if not isinstance(list_, list): | |
+ raise TypeError("Expected a list, got: {}.".format(type(list_))) | |
+ self._list = list_ | |
+ | |
+ def __iter__(self): | |
+ return iter(self._list) | |
+ | |
+ def __len__(self): | |
+ return len(self._list) | |
+ | |
+ def __bool__(self): | |
+ return bool(self._list) | |
+ | |
+ # Python 3 wants __bool__, Python 2.7 wants __nonzero__ | |
+ __nonzero__ = __bool__ | |
+ | |
+ def __getitem__(self, i): | |
+ return self._list[i] | |
+ | |
+ def __add__(self, other): | |
+ if not isinstance(other, list): | |
+ other = list(other) | |
+ return list(self) + other | |
+ | |
+ | |
+# TODO(fkp): very generic code, it should be moved in a more generic place. | |
+def is_iterable(obj): | |
+ """Return true if the object is iterable.""" | |
+ if isinstance(obj, tf_ops.Tensor): | |
+ return False | |
+ try: | |
+ _ = iter(obj) | |
+ except Exception: # pylint: disable=broad-except | |
+ return False | |
+ return True | |
+ | |
+ | |
+def flatten_tree(tree, leaves=None): | |
+ """Flatten a tree into a list. | |
+ | |
+ Args: | |
+ tree: iterable or not. If iterable, its elements (child) can also be | |
+ iterable or not. | |
+ leaves: list to which the tree leaves are appended (None by default). | |
+ Returns: | |
+ A list of all the leaves in the tree. | |
+ """ | |
+ if leaves is None: | |
+ leaves = [] | |
+ if isinstance(tree, dict): | |
+ for _, child in iteritems(tree): | |
+ flatten_tree(child, leaves) | |
+ elif is_iterable(tree): | |
+ for child in tree: | |
+ flatten_tree(child, leaves) | |
+ else: | |
+ leaves.append(tree) | |
+ return leaves | |
+ | |
+ | |
+def transform_tree(tree, fn, iterable_type=tuple): | |
+ """Transform all the nodes of a tree. | |
+ | |
+ Args: | |
+ tree: iterable or not. If iterable, its elements (child) can also be | |
+ iterable or not. | |
+ fn: function to apply to each leaves. | |
+ iterable_type: type use to construct the resulting tree for unknown | |
+ iterable, typically `list` or `tuple`. | |
+ Returns: | |
+ A tree whose leaves has been transformed by `fn`. | |
+ The hierarchy of the output tree mimics the one of the input tree. | |
+ """ | |
+ if is_iterable(tree): | |
+ if isinstance(tree, dict): | |
+ res = tree.__new__(type(tree)) | |
+ res.__init__( | |
+ (k, transform_tree(child, fn)) for k, child in iteritems(tree)) | |
+ return res | |
+ elif isinstance(tree, tuple): | |
+ # NamedTuple? | |
+ if hasattr(tree, "_asdict"): | |
+ res = tree.__new__(type(tree), **transform_tree(tree._asdict(), fn)) | |
+ else: | |
+ res = tree.__new__(type(tree), | |
+ (transform_tree(child, fn) for child in tree)) | |
+ return res | |
+ elif isinstance(tree, collections_abc.Sequence): | |
+ res = tree.__new__(type(tree)) | |
+ res.__init__(transform_tree(child, fn) for child in tree) | |
+ return res | |
+ else: | |
+ return iterable_type(transform_tree(child, fn) for child in tree) | |
+ else: | |
+ return fn(tree) | |
+ | |
+ | |
+def check_graphs(*args): | |
+ """Check that all the element in args belong to the same graph. | |
+ | |
+ Args: | |
+ *args: a list of object with a obj.graph property. | |
+ Raises: | |
+ ValueError: if all the elements do not belong to the same graph. | |
+ """ | |
+ graph = None | |
+ for i, sgv in enumerate(args): | |
+ if graph is None and sgv.graph is not None: | |
+ graph = sgv.graph | |
+ elif sgv.graph is not None and sgv.graph is not graph: | |
+ raise ValueError("Argument[{}]: Wrong graph!".format(i)) | |
+ | |
+ | |
+def get_unique_graph(tops, check_types=None, none_if_empty=False): | |
+ """Return the unique graph used by the all the elements in tops. | |
+ | |
+ Args: | |
+ tops: list of elements to check (usually a list of tf.Operation and/or | |
+ tf.Tensor). Or a tf.Graph. | |
+ check_types: check that the element in tops are of given type(s). If None, | |
+ the types (tf.Operation, tf.Tensor) are used. | |
+ none_if_empty: don't raise an error if tops is an empty list, just return | |
+ None. | |
+ Returns: | |
+ The unique graph used by all the tops. | |
+ Raises: | |
+ TypeError: if tops is not a iterable of tf.Operation. | |
+ ValueError: if the graph is not unique. | |
+ """ | |
+ if isinstance(tops, tf_ops.Graph): | |
+ return tops | |
+ if not is_iterable(tops): | |
+ raise TypeError("{} is not iterable".format(type(tops))) | |
+ if check_types is None: | |
+ check_types = (tf_ops.Operation, tf_ops.Tensor) | |
+ elif not is_iterable(check_types): | |
+ check_types = (check_types,) | |
+ g = None | |
+ for op in tops: | |
+ if not isinstance(op, check_types): | |
+ raise TypeError("Expected a type in ({}), got: {}".format(", ".join([str( | |
+ t) for t in check_types]), type(op))) | |
+ if g is None: | |
+ g = op.graph | |
+ elif g is not op.graph: | |
+ raise ValueError("Operation {} does not belong to given graph".format(op)) | |
+ if g is None and not none_if_empty: | |
+ raise ValueError("Can't find the unique graph of an empty list") | |
+ return g | |
+ | |
+ | |
+def make_list_of_op(ops, check_graph=True, allow_graph=True, ignore_ts=False): | |
+ """Convert ops to a list of `tf.Operation`. | |
+ | |
+ Args: | |
+ ops: can be an iterable of `tf.Operation`, a `tf.Graph` or a single | |
+ operation. | |
+ check_graph: if `True` check if all the operations belong to the same graph. | |
+ allow_graph: if `False` a `tf.Graph` cannot be converted. | |
+ ignore_ts: if True, silently ignore `tf.Tensor`. | |
+ Returns: | |
+ A newly created list of `tf.Operation`. | |
+ Raises: | |
+ TypeError: if ops cannot be converted to a list of `tf.Operation` or, | |
+ if `check_graph` is `True`, if all the ops do not belong to the | |
+ same graph. | |
+ """ | |
+ if isinstance(ops, tf_ops.Graph): | |
+ if allow_graph: | |
+ return ops.get_operations() | |
+ else: | |
+ raise TypeError("allow_graph is False: cannot convert a tf.Graph.") | |
+ else: | |
+ if not is_iterable(ops): | |
+ ops = [ops] | |
+ if not ops: | |
+ return [] | |
+ if check_graph: | |
+ check_types = None if ignore_ts else tf_ops.Operation | |
+ get_unique_graph(ops, check_types=check_types) | |
+ return [op for op in ops if isinstance(op, tf_ops.Operation)] | |
+ | |
+ | |
+# TODO(fkp): move this function in tf.Graph? | |
+def get_tensors(graph): | |
+ """get all the tensors which are input or output of an op in the graph. | |
+ | |
+ Args: | |
+ graph: a `tf.Graph`. | |
+ Returns: | |
+ A list of `tf.Tensor`. | |
+ Raises: | |
+ TypeError: if graph is not a `tf.Graph`. | |
+ """ | |
+ if not isinstance(graph, tf_ops.Graph): | |
+ raise TypeError("Expected a graph, got: {}".format(type(graph))) | |
+ ts = [] | |
+ for op in graph.get_operations(): | |
+ ts += op.outputs | |
+ return ts | |
+ | |
+ | |
+def make_list_of_t(ts, check_graph=True, allow_graph=True, ignore_ops=False): | |
+ """Convert ts to a list of `tf.Tensor`. | |
+ | |
+ Args: | |
+ ts: can be an iterable of `tf.Tensor`, a `tf.Graph` or a single tensor. | |
+ check_graph: if `True` check if all the tensors belong to the same graph. | |
+ allow_graph: if `False` a `tf.Graph` cannot be converted. | |
+ ignore_ops: if `True`, silently ignore `tf.Operation`. | |
+ Returns: | |
+ A newly created list of `tf.Tensor`. | |
+ Raises: | |
+ TypeError: if `ts` cannot be converted to a list of `tf.Tensor` or, | |
+ if `check_graph` is `True`, if all the ops do not belong to the same graph. | |
+ """ | |
+ if isinstance(ts, tf_ops.Graph): | |
+ if allow_graph: | |
+ return get_tensors(ts) | |
+ else: | |
+ raise TypeError("allow_graph is False: cannot convert a tf.Graph.") | |
+ else: | |
+ if not is_iterable(ts): | |
+ ts = [ts] | |
+ if not ts: | |
+ return [] | |
+ if check_graph: | |
+ check_types = None if ignore_ops else tf_ops.Tensor | |
+ get_unique_graph(ts, check_types=check_types) | |
+ return [t for t in ts if isinstance(t, tf_ops.Tensor)] | |
+ | |
+ | |
+def get_generating_ops(ts): | |
+ """Return all the generating ops of the tensors in `ts`. | |
+ | |
+ Args: | |
+ ts: a list of `tf.Tensor` | |
+ Returns: | |
+ A list of all the generating `tf.Operation` of the tensors in `ts`. | |
+ Raises: | |
+ TypeError: if `ts` cannot be converted to a list of `tf.Tensor`. | |
+ """ | |
+ ts = make_list_of_t(ts, allow_graph=False) | |
+ return [t.op for t in ts] | |
+ | |
+ | |
+def get_consuming_ops(ts): | |
+ """Return all the consuming ops of the tensors in ts. | |
+ | |
+ Args: | |
+ ts: a list of `tf.Tensor` | |
+ Returns: | |
+ A list of all the consuming `tf.Operation` of the tensors in `ts`. | |
+ Raises: | |
+ TypeError: if ts cannot be converted to a list of `tf.Tensor`. | |
+ """ | |
+ ts = make_list_of_t(ts, allow_graph=False) | |
+ ops = [] | |
+ for t in ts: | |
+ for op in t.consumers(): | |
+ if op not in ops: | |
+ ops.append(op) | |
+ return ops | |
+ | |
+ | |
+class ControlOutputs(object): | |
+ """The control outputs topology.""" | |
+ | |
+ def __init__(self, graph): | |
+ """Create a dictionary of control-output dependencies. | |
+ | |
+ Args: | |
+ graph: a `tf.Graph`. | |
+ Returns: | |
+ A dictionary where a key is a `tf.Operation` instance and the | |
+ corresponding value is a list of all the ops which have the key | |
+ as one of their control-input dependencies. | |
+ Raises: | |
+ TypeError: graph is not a `tf.Graph`. | |
+ """ | |
+ if not isinstance(graph, tf_ops.Graph): | |
+ raise TypeError("Expected a tf.Graph, got: {}".format(type(graph))) | |
+ self._control_outputs = {} | |
+ self._graph = graph | |
+ self._version = None | |
+ self._build() | |
+ | |
+ def update(self): | |
+ """Update the control outputs if the graph has changed.""" | |
+ if self._version != self._graph.version: | |
+ self._build() | |
+ return self | |
+ | |
+ def _build(self): | |
+ """Build the control outputs dictionary.""" | |
+ self._control_outputs.clear() | |
+ ops = self._graph.get_operations() | |
+ for op in ops: | |
+ for control_input in op.control_inputs: | |
+ if control_input not in self._control_outputs: | |
+ self._control_outputs[control_input] = [] | |
+ if op not in self._control_outputs[control_input]: | |
+ self._control_outputs[control_input].append(op) | |
+ self._version = self._graph.version | |
+ | |
+ def get_all(self): | |
+ return self._control_outputs | |
+ | |
+ def get(self, op): | |
+ """return the control outputs of op.""" | |
+ if op in self._control_outputs: | |
+ return self._control_outputs[op] | |
+ else: | |
+ return () | |
+ | |
+ @property | |
+ def graph(self): | |
+ return self._graph | |
+ | |
+ | |
+def scope_finalize(scope): | |
+ if scope and scope[-1] != "/": | |
+ scope += "/" | |
+ return scope | |
+ | |
+ | |
+def scope_dirname(scope): | |
+ slash = scope.rfind("/") | |
+ if slash == -1: | |
+ return "" | |
+ return scope[:slash + 1] | |
+ | |
+ | |
+def scope_basename(scope): | |
+ slash = scope.rfind("/") | |
+ if slash == -1: | |
+ return scope | |
+ return scope[slash + 1:] | |
+ | |
+ | |
+def placeholder_name(t=None, scope=None, prefix=_DEFAULT_PLACEHOLDER_PREFIX): | |
+ """Create placeholder name for the graph editor. | |
+ | |
+ Args: | |
+ t: optional tensor on which the placeholder operation's name will be based | |
+ on | |
+ scope: absolute scope with which to prefix the placeholder's name. None | |
+ means that the scope of t is preserved. "" means the root scope. | |
+ prefix: placeholder name prefix. | |
+ Returns: | |
+ A new placeholder name prefixed by "geph". Note that "geph" stands for | |
+ Graph Editor PlaceHolder. This convention allows to quickly identify the | |
+ placeholder generated by the Graph Editor. | |
+ Raises: | |
+ TypeError: if t is not None or a tf.Tensor. | |
+ """ | |
+ if scope is not None: | |
+ scope = scope_finalize(scope) | |
+ if t is not None: | |
+ if not isinstance(t, tf_ops.Tensor): | |
+ raise TypeError("Expected a tf.Tenfor, got: {}".format(type(t))) | |
+ op_dirname = scope_dirname(t.op.name) | |
+ op_basename = scope_basename(t.op.name) | |
+ if scope is None: | |
+ scope = op_dirname | |
+ | |
+ if op_basename.startswith("{}__".format(prefix)): | |
+ ph_name = op_basename | |
+ else: | |
+ ph_name = "{}__{}_{}".format(prefix, op_basename, t.value_index) | |
+ | |
+ return scope + ph_name | |
+ else: | |
+ if scope is None: | |
+ scope = "" | |
+ return "{}{}".format(scope, prefix) | |
+ | |
+ | |
+def make_placeholder_from_tensor(t, scope=None, | |
+ prefix=_DEFAULT_PLACEHOLDER_PREFIX): | |
+ """Create a `tf.compat.v1.placeholder` for the Graph Editor. | |
+ | |
+ Note that the correct graph scope must be set by the calling function. | |
+ | |
+ Args: | |
+ t: a `tf.Tensor` whose name will be used to create the placeholder (see | |
+ function placeholder_name). | |
+ scope: absolute scope within which to create the placeholder. None means | |
+ that the scope of `t` is preserved. `""` means the root scope. | |
+ prefix: placeholder name prefix. | |
+ | |
+ Returns: | |
+ A newly created `tf.compat.v1.placeholder`. | |
+ Raises: | |
+ TypeError: if `t` is not `None` or a `tf.Tensor`. | |
+ """ | |
+ return tf_array_ops.placeholder( | |
+ dtype=t.dtype, shape=t.get_shape(), | |
+ name=placeholder_name(t, scope=scope, prefix=prefix)) | |
+ | |
+ | |
+def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None, | |
+ prefix=_DEFAULT_PLACEHOLDER_PREFIX): | |
+ """Create a tf.compat.v1.placeholder for the Graph Editor. | |
+ | |
+ Note that the correct graph scope must be set by the calling function. | |
+ The placeholder is named using the function placeholder_name (with no | |
+ tensor argument). | |
+ | |
+ Args: | |
+ dtype: the tensor type. | |
+ shape: the tensor shape (optional). | |
+ scope: absolute scope within which to create the placeholder. None means | |
+ that the scope of t is preserved. "" means the root scope. | |
+ prefix: placeholder name prefix. | |
+ | |
+ Returns: | |
+ A newly created tf.placeholder. | |
+ """ | |
+ return tf_array_ops.placeholder( | |
+ dtype=dtype, shape=shape, | |
+ name=placeholder_name(scope=scope, prefix=prefix)) | |
+ | |
+ | |
+_INTERNAL_VARIABLE_RE = re.compile(r"^__\w+__$") | |
+ | |
+ | |
+def get_predefined_collection_names(): | |
+ """Return all the predefined collection names.""" | |
+ return [getattr(tf_ops.GraphKeys, key) for key in dir(tf_ops.GraphKeys) | |
+ if not _INTERNAL_VARIABLE_RE.match(key)] | |
+ | |
+ | |
+def find_corresponding_elem(target, dst_graph, dst_scope="", src_scope=""): | |
+ """Find corresponding op/tensor in a different graph. | |
+ | |
+ Args: | |
+ target: A `tf.Tensor` or a `tf.Operation` belonging to the original graph. | |
+ dst_graph: The graph in which the corresponding graph element must be found. | |
+ dst_scope: A scope which is prepended to the name to look for. | |
+ src_scope: A scope which is removed from the original of `target` name. | |
+ | |
+ Returns: | |
+ The corresponding tf.Tensor` or a `tf.Operation`. | |
+ | |
+ Raises: | |
+ ValueError: if `src_name` does not start with `src_scope`. | |
+ TypeError: if `target` is not a `tf.Tensor` or a `tf.Operation` | |
+ KeyError: If the corresponding graph element cannot be found. | |
+ """ | |
+ src_name = target.name | |
+ if src_scope: | |
+ src_scope = scope_finalize(src_scope) | |
+ if not src_name.startswidth(src_scope): | |
+ raise ValueError("{} does not start with {}".format(src_name, src_scope)) | |
+ src_name = src_name[len(src_scope):] | |
+ | |
+ dst_name = src_name | |
+ if dst_scope: | |
+ dst_scope = scope_finalize(dst_scope) | |
+ dst_name = dst_scope + dst_name | |
+ | |
+ if isinstance(target, tf_ops.Tensor): | |
+ return dst_graph.get_tensor_by_name(dst_name) | |
+ if isinstance(target, tf_ops.Operation): | |
+ return dst_graph.get_operation_by_name(dst_name) | |
+ raise TypeError("Expected tf.Tensor or tf.Operation, got: {}", type(target)) | |
+ | |
+ | |
+def find_corresponding(targets, dst_graph, dst_scope="", src_scope=""): | |
+ """Find corresponding ops/tensors in a different graph. | |
+ | |
+ `targets` is a Python tree, that is, a nested structure of iterable | |
+ (list, tupple, dictionary) whose leaves are instances of | |
+ `tf.Tensor` or `tf.Operation` | |
+ | |
+ Args: | |
+ targets: A Python tree containing `tf.Tensor` or `tf.Operation` | |
+ belonging to the original graph. | |
+ dst_graph: The graph in which the corresponding graph element must be found. | |
+ dst_scope: A scope which is prepended to the name to look for. | |
+ src_scope: A scope which is removed from the original of `top` name. | |
+ | |
+ Returns: | |
+ A Python tree containin the corresponding tf.Tensor` or a `tf.Operation`. | |
+ | |
+ Raises: | |
+ ValueError: if `src_name` does not start with `src_scope`. | |
+ TypeError: if `top` is not a `tf.Tensor` or a `tf.Operation` | |
+ KeyError: If the corresponding graph element cannot be found. | |
+ """ | |
+ def func(top): | |
+ return find_corresponding_elem(top, dst_graph, dst_scope, src_scope) | |
+ return transform_tree(targets, func) | |
diff --git a/train.py b/train.py | |
index d497440..bfba032 100755 | |
--- a/train.py | |
+++ b/train.py | |
@@ -12,8 +12,10 @@ import tensorflow as tf | |
import time | |
import tqdm | |
from tensorflow.core.protobuf import rewriter_config_pb2 | |
-from tensorflow.contrib import tpu | |
-from tensorflow.contrib.cluster_resolver import TPUClusterResolver | |
+from tensorflow.compat.v1 import tpu | |
+#import tensorflow.distribute | |
+#TPUClusterResolver = tf.distribute.TPUClusterResolver | |
+from tensorflow.compat.v1.distribute.cluster_resolver import TPUClusterResolver | |
from tensorflow.python import pywrap_tensorflow | |
import model, sample, encoder | |
@@ -189,18 +191,18 @@ def main(tpu_cluster=None): | |
if args.optimizer == 'adam': | |
args.only_train_transformer_layers = True | |
- config = tf.ConfigProto() | |
+ config = tf.compat.v1.ConfigProto() | |
if args.allow_growth: | |
config.gpu_options.allow_growth = True | |
if args.disable_layout_optimizer: | |
config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF | |
- with tf.Session(tpu_cluster, config=config) as sess: | |
+ with tf.compat.v1.Session(tpu_cluster, config=config) as sess: | |
if tpu_cluster and args.init_tpu: | |
print("initializing TPU system...") | |
sess.run(tpu.initialize_system()) | |
if tpu_cluster: | |
print("Using TPU %s" % tpu_cluster) | |
- context = tf.placeholder(tf.int32, [args.batch_size, None]) | |
+ context = tf.compat.v1.placeholder(tf.int32, [args.batch_size, None]) | |
context_in = randomize(context, hparams, args.noise) | |
output = model.model(hparams=hparams, X=context_in) | |
loss = tf.reduce_mean( | |
@@ -208,12 +210,12 @@ def main(tpu_cluster=None): | |
labels=context[:, 1:], logits=output['logits'][:, :-1])) | |
if args.val_every > 0: | |
- val_context = tf.placeholder(tf.int32, [args.val_batch_size, None]) | |
+ val_context = tf.compat.v1.placeholder(tf.int32, [args.val_batch_size, None]) | |
val_output = model.model(hparams=hparams, X=val_context) | |
val_loss = tf.reduce_mean( | |
tf.nn.sparse_softmax_cross_entropy_with_logits( | |
labels=val_context[:, 1:], logits=val_output['logits'][:, :-1])) | |
- val_loss_summary = tf.summary.scalar('val_loss', val_loss) | |
+ val_loss_summary = tf.compat.v1.summary.scalar('val_loss', val_loss) | |
tf_sample = sample.sample_sequence( | |
@@ -226,14 +228,14 @@ def main(tpu_cluster=None): | |
top_p=args.top_p, | |
epsilon=epsilon) | |
- all_vars = [v for v in tf.trainable_variables() if 'model' in v.name] | |
+ all_vars = [v for v in tf.compat.v1.trainable_variables() if 'model' in v.name] | |
train_vars = [v for v in all_vars if '/h' in v.name] if args.only_train_transformer_layers else all_vars | |
parameter_count = sum([np.prod(v.shape.as_list()) for v in train_vars]) | |
print("This model is using %d parameters (%.2fM)" % (parameter_count, parameter_count/(1024.0*1024.0))) | |
- with tf.variable_scope(tf.get_variable_scope().name, reuse=tf.AUTO_REUSE): | |
- global_step = tflex.get_variable('global_step') or tf.get_variable('global_step', shape=(), dtype=tf.int32, trainable=False) | |
+ with tflex.variable_scope(reuse=tf.compat.v1.AUTO_REUSE): | |
+ global_step = model.init_variable('global_step', shape=(), dtype=tf.int32, trainable=False) | |
current_step = args.learning_rate_initial_step | |
global_step.load(current_step, session=sess) | |
if args.learning_rate_cos: | |
@@ -243,9 +245,9 @@ def main(tpu_cluster=None): | |
lr = tf.constant(args.learning_rate) | |
if args.optimizer == 'adam': | |
- opt = tf.train.AdamOptimizer(learning_rate=lr) | |
+ opt = tf.compat.v1.train.AdamOptimizer(learning_rate=lr) | |
elif args.optimizer == 'sgd': | |
- opt = tf.train.GradientDescentOptimizer(learning_rate=lr) | |
+ opt = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=lr) | |
elif args.optimizer == 'ada': | |
import tensor2tensor.utils.optimize | |
from tensor2tensor.utils import hparam | |
@@ -259,8 +261,9 @@ def main(tpu_cluster=None): | |
exit('Bad optimizer:', args.optimizer) | |
if tpu_cluster: | |
+ pass | |
# https://pulsejet.github.io/blog/posts/tpu-without-estimator/ | |
- from tensorflow.contrib.tpu.python.tpu import tpu_function | |
+ #from tensorflow.contrib.tpu.python.tpu import tpu_function | |
#tpu_function.get_tpu_context().set_number_of_shards(8) | |
#opt = tf.contrib.tpu.CrossShardOptimizer(opt) | |
@@ -273,7 +276,7 @@ def main(tpu_cluster=None): | |
opt_reset = opt.reset() | |
opt_compute = opt.compute_gradients(loss) | |
opt_apply = opt.apply_gradients() | |
- summary_loss = tf.summary.scalar('loss', opt_apply) | |
+ summary_loss = tf.compat.v1.summary.scalar('loss', opt_apply) | |
else: | |
if args.memory_saving_gradients: | |
opt_grads = memory_saving_gradients.gradients(loss, train_vars) | |
@@ -281,12 +284,12 @@ def main(tpu_cluster=None): | |
opt_grads = tf.gradients(loss, train_vars) | |
opt_grads = list(zip(opt_grads, train_vars)) | |
opt_apply = opt.apply_gradients(opt_grads) | |
- summary_loss = tf.summary.scalar('loss', loss) | |
+ summary_loss = tf.compat.v1.summary.scalar('loss', loss) | |
- summary_lr = tf.summary.scalar('learning_rate', lr) | |
- summaries = tf.summary.merge([summary_lr, summary_loss]) | |
+ summary_lr = tf.compat.v1.summary.scalar('learning_rate', lr) | |
+ summaries = tf.compat.v1.summary.merge([summary_lr, summary_loss]) | |
- summary_log = tf.summary.FileWriter( | |
+ summary_log = tf.compat.v1.summary.FileWriter( | |
os.path.join(CHECKPOINT_DIR, args.run_name)) | |
if args.save_graph: | |
@@ -297,7 +300,7 @@ def main(tpu_cluster=None): | |
max_to_keep=args.max_to_keep, | |
keep_checkpoint_every_n_hours=2, | |
reshape=args.truncate_weights) | |
- sess.run(tf.global_variables_initializer()) | |
+ sess.run(tf.compat.v1.global_variables_initializer()) | |
if args.restore_from == 'latest': | |
ckpt = tflex.latest_checkpoint( | |
@@ -528,7 +531,7 @@ def main(tpu_cluster=None): | |
print('trainable variables:') | |
print('name/shape/parameter_count') | |
param_count = 0 | |
- for x in tf.trainable_variables(): | |
+ for x in tf.compat.v1.trainable_variables(): | |
shape = x.shape.as_list() | |
count = np.prod(shape) | |
print(x.name, shape, count) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment