Skip to content

Instantly share code, notes, and snippets.

@danmou
Created December 5, 2019 10:00
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save danmou/bafa5c80356fdb2c843eaf38c8597f84 to your computer and use it in GitHub Desktop.
Save danmou/bafa5c80356fdb2c843eaf38c8597f84 to your computer and use it in GitHub Desktop.
Mixin for `tf.keras.layers.Layer`s and subclasses to automatically define input and output specs the first time the model is called.
from typing import Any, Mapping, Optional, Sequence, TypeVar, Union
import tensorflow as tf
from tensorflow.keras import layers
T = TypeVar('T')
Nested = Union[T, Sequence[T], Mapping[Any, T]]
class AutoShapeMixin:
"""
Mixin for `tf.keras.layers.Layer`s and subclasses to automatically define input and output specs the first time the model is called. Must be listed before `tf.keras.layers.Layer` when subclassing. Only works for
models and layers with static input and output shapes. First `batch_dims` dimensions (default 1) are assumed to be batch dimensions.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.batch_dims: int = kwargs.pop('batch_dims', 1)
super().__init__(*args, **kwargs)
assert not getattr(self, 'dynamic'), 'AutoShapeMixin should not be used with dynamic layers!'
self._input_spec: Optional[Nested[layers.InputSpec]] = None
self._output_spec: Optional[Nested[layers.InputSpec]] = None
self.built_with_input = False
def build_with_input(self, input: Nested[tf.Tensor], *args: Any, **kwargs: Any) -> None:
bd = self.batch_dims
self._input_spec = tf.nest.map_structure(
lambda x: layers.InputSpec(shape=[None]*bd + x.shape[bd:], dtype=x.dtype), input)
dummy_input = tf.nest.map_structure(lambda t: tf.zeros([2]*bd + t.shape[bd:], t.dtype), input)
dummy_output = super().__call__(dummy_input, *args, **kwargs)
self._output_spec = tf.nest.map_structure(lambda x: layers.InputSpec(shape=[None]*bd + x.shape[bd:],
dtype=x.dtype), dummy_output)
self.built_with_input = True
def __call__(self, inputs: Nested[tf.Tensor], *args: Any, **kwargs: Any) -> Any:
if not self.built_with_input:
self.build_with_input(inputs, *args, **kwargs)
return super().__call__(inputs, *args, **kwargs)
@property
def input_spec(self) -> Optional[Nested[layers.InputSpec]]:
return self._input_spec
@input_spec.setter
def input_spec(self, value: Optional[layers.InputSpec]) -> None:
self._input_spec = value
@property
def output_spec(self) -> Optional[Nested[layers.InputSpec]]:
return self._output_spec
@output_spec.setter
def output_spec(self, value: Optional[layers.InputSpec]) -> None:
self._output_spec = value
@property
def input_shape(self) -> Nested[tf.TensorShape]:
assert self.input_spec is not None, 'build_with_input has not been called; input shape is not defined'
return tf.nest.map_structure(lambda x: x.shape, self.input_spec)
@property
def output_shape(self) -> Nested[tf.TensorShape]:
assert self.output_spec is not None, 'build_with_input has not been called; output shape is not defined'
return tf.nest.map_structure(lambda x: x.shape, self.output_spec)
@property
def input_dtype(self) -> Nested[tf.TensorShape]:
assert self.input_spec is not None, 'build_with_input has not been called; input dtype is not defined'
return tf.nest.map_structure(lambda x: x.dtype, self.input_spec)
@property
def output_dtype(self) -> Nested[tf.TensorShape]:
assert self.output_spec is not None, 'build_with_input has not been called; output dtype is not defined'
return tf.nest.map_structure(lambda x: x.dtype, self.output_spec)
def compute_output_shape(self, input_shape: Nested[tf.TensorShape]) -> Nested[tf.TensorShape]:
if self.output_spec is None:
return super().compute_output_shape(input_shape)
batch_shape = tf.nest.flatten(input_shape)[0][:self.batch_dims]
return tf.nest.map_structure(lambda x: batch_shape + x[self.batch_dims:], self.output_shape)
from typing import List
import tensorflow as tf
from tensorflow.keras import layers
from auto_shape_mixin import AutoShapeMixin
### To use the standard Keras layers with auto shape, redefine them like this:
class Layer(AutoShapeMixin, layers.Layer):
pass
class Dense(AutoShapeMixin, layers.Dense):
pass
class Conv2D(AutoShapeMixin, layers.Conv2D):
pass
class Flatten(AutoShapeMixin, layers.Flatten):
pass
class Concatenate(AutoShapeMixin, layers.Concatenate):
pass
class Model(AutoShapeMixin, tf.keras.Model):
pass
class Sequential(AutoShapeMixin, tf.keras.Sequential):
pass
# etc
### For your own layers simply inherit from one of the above classes and also use them for all sub-layers, e.g.:
class ExampleNetwork(Model):
def __init__(self) -> None:
super().__init__()
self.encoder = Sequential([
Conv2D(filters=32, kernel_size=3, strides=2, activation='relu'),
Flatten(),
])
self.concat = Concatenate(axis=-1)
self.dense = Dense(units=100)
def call(self, inputs: List[tf.Tensor]) -> tf.Tensor:
encoded = self.encoder(inputs[0])
joined = self.concat([encoded] + inputs[1:])
return self.dense(joined)
# After the first time you call your model with an input, its input and output shapes and dtypes will be defined and `summary` will work as expected.
model = ExampleNetwork()
first_batch = [tf.zeros((1, 64, 64, 3)), tf.zeros((1, 10))]
model(first_batch)
model.summary()
# Model: "example_network_1"
# _________________________________________________________________
# Layer (type) Output Shape Param #
# =================================================================
# sequential_1 (Sequential) (None, 30752) 896
# _________________________________________________________________
# concatenate (Concatenate) (None, 30762) 0
# _________________________________________________________________
# dense (Dense) (None, 100) 3076300
# =================================================================
# Total params: 3,077,196
# Trainable params: 3,077,196
# Non-trainable params: 0
# _________________________________________________________________
print(model.input_shape)
# ListWrapper([TensorShape([None, 64, 64, 3]), TensorShape([None, 10])])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment