Skip to content

Instantly share code, notes, and snippets.

@iCorv
Last active March 15, 2021 11:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save iCorv/1ad195b9510c2a3918506580af5a4adf to your computer and use it in GitHub Desktop.
Save iCorv/1ad195b9510c2a3918506580af5a4adf to your computer and use it in GitHub Desktop.
A Keras Subpixel1D layer for upsampling audio and other time-series data in neural networks.
class Subpixel1D(tf.keras.layers.Layer):
def __init__(self,
r,
**kwargs):
super(Subpixel1D, self).__init__(**kwargs)
self.r = r
def build(self, input_shape):
# check if channels are evenly divisible for subpixel1d to work!
input_shape = tf.TensorShape(input_shape).as_list()
if input_shape[2] % self.r != 0:
raise ValueError(
f'The number of input channels must be evenly divisible by the upsampling '
f'factor r. Received r={self.r}, but the input has {input_shape[2]} channels '
f'(full input shape is {input_shape}).'
)
def call(self, inputs):
# (batch, samples, channels) -> (channels, samples, batch)
outputs = tf.transpose(inputs, [2, 1, 0])
# (channels, samples, batch) -> (channels/r, r*samples, batch)
outputs = tf.batch_to_space(outputs, [self.r], [[0, 0]])
# (channels, samples, batch) -> (batch, samples, channels)
outputs = tf.transpose(outputs, [2, 1, 0])
return outputs
def compute_output_shape(self, input_shape):
input_shape = tf.TensorShape(input_shape).as_list()
return (input_shape[0], input_shape[1] * self.r, input_shape[2] // self.r)
def get_config(self):
config = {
'r': self.r,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment