Skip to content

Instantly share code, notes, and snippets.

@shawwn
Created November 15, 2019 11:36
Show Gist options
  • Save shawwn/17ede3e698b1031e63b5912c6682d4f0 to your computer and use it in GitHub Desktop.
Save shawwn/17ede3e698b1031e63b5912c6682d4f0 to your computer and use it in GitHub Desktop.
diff --git a/requirements.txt b/requirements.txt
index 0d2556c..a573910 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -5,3 +5,4 @@ tqdm==4.31.1
toposort==1.5
tensor2tensor>=1.14.1
h5py
+tensorflow_addons>=0.6.0
diff --git a/src/generate_unconditional_samples.py b/src/generate_unconditional_samples.py
index 87e5132..176cccd 100755
--- a/src/generate_unconditional_samples.py
+++ b/src/generate_unconditional_samples.py
@@ -54,9 +54,9 @@ def sample_model(
elif length > hparams.n_ctx:
raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
- with tf.Session(graph=tf.Graph()) as sess:
+ with tf.compat.v1.Session(graph=tf.Graph()) as sess:
np.random.seed(seed)
- tf.set_random_seed(seed)
+ tf.compat.v1.set_random_seed(seed)
output = sample.sample_sequence(
hparams=hparams, length=length,
@@ -71,6 +71,7 @@ def sample_model(
generated = 0
while nsamples == 0 or generated < nsamples:
+ print('Generating...')
out = sess.run(output)
for i in range(batch_size):
generated += 1
diff --git a/src/hparam.py b/src/hparam.py
new file mode 100644
index 0000000..6495838
--- /dev/null
+++ b/src/hparam.py
@@ -0,0 +1,747 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Hyperparameter values."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import numbers
+import re
+
+import six
+
+import hparam_pb2
+from tensorflow.python.framework import ops
+from tensorflow.python.util import compat
+from tensorflow.python.util import deprecation
+
+# Define the regular expression for parsing a single clause of the input
+# (delimited by commas). A legal clause looks like:
+# <variable name>[<index>]? = <rhs>
+# where <rhs> is either a single token or [] enclosed list of tokens.
+# For example: "var[1] = a" or "x = [1,2,3]"
+PARAM_RE = re.compile(r"""
+ (?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x"
+ (\[\s*(?P<index>\d+)\s*\])? # (optional) index: "1" or None
+ \s*=\s*
+ ((?P<val>[^,\[]*) # single value: "a" or None
+ |
+ \[(?P<vals>[^\]]*)\]) # list of values: None or "1,2,3"
+ ($|,\s*)""", re.VERBOSE)
+
+
+def _parse_fail(name, var_type, value, values):
+ """Helper function for raising a value error for bad assignment."""
+ raise ValueError(
+ 'Could not parse hparam \'%s\' of type \'%s\' with value \'%s\' in %s' %
+ (name, var_type.__name__, value, values))
+
+
+def _reuse_fail(name, values):
+ """Helper function for raising a value error for reuse of name."""
+ raise ValueError('Multiple assignments to variable \'%s\' in %s' % (name,
+ values))
+
+
+def _process_scalar_value(name, parse_fn, var_type, m_dict, values,
+ results_dictionary):
+ """Update results_dictionary with a scalar value.
+
+ Used to update the results_dictionary to be returned by parse_values when
+ encountering a clause with a scalar RHS (e.g. "s=5" or "arr[0]=5".)
+
+ Mutates results_dictionary.
+
+ Args:
+ name: Name of variable in assignment ("s" or "arr").
+ parse_fn: Function for parsing the actual value.
+ var_type: Type of named variable.
+ m_dict: Dictionary constructed from regex parsing.
+ m_dict['val']: RHS value (scalar)
+ m_dict['index']: List index value (or None)
+ values: Full expression being parsed
+ results_dictionary: The dictionary being updated for return by the parsing
+ function.
+
+ Raises:
+ ValueError: If the name has already been used.
+ """
+ try:
+ parsed_value = parse_fn(m_dict['val'])
+ except ValueError:
+ _parse_fail(name, var_type, m_dict['val'], values)
+
+ # If no index is provided
+ if not m_dict['index']:
+ if name in results_dictionary:
+ _reuse_fail(name, values)
+ results_dictionary[name] = parsed_value
+ else:
+ if name in results_dictionary:
+ # The name has already been used as a scalar, then it
+ # will be in this dictionary and map to a non-dictionary.
+ if not isinstance(results_dictionary.get(name), dict):
+ _reuse_fail(name, values)
+ else:
+ results_dictionary[name] = {}
+
+ index = int(m_dict['index'])
+ # Make sure the index position hasn't already been assigned a value.
+ if index in results_dictionary[name]:
+ _reuse_fail('{}[{}]'.format(name, index), values)
+ results_dictionary[name][index] = parsed_value
+
+
+def _process_list_value(name, parse_fn, var_type, m_dict, values,
+ results_dictionary):
+ """Update results_dictionary from a list of values.
+
+ Used to update results_dictionary to be returned by parse_values when
+ encountering a clause with a list RHS (e.g. "arr=[1,2,3]".)
+
+ Mutates results_dictionary.
+
+ Args:
+ name: Name of variable in assignment ("arr").
+ parse_fn: Function for parsing individual values.
+ var_type: Type of named variable.
+ m_dict: Dictionary constructed from regex parsing.
+ m_dict['val']: RHS value (scalar)
+ values: Full expression being parsed
+ results_dictionary: The dictionary being updated for return by the parsing
+ function.
+
+ Raises:
+ ValueError: If the name has an index or the values cannot be parsed.
+ """
+ if m_dict['index'] is not None:
+ raise ValueError('Assignment of a list to a list index.')
+ elements = filter(None, re.split('[ ,]', m_dict['vals']))
+ # Make sure the name hasn't already been assigned a value
+ if name in results_dictionary:
+ raise _reuse_fail(name, values)
+ try:
+ results_dictionary[name] = [parse_fn(e) for e in elements]
+ except ValueError:
+ _parse_fail(name, var_type, m_dict['vals'], values)
+
+
+def _cast_to_type_if_compatible(name, param_type, value):
+ """Cast hparam to the provided type, if compatible.
+
+ Args:
+ name: Name of the hparam to be cast.
+ param_type: The type of the hparam.
+ value: The value to be cast, if compatible.
+
+ Returns:
+ The result of casting `value` to `param_type`.
+
+ Raises:
+ ValueError: If the type of `value` is not compatible with param_type.
+ * If `param_type` is a string type, but `value` is not.
+ * If `param_type` is a boolean, but `value` is not, or vice versa.
+ * If `param_type` is an integer type, but `value` is not.
+ * If `param_type` is a float type, but `value` is not a numeric type.
+ """
+ fail_msg = (
+ "Could not cast hparam '%s' of type '%s' from value %r" %
+ (name, param_type, value))
+
+ # If `value` is already of type `param_type`, return it directly.
+ # `isinstance` is too weak (e.g. isinstance(True, int) == True).
+ if type(value) == param_type: # pylint: disable=unidiomatic-typecheck
+ return value
+
+ # Some callers use None, for which we can't do any casting/checking. :(
+ if issubclass(param_type, type(None)):
+ return value
+
+ # Avoid converting a non-string type to a string.
+ if (issubclass(param_type, (six.string_types, six.binary_type)) and
+ not isinstance(value, (six.string_types, six.binary_type))):
+ raise ValueError(fail_msg)
+
+ # Avoid converting a number or string type to a boolean or vice versa.
+ if issubclass(param_type, bool) != isinstance(value, bool):
+ raise ValueError(fail_msg)
+
+ # Avoid converting float to an integer (the reverse is fine).
+ if (issubclass(param_type, numbers.Integral) and
+ not isinstance(value, numbers.Integral)):
+ raise ValueError(fail_msg)
+
+ # Avoid converting a non-numeric type to a numeric type.
+ if (issubclass(param_type, numbers.Number) and
+ not isinstance(value, numbers.Number)):
+ raise ValueError(fail_msg)
+
+ return param_type(value)
+
+
+def parse_values(values, type_map, ignore_unknown=False):
+ """Parses hyperparameter values from a string into a python map.
+
+ `values` is a string containing comma-separated `name=value` pairs.
+ For each pair, the value of the hyperparameter named `name` is set to
+ `value`.
+
+ If a hyperparameter name appears multiple times in `values`, a ValueError
+ is raised (e.g. 'a=1,a=2', 'a[1]=1,a[1]=2').
+
+ If a hyperparameter name in both an index assignment and scalar assignment,
+ a ValueError is raised. (e.g. 'a=[1,2,3],a[0] = 1').
+
+ The hyperparameter name may contain '.' symbols, which will result in an
+ attribute name that is only accessible through the getattr and setattr
+ functions. (And must be first explicit added through add_hparam.)
+
+ WARNING: Use of '.' in your variable names is allowed, but is not well
+ supported and not recommended.
+
+ The `value` in `name=value` must follows the syntax according to the
+ type of the parameter:
+
+ * Scalar integer: A Python-parsable integer point value. E.g.: 1,
+ 100, -12.
+ * Scalar float: A Python-parsable floating point value. E.g.: 1.0,
+ -.54e89.
+ * Boolean: Either true or false.
+ * Scalar string: A non-empty sequence of characters, excluding comma,
+ spaces, and square brackets. E.g.: foo, bar_1.
+ * List: A comma separated list of scalar values of the parameter type
+ enclosed in square brackets. E.g.: [1,2,3], [1.0,1e-12], [high,low].
+
+ When index assignment is used, the corresponding type_map key should be the
+ list name. E.g. for "arr[1]=0" the type_map must have the key "arr" (not
+ "arr[1]").
+
+ Args:
+ values: String. Comma separated list of `name=value` pairs where
+ 'value' must follow the syntax described above.
+ type_map: A dictionary mapping hyperparameter names to types. Note every
+ parameter name in values must be a key in type_map. The values must
+ conform to the types indicated, where a value V is said to conform to a
+ type T if either V has type T, or V is a list of elements of type T.
+ Hence, for a multidimensional parameter 'x' taking float values,
+ 'x=[0.1,0.2]' will parse successfully if type_map['x'] = float.
+ ignore_unknown: Bool. Whether values that are missing a type in type_map
+ should be ignored. If set to True, a ValueError will not be raised for
+ unknown hyperparameter type.
+
+ Returns:
+ A python map mapping each name to either:
+ * A scalar value.
+ * A list of scalar values.
+ * A dictionary mapping index numbers to scalar values.
+ (e.g. "x=5,L=[1,2],arr[1]=3" results in {'x':5,'L':[1,2],'arr':{1:3}}")
+
+ Raises:
+ ValueError: If there is a problem with input.
+ * If `values` cannot be parsed.
+ * If a list is assigned to a list index (e.g. 'a[1] = [1,2,3]').
+ * If the same rvalue is assigned two different values (e.g. 'a=1,a=2',
+ 'a[1]=1,a[1]=2', or 'a=1,a=[1]')
+ """
+ results_dictionary = {}
+ pos = 0
+ while pos < len(values):
+ m = PARAM_RE.match(values, pos)
+ if not m:
+ raise ValueError('Malformed hyperparameter value: %s' % values[pos:])
+ # Check that there is a comma between parameters and move past it.
+ pos = m.end()
+ # Parse the values.
+ m_dict = m.groupdict()
+ name = m_dict['name']
+ if name not in type_map:
+ if ignore_unknown:
+ continue
+ raise ValueError('Unknown hyperparameter type for %s' % name)
+ type_ = type_map[name]
+
+ # Set up correct parsing function (depending on whether type_ is a bool)
+ if type_ == bool:
+
+ def parse_bool(value):
+ if value in ['true', 'True']:
+ return True
+ elif value in ['false', 'False']:
+ return False
+ else:
+ try:
+ return bool(int(value))
+ except ValueError:
+ _parse_fail(name, type_, value, values)
+
+ parse = parse_bool
+ else:
+ parse = type_
+
+ # If a singe value is provided
+ if m_dict['val'] is not None:
+ _process_scalar_value(name, parse, type_, m_dict, values,
+ results_dictionary)
+
+ # If the assigned value is a list:
+ elif m_dict['vals'] is not None:
+ _process_list_value(name, parse, type_, m_dict, values,
+ results_dictionary)
+
+ else: # Not assigned a list or value
+ _parse_fail(name, type_, '', values)
+
+ return results_dictionary
+
+
+class HParams(object):
+ """Class to hold a set of hyperparameters as name-value pairs.
+
+ A `HParams` object holds hyperparameters used to build and train a model,
+ such as the number of hidden units in a neural net layer or the learning rate
+ to use when training.
+
+ You first create a `HParams` object by specifying the names and values of the
+ hyperparameters.
+
+ To make them easily accessible the parameter names are added as direct
+ attributes of the class. A typical usage is as follows:
+
+ ```python
+ # Create a HParams object specifying names and values of the model
+ # hyperparameters:
+ hparams = HParams(learning_rate=0.1, num_hidden_units=100)
+
+ # The hyperparameter are available as attributes of the HParams object:
+ hparams.learning_rate ==> 0.1
+ hparams.num_hidden_units ==> 100
+ ```
+
+ Hyperparameters have type, which is inferred from the type of their value
+ passed at construction type. The currently supported types are: integer,
+ float, boolean, string, and list of integer, float, boolean, or string.
+
+ You can override hyperparameter values by calling the
+ [`parse()`](#HParams.parse) method, passing a string of comma separated
+ `name=value` pairs. This is intended to make it possible to override
+ any hyperparameter values from a single command-line flag to which
+ the user passes 'hyper-param=value' pairs. It avoids having to define
+ one flag for each hyperparameter.
+
+ The syntax expected for each value depends on the type of the parameter.
+ See `parse()` for a description of the syntax.
+
+ Example:
+
+ ```python
+ # Define a command line flag to pass name=value pairs.
+ # For example using argparse:
+ import argparse
+ parser = argparse.ArgumentParser(description='Train my model.')
+ parser.add_argument('--hparams', type=str,
+ help='Comma separated list of "name=value" pairs.')
+ args = parser.parse_args()
+ ...
+ def my_program():
+ # Create a HParams object specifying the names and values of the
+ # model hyperparameters:
+ hparams = tf.contrib.training.HParams(
+ learning_rate=0.1,
+ num_hidden_units=100,
+ activations=['relu', 'tanh'])
+
+ # Override hyperparameters values by parsing the command line
+ hparams.parse(args.hparams)
+
+ # If the user passed `--hparams=learning_rate=0.3` on the command line
+ # then 'hparams' has the following attributes:
+ hparams.learning_rate ==> 0.3
+ hparams.num_hidden_units ==> 100
+ hparams.activations ==> ['relu', 'tanh']
+
+ # If the hyperparameters are in json format use parse_json:
+ hparams.parse_json('{"learning_rate": 0.3, "activations": "relu"}')
+ ```
+ """
+
+ _HAS_DYNAMIC_ATTRIBUTES = True # Required for pytype checks.
+
+ def __init__(self, hparam_def=None, model_structure=None, **kwargs):
+ """Create an instance of `HParams` from keyword arguments.
+
+ The keyword arguments specify name-values pairs for the hyperparameters.
+ The parameter types are inferred from the type of the values passed.
+
+ The parameter names are added as attributes of `HParams` object, so they
+ can be accessed directly with the dot notation `hparams._name_`.
+
+ Example:
+
+ ```python
+ # Define 3 hyperparameters: 'learning_rate' is a float parameter,
+ # 'num_hidden_units' an integer parameter, and 'activation' a string
+ # parameter.
+ hparams = tf.contrib.training.HParams(
+ learning_rate=0.1, num_hidden_units=100, activation='relu')
+
+ hparams.activation ==> 'relu'
+ ```
+
+ Note that a few names are reserved and cannot be used as hyperparameter
+ names. If you use one of the reserved name the constructor raises a
+ `ValueError`.
+
+ Args:
+ hparam_def: Serialized hyperparameters, encoded as a hparam_pb2.HParamDef
+ protocol buffer. If provided, this object is initialized by
+ deserializing hparam_def. Otherwise **kwargs is used.
+ model_structure: An instance of ModelStructure, defining the feature
+ crosses to be used in the Trial.
+ **kwargs: Key-value pairs where the key is the hyperparameter name and
+ the value is the value for the parameter.
+
+ Raises:
+ ValueError: If both `hparam_def` and initialization values are provided,
+ or if one of the arguments is invalid.
+
+ """
+ # Register the hyperparameters and their type in _hparam_types.
+ # This simplifies the implementation of parse().
+ # _hparam_types maps the parameter name to a tuple (type, bool).
+ # The type value is the type of the parameter for scalar hyperparameters,
+ # or the type of the list elements for multidimensional hyperparameters.
+ # The bool value is True if the value is a list, False otherwise.
+ self._hparam_types = {}
+ self._model_structure = model_structure
+ if hparam_def:
+ self._init_from_proto(hparam_def)
+ if kwargs:
+ raise ValueError('hparam_def and initialization values are '
+ 'mutually exclusive')
+ else:
+ for name, value in six.iteritems(kwargs):
+ self.add_hparam(name, value)
+
+ def _init_from_proto(self, hparam_def):
+ """Creates a new HParams from `HParamDef` protocol buffer.
+
+ Args:
+ hparam_def: `HParamDef` protocol buffer.
+ """
+ assert isinstance(hparam_def, hparam_pb2.HParamDef)
+ for name, value in hparam_def.hparam.items():
+ kind = value.WhichOneof('kind')
+ if kind.endswith('_value'):
+ # Single value.
+ if kind.startswith('int64'):
+ # Setting attribute value to be 'int' to ensure the type is compatible
+ # with both Python2 and Python3.
+ self.add_hparam(name, int(getattr(value, kind)))
+ elif kind.startswith('bytes'):
+ # Setting attribute value to be 'str' to ensure the type is compatible
+ # with both Python2 and Python3. UTF-8 encoding is assumed.
+ self.add_hparam(name, compat.as_str(getattr(value, kind)))
+ else:
+ self.add_hparam(name, getattr(value, kind))
+ else:
+ # List of values.
+ if kind.startswith('int64'):
+ # Setting attribute value to be 'int' to ensure the type is compatible
+ # with both Python2 and Python3.
+ self.add_hparam(name, [int(v) for v in getattr(value, kind).value])
+ elif kind.startswith('bytes'):
+ # Setting attribute value to be 'str' to ensure the type is compatible
+ # with both Python2 and Python3. UTF-8 encoding is assumed.
+ self.add_hparam(
+ name, [compat.as_str(v) for v in getattr(value, kind).value])
+ else:
+ self.add_hparam(name, [v for v in getattr(value, kind).value])
+
+ def add_hparam(self, name, value):
+ """Adds {name, value} pair to hyperparameters.
+
+ Args:
+ name: Name of the hyperparameter.
+ value: Value of the hyperparameter. Can be one of the following types:
+ int, float, string, int list, float list, or string list.
+
+ Raises:
+ ValueError: if one of the arguments is invalid.
+ """
+ # Keys in kwargs are unique, but 'name' could the name of a pre-existing
+ # attribute of this object. In that case we refuse to use it as a
+ # hyperparameter name.
+ if getattr(self, name, None) is not None:
+ raise ValueError('Hyperparameter name is reserved: %s' % name)
+ if isinstance(value, (list, tuple)):
+ if not value:
+ raise ValueError(
+ 'Multi-valued hyperparameters cannot be empty: %s' % name)
+ self._hparam_types[name] = (type(value[0]), True)
+ else:
+ self._hparam_types[name] = (type(value), False)
+ setattr(self, name, value)
+
+ def set_hparam(self, name, value):
+ """Set the value of an existing hyperparameter.
+
+ This function verifies that the type of the value matches the type of the
+ existing hyperparameter.
+
+ Args:
+ name: Name of the hyperparameter.
+ value: New value of the hyperparameter.
+
+ Raises:
+ KeyError: If the hyperparameter doesn't exist.
+ ValueError: If there is a type mismatch.
+ """
+ param_type, is_list = self._hparam_types[name]
+ if isinstance(value, list):
+ if not is_list:
+ raise ValueError(
+ 'Must not pass a list for single-valued parameter: %s' % name)
+ setattr(self, name, [
+ _cast_to_type_if_compatible(name, param_type, v) for v in value])
+ else:
+ if is_list:
+ raise ValueError(
+ 'Must pass a list for multi-valued parameter: %s.' % name)
+ setattr(self, name, _cast_to_type_if_compatible(name, param_type, value))
+
+ def del_hparam(self, name):
+ """Removes the hyperparameter with key 'name'.
+
+ Does nothing if it isn't present.
+
+ Args:
+ name: Name of the hyperparameter.
+ """
+ if hasattr(self, name):
+ delattr(self, name)
+ del self._hparam_types[name]
+
+ def parse(self, values):
+ """Override existing hyperparameter values, parsing new values from a string.
+
+ See parse_values for more detail on the allowed format for values.
+
+ Args:
+ values: String. Comma separated list of `name=value` pairs where 'value'
+ must follow the syntax described above.
+
+ Returns:
+ The `HParams` instance.
+
+ Raises:
+ ValueError: If `values` cannot be parsed or a hyperparameter in `values`
+ doesn't exist.
+ """
+ type_map = {}
+ for name, t in self._hparam_types.items():
+ param_type, _ = t
+ type_map[name] = param_type
+
+ values_map = parse_values(values, type_map)
+ return self.override_from_dict(values_map)
+
+ def override_from_dict(self, values_dict):
+ """Override existing hyperparameter values, parsing new values from a dictionary.
+
+ Args:
+ values_dict: Dictionary of name:value pairs.
+
+ Returns:
+ The `HParams` instance.
+
+ Raises:
+ KeyError: If a hyperparameter in `values_dict` doesn't exist.
+ ValueError: If `values_dict` cannot be parsed.
+ """
+ for name, value in values_dict.items():
+ self.set_hparam(name, value)
+ return self
+
+ @deprecation.deprecated(None, 'Use `override_from_dict`.')
+ def set_from_map(self, values_map):
+ """DEPRECATED. Use override_from_dict."""
+ return self.override_from_dict(values_dict=values_map)
+
+ def set_model_structure(self, model_structure):
+ self._model_structure = model_structure
+
+ def get_model_structure(self):
+ return self._model_structure
+
+ def to_json(self, indent=None, separators=None, sort_keys=False):
+ """Serializes the hyperparameters into JSON.
+
+ Args:
+ indent: If a non-negative integer, JSON array elements and object members
+ will be pretty-printed with that indent level. An indent level of 0, or
+ negative, will only insert newlines. `None` (the default) selects the
+ most compact representation.
+ separators: Optional `(item_separator, key_separator)` tuple. Default is
+ `(', ', ': ')`.
+ sort_keys: If `True`, the output dictionaries will be sorted by key.
+
+ Returns:
+ A JSON string.
+ """
+ return json.dumps(
+ self.values(),
+ indent=indent,
+ separators=separators,
+ sort_keys=sort_keys)
+
+ def parse_json(self, values_json):
+ """Override existing hyperparameter values, parsing new values from a json object.
+
+ Args:
+ values_json: String containing a json object of name:value pairs.
+
+ Returns:
+ The `HParams` instance.
+
+ Raises:
+ KeyError: If a hyperparameter in `values_json` doesn't exist.
+ ValueError: If `values_json` cannot be parsed.
+ """
+ values_map = json.loads(values_json)
+ return self.override_from_dict(values_map)
+
+ def values(self):
+ """Return the hyperparameter values as a Python dictionary.
+
+ Returns:
+ A dictionary with hyperparameter names as keys. The values are the
+ hyperparameter values.
+ """
+ return {n: getattr(self, n) for n in self._hparam_types.keys()}
+
+ def get(self, key, default=None):
+ """Returns the value of `key` if it exists, else `default`."""
+ if key in self._hparam_types:
+ # Ensure that default is compatible with the parameter type.
+ if default is not None:
+ param_type, is_param_list = self._hparam_types[key]
+ type_str = 'list<%s>' % param_type if is_param_list else str(param_type)
+ fail_msg = ("Hparam '%s' of type '%s' is incompatible with "
+ 'default=%s' % (key, type_str, default))
+
+ is_default_list = isinstance(default, list)
+ if is_param_list != is_default_list:
+ raise ValueError(fail_msg)
+
+ try:
+ if is_default_list:
+ for value in default:
+ _cast_to_type_if_compatible(key, param_type, value)
+ else:
+ _cast_to_type_if_compatible(key, param_type, default)
+ except ValueError as e:
+ raise ValueError('%s. %s' % (fail_msg, e))
+
+ return getattr(self, key)
+
+ return default
+
+ def __contains__(self, key):
+ return key in self._hparam_types
+
+ def __str__(self):
+ hpdict = self.values()
+ output_list = ['{}={}'.format(key, hpdict[key]) for key in hpdict]
+ return ','.join(output_list)
+
+ def __repr__(self):
+ strval = str(sorted(self.values().items()))
+ return '%s(%s)' % (type(self).__name__, strval)
+
+ @staticmethod
+ def _get_kind_name(param_type, is_list):
+ """Returns the field name given parameter type and is_list.
+
+ Args:
+ param_type: Data type of the hparam.
+ is_list: Whether this is a list.
+
+ Returns:
+ A string representation of the field name.
+
+ Raises:
+ ValueError: If parameter type is not recognized.
+ """
+ if issubclass(param_type, bool):
+ # This check must happen before issubclass(param_type, six.integer_types),
+ # since Python considers bool to be a subclass of int.
+ typename = 'bool'
+ elif issubclass(param_type, six.integer_types):
+ # Setting 'int' and 'long' types to be 'int64' to ensure the type is
+ # compatible with both Python2 and Python3.
+ typename = 'int64'
+ elif issubclass(param_type, (six.string_types, six.binary_type)):
+ # Setting 'string' and 'bytes' types to be 'bytes' to ensure the type is
+ # compatible with both Python2 and Python3.
+ typename = 'bytes'
+ elif issubclass(param_type, float):
+ typename = 'float'
+ else:
+ raise ValueError('Unsupported parameter type: %s' % str(param_type))
+
+ suffix = 'list' if is_list else 'value'
+ return '_'.join([typename, suffix])
+
+ def to_proto(self, export_scope=None): # pylint: disable=unused-argument
+ """Converts a `HParams` object to a `HParamDef` protocol buffer.
+
+ Args:
+ export_scope: Optional `string`. Name scope to remove.
+
+ Returns:
+ A `HParamDef` protocol buffer.
+ """
+ hparam_proto = hparam_pb2.HParamDef()
+ for name in self._hparam_types:
+ # Parse the values.
+ param_type, is_list = self._hparam_types.get(name, (None, None))
+ kind = HParams._get_kind_name(param_type, is_list)
+
+ if is_list:
+ if kind.startswith('bytes'):
+ v_list = [compat.as_bytes(v) for v in getattr(self, name)]
+ else:
+ v_list = [v for v in getattr(self, name)]
+ getattr(hparam_proto.hparam[name], kind).value.extend(v_list)
+ else:
+ v = getattr(self, name)
+ if kind.startswith('bytes'):
+ v = compat.as_bytes(getattr(self, name))
+ setattr(hparam_proto.hparam[name], kind, v)
+
+ return hparam_proto
+
+ @staticmethod
+ def from_proto(hparam_def, import_scope=None): # pylint: disable=unused-argument
+ return HParams(hparam_def=hparam_def)
+
+
+ops.register_proto_function(
+ 'hparams',
+ proto_type=hparam_pb2.HParamDef,
+ to_proto=HParams.to_proto,
+ from_proto=HParams.from_proto)
+
diff --git a/src/hparam_pb2.py b/src/hparam_pb2.py
new file mode 100644
index 0000000..ba10c30
--- /dev/null
+++ b/src/hparam_pb2.py
@@ -0,0 +1,399 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: tensorflow/contrib/training/python/training/hparam.proto
+
+import sys
+_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import message as _message
+from google.protobuf import reflection as _reflection
+from google.protobuf import symbol_database as _symbol_database
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+
+
+DESCRIPTOR = _descriptor.FileDescriptor(
+ name='tensorflow/contrib/training/python/training/hparam.proto',
+ package='tensorflow',
+ syntax='proto3',
+ serialized_options=_b('\370\001\001'),
+ serialized_pb=_b('\n8tensorflow/contrib/training/python/training/hparam.proto\x12\ntensorflow\"\xd6\x04\n\tHParamDef\x12\x31\n\x06hparam\x18\x01 \x03(\x0b\x32!.tensorflow.HParamDef.HparamEntry\x1a\x1a\n\tBytesList\x12\r\n\x05value\x18\x01 \x03(\x0c\x1a\x1e\n\tFloatList\x12\x11\n\x05value\x18\x01 \x03(\x02\x42\x02\x10\x01\x1a\x1e\n\tInt64List\x12\x11\n\x05value\x18\x01 \x03(\x03\x42\x02\x10\x01\x1a\x1d\n\x08\x42oolList\x12\x11\n\x05value\x18\x01 \x03(\x08\x42\x02\x10\x01\x1a\xc9\x02\n\nHParamType\x12\x15\n\x0bint64_value\x18\x01 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\x02 \x01(\x02H\x00\x12\x15\n\x0b\x62ytes_value\x18\x03 \x01(\x0cH\x00\x12\x14\n\nbool_value\x18\x07 \x01(\x08H\x00\x12\x35\n\nint64_list\x18\x04 \x01(\x0b\x32\x1f.tensorflow.HParamDef.Int64ListH\x00\x12\x35\n\nfloat_list\x18\x05 \x01(\x0b\x32\x1f.tensorflow.HParamDef.FloatListH\x00\x12\x35\n\nbytes_list\x18\x06 \x01(\x0b\x32\x1f.tensorflow.HParamDef.BytesListH\x00\x12\x33\n\tbool_list\x18\x08 \x01(\x0b\x32\x1e.tensorflow.HParamDef.BoolListH\x00\x42\x06\n\x04kind\x1aO\n\x0bHparamEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .tensorflow.HParamDef.HParamType:\x02\x38\x01\x42\x03\xf8\x01\x01\x62\x06proto3')
+)
+
+
+
+
+_HPARAMDEF_BYTESLIST = _descriptor.Descriptor(
+ name='BytesList',
+ full_name='tensorflow.HParamDef.BytesList',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='value', full_name='tensorflow.HParamDef.BytesList.value', index=0,
+ number=1, type=12, cpp_type=9, label=3,
+ has_default_value=False, default_value=[],
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[],
+ enum_types=[
+ ],
+ serialized_options=None,
+ is_extendable=False,
+ syntax='proto3',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=137,
+ serialized_end=163,
+)
+
+_HPARAMDEF_FLOATLIST = _descriptor.Descriptor(
+ name='FloatList',
+ full_name='tensorflow.HParamDef.FloatList',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='value', full_name='tensorflow.HParamDef.FloatList.value', index=0,
+ number=1, type=2, cpp_type=6, label=3,
+ has_default_value=False, default_value=[],
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=_b('\020\001'), file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[],
+ enum_types=[
+ ],
+ serialized_options=None,
+ is_extendable=False,
+ syntax='proto3',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=165,
+ serialized_end=195,
+)
+
+_HPARAMDEF_INT64LIST = _descriptor.Descriptor(
+ name='Int64List',
+ full_name='tensorflow.HParamDef.Int64List',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='value', full_name='tensorflow.HParamDef.Int64List.value', index=0,
+ number=1, type=3, cpp_type=2, label=3,
+ has_default_value=False, default_value=[],
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=_b('\020\001'), file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[],
+ enum_types=[
+ ],
+ serialized_options=None,
+ is_extendable=False,
+ syntax='proto3',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=197,
+ serialized_end=227,
+)
+
+_HPARAMDEF_BOOLLIST = _descriptor.Descriptor(
+ name='BoolList',
+ full_name='tensorflow.HParamDef.BoolList',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='value', full_name='tensorflow.HParamDef.BoolList.value', index=0,
+ number=1, type=8, cpp_type=7, label=3,
+ has_default_value=False, default_value=[],
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=_b('\020\001'), file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[],
+ enum_types=[
+ ],
+ serialized_options=None,
+ is_extendable=False,
+ syntax='proto3',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=229,
+ serialized_end=258,
+)
+
+_HPARAMDEF_HPARAMTYPE = _descriptor.Descriptor(
+ name='HParamType',
+ full_name='tensorflow.HParamDef.HParamType',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='int64_value', full_name='tensorflow.HParamDef.HParamType.int64_value', index=0,
+ number=1, type=3, cpp_type=2, label=1,
+ has_default_value=False, default_value=0,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='float_value', full_name='tensorflow.HParamDef.HParamType.float_value', index=1,
+ number=2, type=2, cpp_type=6, label=1,
+ has_default_value=False, default_value=float(0),
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='bytes_value', full_name='tensorflow.HParamDef.HParamType.bytes_value', index=2,
+ number=3, type=12, cpp_type=9, label=1,
+ has_default_value=False, default_value=_b(""),
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='bool_value', full_name='tensorflow.HParamDef.HParamType.bool_value', index=3,
+ number=7, type=8, cpp_type=7, label=1,
+ has_default_value=False, default_value=False,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='int64_list', full_name='tensorflow.HParamDef.HParamType.int64_list', index=4,
+ number=4, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='float_list', full_name='tensorflow.HParamDef.HParamType.float_list', index=5,
+ number=5, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='bytes_list', full_name='tensorflow.HParamDef.HParamType.bytes_list', index=6,
+ number=6, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='bool_list', full_name='tensorflow.HParamDef.HParamType.bool_list', index=7,
+ number=8, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[],
+ enum_types=[
+ ],
+ serialized_options=None,
+ is_extendable=False,
+ syntax='proto3',
+ extension_ranges=[],
+ oneofs=[
+ _descriptor.OneofDescriptor(
+ name='kind', full_name='tensorflow.HParamDef.HParamType.kind',
+ index=0, containing_type=None, fields=[]),
+ ],
+ serialized_start=261,
+ serialized_end=590,
+)
+
+_HPARAMDEF_HPARAMENTRY = _descriptor.Descriptor(
+ name='HparamEntry',
+ full_name='tensorflow.HParamDef.HparamEntry',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='key', full_name='tensorflow.HParamDef.HparamEntry.key', index=0,
+ number=1, type=9, cpp_type=9, label=1,
+ has_default_value=False, default_value=_b("").decode('utf-8'),
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='value', full_name='tensorflow.HParamDef.HparamEntry.value', index=1,
+ number=2, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[],
+ enum_types=[
+ ],
+ serialized_options=_b('8\001'),
+ is_extendable=False,
+ syntax='proto3',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=592,
+ serialized_end=671,
+)
+
+_HPARAMDEF = _descriptor.Descriptor(
+ name='HParamDef',
+ full_name='tensorflow.HParamDef',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='hparam', full_name='tensorflow.HParamDef.hparam', index=0,
+ number=1, type=11, cpp_type=10, label=3,
+ has_default_value=False, default_value=[],
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[_HPARAMDEF_BYTESLIST, _HPARAMDEF_FLOATLIST, _HPARAMDEF_INT64LIST, _HPARAMDEF_BOOLLIST, _HPARAMDEF_HPARAMTYPE, _HPARAMDEF_HPARAMENTRY, ],
+ enum_types=[
+ ],
+ serialized_options=None,
+ is_extendable=False,
+ syntax='proto3',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=73,
+ serialized_end=671,
+)
+
+_HPARAMDEF_BYTESLIST.containing_type = _HPARAMDEF
+_HPARAMDEF_FLOATLIST.containing_type = _HPARAMDEF
+_HPARAMDEF_INT64LIST.containing_type = _HPARAMDEF
+_HPARAMDEF_BOOLLIST.containing_type = _HPARAMDEF
+_HPARAMDEF_HPARAMTYPE.fields_by_name['int64_list'].message_type = _HPARAMDEF_INT64LIST
+_HPARAMDEF_HPARAMTYPE.fields_by_name['float_list'].message_type = _HPARAMDEF_FLOATLIST
+_HPARAMDEF_HPARAMTYPE.fields_by_name['bytes_list'].message_type = _HPARAMDEF_BYTESLIST
+_HPARAMDEF_HPARAMTYPE.fields_by_name['bool_list'].message_type = _HPARAMDEF_BOOLLIST
+_HPARAMDEF_HPARAMTYPE.containing_type = _HPARAMDEF
+_HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind'].fields.append(
+ _HPARAMDEF_HPARAMTYPE.fields_by_name['int64_value'])
+_HPARAMDEF_HPARAMTYPE.fields_by_name['int64_value'].containing_oneof = _HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind']
+_HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind'].fields.append(
+ _HPARAMDEF_HPARAMTYPE.fields_by_name['float_value'])
+_HPARAMDEF_HPARAMTYPE.fields_by_name['float_value'].containing_oneof = _HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind']
+_HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind'].fields.append(
+ _HPARAMDEF_HPARAMTYPE.fields_by_name['bytes_value'])
+_HPARAMDEF_HPARAMTYPE.fields_by_name['bytes_value'].containing_oneof = _HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind']
+_HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind'].fields.append(
+ _HPARAMDEF_HPARAMTYPE.fields_by_name['bool_value'])
+_HPARAMDEF_HPARAMTYPE.fields_by_name['bool_value'].containing_oneof = _HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind']
+_HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind'].fields.append(
+ _HPARAMDEF_HPARAMTYPE.fields_by_name['int64_list'])
+_HPARAMDEF_HPARAMTYPE.fields_by_name['int64_list'].containing_oneof = _HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind']
+_HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind'].fields.append(
+ _HPARAMDEF_HPARAMTYPE.fields_by_name['float_list'])
+_HPARAMDEF_HPARAMTYPE.fields_by_name['float_list'].containing_oneof = _HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind']
+_HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind'].fields.append(
+ _HPARAMDEF_HPARAMTYPE.fields_by_name['bytes_list'])
+_HPARAMDEF_HPARAMTYPE.fields_by_name['bytes_list'].containing_oneof = _HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind']
+_HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind'].fields.append(
+ _HPARAMDEF_HPARAMTYPE.fields_by_name['bool_list'])
+_HPARAMDEF_HPARAMTYPE.fields_by_name['bool_list'].containing_oneof = _HPARAMDEF_HPARAMTYPE.oneofs_by_name['kind']
+_HPARAMDEF_HPARAMENTRY.fields_by_name['value'].message_type = _HPARAMDEF_HPARAMTYPE
+_HPARAMDEF_HPARAMENTRY.containing_type = _HPARAMDEF
+_HPARAMDEF.fields_by_name['hparam'].message_type = _HPARAMDEF_HPARAMENTRY
+DESCRIPTOR.message_types_by_name['HParamDef'] = _HPARAMDEF
+_sym_db.RegisterFileDescriptor(DESCRIPTOR)
+
+HParamDef = _reflection.GeneratedProtocolMessageType('HParamDef', (_message.Message,), dict(
+
+ BytesList = _reflection.GeneratedProtocolMessageType('BytesList', (_message.Message,), dict(
+ DESCRIPTOR = _HPARAMDEF_BYTESLIST,
+ __module__ = 'tensorflow.contrib.training.python.training.hparam_pb2'
+ # @@protoc_insertion_point(class_scope:tensorflow.HParamDef.BytesList)
+ ))
+ ,
+
+ FloatList = _reflection.GeneratedProtocolMessageType('FloatList', (_message.Message,), dict(
+ DESCRIPTOR = _HPARAMDEF_FLOATLIST,
+ __module__ = 'tensorflow.contrib.training.python.training.hparam_pb2'
+ # @@protoc_insertion_point(class_scope:tensorflow.HParamDef.FloatList)
+ ))
+ ,
+
+ Int64List = _reflection.GeneratedProtocolMessageType('Int64List', (_message.Message,), dict(
+ DESCRIPTOR = _HPARAMDEF_INT64LIST,
+ __module__ = 'tensorflow.contrib.training.python.training.hparam_pb2'
+ # @@protoc_insertion_point(class_scope:tensorflow.HParamDef.Int64List)
+ ))
+ ,
+
+ BoolList = _reflection.GeneratedProtocolMessageType('BoolList', (_message.Message,), dict(
+ DESCRIPTOR = _HPARAMDEF_BOOLLIST,
+ __module__ = 'tensorflow.contrib.training.python.training.hparam_pb2'
+ # @@protoc_insertion_point(class_scope:tensorflow.HParamDef.BoolList)
+ ))
+ ,
+
+ HParamType = _reflection.GeneratedProtocolMessageType('HParamType', (_message.Message,), dict(
+ DESCRIPTOR = _HPARAMDEF_HPARAMTYPE,
+ __module__ = 'tensorflow.contrib.training.python.training.hparam_pb2'
+ # @@protoc_insertion_point(class_scope:tensorflow.HParamDef.HParamType)
+ ))
+ ,
+
+ HparamEntry = _reflection.GeneratedProtocolMessageType('HparamEntry', (_message.Message,), dict(
+ DESCRIPTOR = _HPARAMDEF_HPARAMENTRY,
+ __module__ = 'tensorflow.contrib.training.python.training.hparam_pb2'
+ # @@protoc_insertion_point(class_scope:tensorflow.HParamDef.HparamEntry)
+ ))
+ ,
+ DESCRIPTOR = _HPARAMDEF,
+ __module__ = 'tensorflow.contrib.training.python.training.hparam_pb2'
+ # @@protoc_insertion_point(class_scope:tensorflow.HParamDef)
+ ))
+_sym_db.RegisterMessage(HParamDef)
+_sym_db.RegisterMessage(HParamDef.BytesList)
+_sym_db.RegisterMessage(HParamDef.FloatList)
+_sym_db.RegisterMessage(HParamDef.Int64List)
+_sym_db.RegisterMessage(HParamDef.BoolList)
+_sym_db.RegisterMessage(HParamDef.HParamType)
+_sym_db.RegisterMessage(HParamDef.HparamEntry)
+
+
+DESCRIPTOR._options = None
+_HPARAMDEF_FLOATLIST.fields_by_name['value']._options = None
+_HPARAMDEF_INT64LIST.fields_by_name['value']._options = None
+_HPARAMDEF_BOOLLIST.fields_by_name['value']._options = None
+_HPARAMDEF_HPARAMENTRY._options = None
+# @@protoc_insertion_point(module_scope)
+
diff --git a/src/interactive_conditional_samples.py b/src/interactive_conditional_samples.py
index 26390e6..0bde72e 100755
--- a/src/interactive_conditional_samples.py
+++ b/src/interactive_conditional_samples.py
@@ -14,6 +14,23 @@ import tflex
import model, sample, encoder
+from load_dataset import load_dataset, Sampler
+
+def generate_samples(context, enc, sampler, data_sampler, session, batch_size, count=1):
+ print('Generating samples...')
+ context_tokens = data_sampler.sample(1)
+ index = 0
+ while index < count:
+ out = session.run(
+ sampler,
+ feed_dict={context: batch_size * [context_tokens]})
+ for i in range(len(out)):
+ text = enc.decode(out[i])
+ text = '======== SAMPLE {} ========\n{}\n'.format(
+ index + 1, text)
+ print(text)
+ index += 1
+
def interact_model(
model_name='117M',
seed=None,
@@ -22,7 +39,9 @@ def interact_model(
length=None,
temperature=1,
top_k=0,
- top_p=0.0
+ top_p=0.0,
+ dataset=None,
+ combine=50000
):
"""
Interactively run the model
@@ -58,10 +77,10 @@ def interact_model(
elif length > hparams.n_ctx:
raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
- with tf.Session(graph=tf.Graph()) as sess:
- context = tf.placeholder(tf.int32, [batch_size, None])
+ with tf.compat.v1.Session(graph=tf.Graph()) as sess:
+ context = tf.compat.v1.placeholder(tf.int32, [batch_size, None])
np.random.seed(seed)
- tf.set_random_seed(seed)
+ tf.compat.v1.set_random_seed(seed)
output = sample.sample_sequence(
hparams=hparams, length=length,
context=context,
@@ -73,11 +92,21 @@ def interact_model(
ckpt = tflex.latest_checkpoint(os.path.join('models', model_name))
saver.restore(sess, ckpt)
+ print('Loading dataset...')
+ chunks = load_dataset(enc, dataset, combine=combine)
+ data_sampler = Sampler(chunks, seed=seed)
+ print('dataset has', data_sampler.total_size, 'tokens', len(chunks), 'chunks')
+
+ import pdb
+ pdb.set_trace()
+
while True:
raw_text = input("Model prompt >>> ")
while not raw_text:
print('Prompt should not be empty!')
raw_text = input("Model prompt >>> ")
+ raw_text = raw_text.rstrip()
+ print(repr(raw_text))
context_tokens = enc.encode(raw_text)
generated = 0
for _ in range(nsamples // batch_size):
diff --git a/src/model.py b/src/model.py
index d0cde2a..9cf95ac 100644
--- a/src/model.py
+++ b/src/model.py
@@ -1,9 +1,13 @@
import numpy as np
import tensorflow as tf
-from tensorflow.contrib.training import HParams
+import tensorflow_addons as tfa
+#from tensorflow.contrib.training import HParams
+import hparam
+import collections
+
def default_hparams():
- return HParams(
+ return hparam.HParams(
n_vocab=50257,
n_ctx=1024,
n_embd=768,
@@ -17,11 +21,19 @@ def default_hparams():
import os
def get_variable(name):
- name = os.path.join(tf.get_variable_scope().name, name)
- vs = tf.trainable_variables()
+ name = os.path.join(tf.compat.v1.get_variable_scope().name, name)
+ vs = tf.compat.v1.trainable_variables()
+ for x in vs:
+ if x.name.startswith(name + ':'):
+ return x
+
+def init_variable(name, *args, **kws):
+ name = os.path.join(tf.compat.v1.get_variable_scope().name, name)
+ vs = tf.compat.v1.trainable_variables()
for x in vs:
if x.name.startswith(name + ':'):
return x
+ return tf.compat.v1.get_variable(name, *args, **kws)
def shape_list(x):
"""Deal with dynamic shape in tensorflow cleanly."""
@@ -29,24 +41,26 @@ def shape_list(x):
dynamic = tf.shape(x)
return [dynamic[i] if s is None else s for i, s in enumerate(static)]
-def softmax(x, axis=-1):
- x = x - tf.reduce_max(x, axis=axis, keepdims=True)
- ex = tf.exp(x)
- return ex / tf.reduce_sum(ex, axis=axis, keepdims=True)
+def softmax(x, axis=-1, name=None):
+ #x = x - tf.reduce_max(x, axis=axis, keepdims=True)
+ #ex = tf.exp(x)
+ #return ex / tf.reduce_sum(ex, axis=axis, keepdims=True)
+ return tf.nn.softmax(x, axis=axis, name=name)
def gelu(x):
- return 0.5*x*(1+tf.tanh(np.sqrt(2/np.pi)*(x+0.044715*tf.pow(x, 3))))
+ #return 0.5*x*(1+tf.tanh(np.sqrt(2/np.pi)*(x+0.044715*tf.pow(x, 3))))
+ return tfa.activations.gelu(x)
def norm(x, scope, *, axis=-1, epsilon=1e-5, hparams=None):
"""Normalize to mean = 0, std = 1, then do a diagonal affine transform."""
dtype = hparams.dtype if hparams else tf.float32
- with tf.variable_scope(scope, dtype=dtype):
- n_state = x.shape[-1].value
- g = get_variable('g') or tf.get_variable('g', [n_state], initializer=tf.constant_initializer(1, dtype=dtype))
- b = get_variable('b') or tf.get_variable('b', [n_state], initializer=tf.constant_initializer(0, dtype=dtype))
+ with tf.compat.v1.variable_scope(scope, dtype=dtype):
+ n_state = x.shape[-1]
+ g = init_variable('g', [n_state], initializer=tf.compat.v1.constant_initializer(1, dtype=dtype))
+ b = init_variable('b', [n_state], initializer=tf.compat.v1.constant_initializer(0, dtype=dtype))
u = tf.reduce_mean(x, axis=axis, keepdims=True)
s = tf.reduce_mean(tf.square(x-u), axis=axis, keepdims=True)
- x = (x - u) * tf.rsqrt(s + epsilon)
+ x = (x - u) * tf.compat.v1.rsqrt(s + epsilon)
x = x*g + b
return x
@@ -62,10 +76,10 @@ def merge_states(x):
def conv1d(x, scope, nf, *, w_init_stdev=0.02, hparams=None):
dtype = hparams.dtype if hparams else tf.float32
- with tf.variable_scope(scope, dtype=dtype):
+ with tf.compat.v1.variable_scope(scope, dtype=dtype):
*start, nx = shape_list(x)
- w = get_variable('w') or tf.get_variable('w', [1, nx, nf], initializer=tf.random_normal_initializer(stddev=w_init_stdev, dtype=dtype))
- b = get_variable('b') or tf.get_variable('b', [nf], initializer=tf.constant_initializer(0, dtype=dtype))
+ w = init_variable('w', [1, nx, nf], initializer=tf.compat.v1.random_normal_initializer(stddev=w_init_stdev, dtype=dtype))
+ b = init_variable('b', [nf], initializer=tf.compat.v1.constant_initializer(0, dtype=dtype))
c = tf.reshape(tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf]))+b, start+[nf])
return c
@@ -105,7 +119,7 @@ def attn(x, scope, n_state, *, past, hparams):
def multihead_attn(q, k, v):
# q, k, v have shape [batch, heads, sequence, features]
w = tf.matmul(q, k, transpose_b=True)
- w = w * tf.rsqrt(tf.cast(v.shape[-1].value, w.dtype))
+ w = w * tf.compat.v1.rsqrt(tf.cast(v.shape[-1], w.dtype))
w = mask_attn_weights(w)
w = softmax(w)
@@ -114,7 +128,7 @@ def attn(x, scope, n_state, *, past, hparams):
return a
dtype = hparams.dtype if hparams else tf.float32
- with tf.variable_scope(scope, dtype=dtype):
+ with tf.compat.v1.variable_scope(scope, dtype=dtype):
c = conv1d(x, 'c_attn', n_state*3, hparams=hparams)
q, k, v = map(split_heads, tf.split(c, 3, axis=2))
present = tf.stack([k, v], axis=1)
@@ -131,9 +145,10 @@ def attn(x, scope, n_state, *, past, hparams):
def mlp(x, scope, n_state, *, hparams):
dtype = hparams.dtype if hparams else tf.float32
- with tf.variable_scope(scope, dtype=dtype):
- nx = x.shape[-1].value
- h = gelu(conv1d(x, 'c_fc', n_state, hparams=hparams))
+ with tf.compat.v1.variable_scope(scope, dtype=dtype):
+ nx = x.shape[-1]
+ h0 = conv1d(x, 'c_fc', n_state, hparams=hparams)
+ h = gelu(h0)
h2 = conv1d(h, 'c_proj', nx, hparams=hparams)
h2 = dropout(h2, hparams.res_dropout)
return h2
@@ -145,8 +160,8 @@ def dropout(x, pdrop=0.1, train=True):
def block(x, scope, *, past, hparams):
dtype = hparams.dtype if hparams else tf.float32
- with tf.variable_scope(scope, dtype=dtype):
- nx = x.shape[-1].value
+ with tf.compat.v1.variable_scope(scope, dtype=dtype):
+ nx = x.shape[-1]
a, present = attn(norm(x, 'ln_1', hparams=hparams), 'attn', nx, past=past, hparams=hparams)
x = x + a
m = mlp(norm(x, 'ln_2', hparams=hparams), 'mlp', nx*4, hparams=hparams)
@@ -168,16 +183,16 @@ def positions_for(tokens, past_length):
return expand_tile(past_length + tf.range(nsteps), batch_size)
-def model(hparams, X, past=None, scope='model', reuse=tf.AUTO_REUSE):
+def model(hparams, X, past=None, scope='model', reuse=tf.compat.v1.AUTO_REUSE):
dtype = hparams.dtype if hparams else tf.float32
- with tf.variable_scope(scope, reuse=reuse, dtype=dtype):
+ with tf.compat.v1.variable_scope(scope, reuse=reuse, dtype=dtype):
results = {}
batch, sequence = shape_list(X)
- wpe = get_variable('wpe') or tf.get_variable('wpe', [hparams.n_ctx, hparams.n_embd],
- initializer=tf.random_normal_initializer(stddev=0.01, dtype=dtype))
- wte = get_variable('wte') or tf.get_variable('wte', [hparams.n_vocab, hparams.n_embd],
- initializer=tf.random_normal_initializer(stddev=0.02, dtype=dtype))
+ wpe = init_variable('wpe', [hparams.n_ctx, hparams.n_embd],
+ initializer=tf.compat.v1.random_normal_initializer(stddev=0.01, dtype=dtype))
+ wte = init_variable('wte', [hparams.n_vocab, hparams.n_embd],
+ initializer=tf.compat.v1.random_normal_initializer(stddev=0.02, dtype=dtype))
past_length = 0 if past is None else tf.shape(past)[-2]
h = tf.gather(wte, X) + tf.gather(wpe, positions_for(X, past_length))
@@ -188,7 +203,7 @@ def model(hparams, X, past=None, scope='model', reuse=tf.AUTO_REUSE):
for layer, past in enumerate(pasts):
h, present = block(h, 'h%d' % layer, past=past, hparams=hparams)
if layer == 10:
- tf.add_to_collection('checkpoints', h)
+ tf.compat.v1.add_to_collection('checkpoints', h)
presents.append(present)
results['present'] = tf.stack(presents, axis=1)
h = norm(h, 'ln_f', hparams=hparams)
diff --git a/src/sample.py b/src/sample.py
index ff517be..fb4ad80 100644
--- a/src/sample.py
+++ b/src/sample.py
@@ -44,7 +44,7 @@ def sample_sequence(*, hparams, length, start_token=None, batch_size=None, conte
context = tf.fill([batch_size, 1], start_token)
def step(hparams, tokens, past=None):
- lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE)
+ lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.compat.v1.AUTO_REUSE)
if hparams.dtype != tf.float32:
lm_output["logits"] = tf.cast(lm_output["logits"], tf.float32)
@@ -64,12 +64,12 @@ def sample_sequence(*, hparams, length, start_token=None, batch_size=None, conte
def body(past, prev, output):
next_outputs = step(hparams, prev[:, tf.newaxis], past=past)
- logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature)
+ logits = next_outputs['logits'][:, -1, :] / tf.compat.v1.to_float(temperature)
if top_p > 0.0:
logits = top_p_logits(logits, p=top_p, epsilon=epsilon)
else:
logits = top_k_logits(logits, k=top_k, epsilon=epsilon)
- samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32)
+ samples = tf.compat.v1.multinomial(logits, num_samples=1, output_dtype=tf.int32)
return [
tf.concat([past, next_outputs['presents']], axis=-2),
tf.squeeze(samples, axis=[1]),
diff --git a/tflex.py b/tflex.py
index 7c76c44..a9fba8e 100644
--- a/tflex.py
+++ b/tflex.py
@@ -8,6 +8,7 @@ import tqdm
import h5py
import shutil
import tempfile
+import traceback
def split_by_params(vs, n=200e6, f=None):
if f is None:
@@ -65,7 +66,7 @@ def assign_values(variables, values, session=None):
def load_snapshot(ckpt, session=None, var_list=None, reshape=False):
session = session or tf.get_default_session()
reader = pywrap_tensorflow.NewCheckpointReader(ckpt)
- vs = var_list or tf.trainable_variables()
+ vs = var_list or tf.compat.v1.trainable_variables()
for variables in tqdm.tqdm(list(split_by_params(vs))):
values = [value for variable, value in grab_values(variables, reader, reshape=reshape)]
assign_values(variables, values, session=session)
@@ -74,14 +75,14 @@ def get_variable(name, var_list=None):
name, num = name.split(':') if ':' in name else (name, '0')
num = int(num)
name = os.path.join(tf.get_variable_scope().name, name)
- vs = var_list or tf.trainable_variables()
+ vs = var_list or tf.compat.v1.trainable_variables()
for x in vs:
if x.name.startswith(name + ':%d' % num):
return x
def load_weights(ckpt, session=None, var_list=None, reshape=False):
session = session or tf.get_default_session()
- vs = var_list or tf.trainable_variables()
+ vs = var_list or tf.compat.v1.trainable_variables()
files = list(sorted(glob(ckpt + '-*.npy')))
for out in tqdm.tqdm(files):
for name, value in np.load(out, allow_pickle=True):
@@ -93,11 +94,13 @@ def load_weights(ckpt, session=None, var_list=None, reshape=False):
variable.load(value, session)
def load_variables(ckpt, session=None, var_list=None, reshape=False):
+ import pdb
+ pdb.set_trace()
session = session or tf.get_default_session()
- vs = var_list or tf.trainable_variables()
+ vs = var_list or tf.compat.v1.trainable_variables()
with h5py.File(ckpt) as f:
for variables in tqdm.tqdm(list(split_by_params(vs))):
- values = [truncate_value(x, f[x.name], reshape=reshape) for x in variables]
+ values = [truncate_value(x, f[x.name], reshape=reshape) for x in variables if x.name in f]
assign_values(variables, values, session=session)
def maketree(path):
@@ -107,21 +110,31 @@ def maketree(path):
pass
def save_variables(ckpt, session=None, var_list=None):
- session = session or tf.get_default_session()
- vs = var_list or tf.trainable_variables()
- maketree(os.path.dirname(ckpt))
- fname = ckpt+'.tmp'
- with h5py.File(fname, "w") as f:
- for variables in tqdm.tqdm(list(split_by_params(vs))):
- values = session.run(variables)
- for value, variable in zip(values, variables):
- name = variable.name
- shape = variable.shape.as_list()
- dtype = variable.dtype
- dset = f.create_dataset(name, shape, dtype=np.float32)
- dset[:] = value
- print('Writing snapshot %s' % ckpt)
- os.rename(ckpt+'.tmp', ckpt)
+ while True:
+ try:
+ session = session or tf.get_default_session()
+ vs = var_list or tf.compat.v1.trainable_variables()
+ maketree(os.path.dirname(ckpt))
+ fname = ckpt+'.tmp'
+ with h5py.File(fname, "w") as f:
+ for variables in tqdm.tqdm(list(split_by_params(vs))):
+ values = session.run(variables)
+ for value, variable in zip(values, variables):
+ name = variable.name
+ shape = variable.shape.as_list()
+ dtype = variable.dtype
+ dset = f.create_dataset(name, shape, dtype=np.float32)
+ dset[:] = value
+ print('Writing snapshot %s' % ckpt)
+ os.rename(ckpt+'.tmp', ckpt)
+ break
+ except:
+ traceback.print_exc(file=sys.stderr)
+ print('Exception while saving checkpoint. To try again, press "c" then enter.')
+ print('(Is your disk full?)')
+ print('Attaching debugger...')
+ import pdb
+ pdb.set_trace()
class Saver(object):
def __init__(
@@ -137,7 +150,7 @@ class Saver(object):
builder=None,
defer_build=False,
allow_empty=False,
- write_version=tf.train.SaverDef.V2,
+ write_version=tf.compat.v1.train.SaverDef.V2,
pad_step_number=False,
save_relative_paths=False,
filename=None):
@@ -159,6 +172,7 @@ class Saver(object):
self.checkpoints = []
def restore(self, sess, save_path):
+ print('Restoring from {}...'.format(save_path))
if save_path.endswith('.ckpt'):
load_snapshot(save_path, session=sess, var_list=self.var_list, reshape=self.reshape)
elif save_path.endswith('.hdf5'):
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment