Last active
May 8, 2019 06:35
-
-
Save kokeshing/f740af85a68bc609550f2e43da9601d9 to your computer and use it in GitHub Desktop.
CasualConv1Dのchannel_lastでのIncremental inferenceのテスト https://qiita.com/r9y9/items/665f84929994c05c6d06
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
class CasualConv1D(tf.keras.layers.Wrapper): | |
def __init__(self, filters, kernel_size=1, strides=1, data_format='channels_last', | |
dilation_rate=1, activation=None, use_bias=True, kernel_initializer=None, | |
bias_initializer=tf.zeros_initializer(), trainable=True, name=None, **kwargs | |
): | |
layer = tf.layers.Conv1D( | |
filters=filters, | |
kernel_size=kernel_size, | |
strides=strides, | |
padding='valid', | |
data_format=data_format, | |
dilation_rate=dilation_rate, | |
activation=activation, | |
use_bias=use_bias, | |
kernel_initializer=kernel_initializer, | |
bias_initializer=bias_initializer, | |
kernel_regularizer=None, | |
bias_regularizer=None, | |
activity_regularizer=None, | |
kernel_constraint=None, | |
bias_constraint=None, | |
trainable=trainable, | |
name=name, **kwargs | |
) | |
super().__init__(layer, name=name, **kwargs) | |
self.filters = filters | |
self.kw = kernel_size | |
self.dilation_rate = dilation_rate | |
self.scope = 'CausalConv1D' if name is None else name | |
def build(self, input_shape): | |
input_shape = tf.TensorShape(input_shape).as_list() | |
self.input_spec = tf.layers.InputSpec(shape=input_shape) | |
self.layer.build(input_shape) | |
if self.layer.data_format == 'channels_last': | |
in_channels = input_shape[2] | |
else: | |
in_channels = input_shape[1] | |
self.linearized_weights = self._get_linearized_weight(in_channels) | |
super().build() | |
def call(self, inputs, is_incremental=False, queue=None): | |
with tf.variable_scope(self.scope) as scope: | |
# nomal run | |
if not is_incremental: | |
padding_size = (self.kw - 1) * self.dilation_rate | |
# Casual convolutionのためpaddingは時系列の初めの方のみにする | |
if self.layer.data_format == 'channels_last': | |
inputs_ = tf.pad(inputs, tf.constant([(0, 0), (padding_size, 0), (0, 0)])) | |
else: | |
assert self.layer.data_format == 'channels_first' | |
inputs_ = tf.pad(inputs, tf.constant([(0, 0), (0, 0), (padding_size, 0)])) | |
outputs = self.layer(inputs_) | |
return outputs | |
# incremental run | |
batch_size = tf.shape(inputs)[0] | |
if self.kw > 1: | |
queue = queue[:, 1:, :] | |
queue = tf.concat([queue, tf.expand_dims(inputs[:, -1, :], axis=1)], axis=1) | |
if self.dilation_rate > 1: | |
inputs = queue[:, 0::self.dilation_rate, :] | |
else: | |
inputs = queue | |
outputs = tf.matmul(tf.reshape(inputs, [batch_size, -1]), self.linearized_weights) | |
if self.layer.use_bias: | |
outputs = tf.nn.bias_add(outputs, self.layer.bias) | |
# [batch_size, 1(time_len), channels] | |
if queue is None: | |
return tf.reshape(outputs, [batch_size, 1, self.layer.filters]) | |
else: | |
return tf.reshape(outputs, [batch_size, 1, self.layer.filters]), queue | |
def _get_linearized_weight(self, in_channels): | |
if tf.shape(self.layer.kernel) == (self.layer.filters, in_channels, self.kw): | |
# [filters, in_channel, kernel_size] | |
weight = tf.transpose(self.layer.kernel, [2, 1, 0]) | |
else: | |
# [kernel_size, in_channel, filters] | |
weight = self.layer.kernel | |
# [kernel_size, in_channel, filters] | |
assert weight.shape == (self.kw, in_channels, self.layer.filters) | |
self.in_channels = in_channels | |
return tf.cast(tf.reshape(weight, [-1, self.layer.filters]), dtype=tf.float32) | |
def test(kernel_size, dilation, T, B, C, data_format='channels_last'): | |
init = tf.constant_initializer(2.0) | |
conv = CasualConv1D(C * 2, | |
kernel_size=kernel_size, | |
data_format=data_format, | |
dilation_rate=dilation, | |
use_bias=False, | |
kernel_initializer=init, | |
name="conv_normal") | |
conv_incremental = CasualConv1D(C * 2, | |
kernel_size=kernel_size, | |
data_format='channels_last', | |
dilation_rate=dilation, | |
use_bias=False, | |
kernel_initializer=init, | |
name="conv_incremental") | |
if data_format == 'channels_last': | |
data = tf.Variable(tf.random.uniform([B, T, C], dtype=tf.dtypes.float32)) | |
else: | |
assert data_format == 'channels_first' | |
data = tf.Variable(tf.random.uniform([B, C, T], dtype=tf.dtypes.float32)) | |
output_conv = conv(data) | |
# Remove future time stamps | |
output_conv = output_conv[:, :T, :] | |
if data_format == 'channels_first': | |
data = tf.transpose(data, (0, 2, 1)) | |
data = tf.concat([data, tf.zeros((B, 1, C))], axis=1) | |
initial_inputs = tf.expand_dims(data[:, 0, :], axis=1) | |
initial_time = tf.constant(0, dtype=tf.int32) | |
initial_outputs = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True) | |
initial_queue = tf.zeros((B, | |
conv.kw + (conv.kw - 1) * (conv.dilation_rate - 1), | |
C)) | |
time_len = T | |
def cond(time, _outputs, _current_input, _queues): | |
return tf.less(time, time_len) | |
def body(time, outputs, inputs, queue): | |
x, new_queue = conv_incremental(inputs, is_incremental=True, queue=queue) | |
if len(x.shape) == 3: | |
x = tf.squeeze(x, [1]) | |
outputs = outputs.write(time, x) | |
time = time + 1 | |
next_inputs = tf.expand_dims(data[:, time, :], axis=1) | |
return time, outputs, next_inputs, new_queue | |
result = tf.while_loop( | |
cond, | |
body, | |
loop_vars=[initial_time, initial_outputs, initial_inputs, initial_queue], | |
parallel_iterations=32, | |
swap_memory=False | |
) | |
outputs = result[1].stack() | |
outputs = tf.transpose(outputs, [1, 0, 2]) | |
return output_conv, outputs | |
if __name__ == '__main__': | |
with tf.Session() as sess: | |
for B in [1]: | |
for T in [10, 20, 30]: | |
for C in [2, 4]: | |
for kernel_size in [3, 5]: | |
for dilation in [1, 2, 4]: | |
x, y = test(kernel_size=kernel_size, dilation=dilation, | |
T=T, B=B, C=C, data_format='channels_last') | |
sess.run(tf.global_variables_initializer()) | |
result_x, result_y = sess.run([x, y]) | |
assert np.allclose(result_x, result_y) == True |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment