Skip to content

Instantly share code, notes, and snippets.

@kokeshing
Last active May 8, 2019 06:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kokeshing/f740af85a68bc609550f2e43da9601d9 to your computer and use it in GitHub Desktop.
Save kokeshing/f740af85a68bc609550f2e43da9601d9 to your computer and use it in GitHub Desktop.
CasualConv1Dのchannel_lastでのIncremental inferenceのテスト https://qiita.com/r9y9/items/665f84929994c05c6d06
import tensorflow as tf
import numpy as np
class CasualConv1D(tf.keras.layers.Wrapper):
def __init__(self, filters, kernel_size=1, strides=1, data_format='channels_last',
dilation_rate=1, activation=None, use_bias=True, kernel_initializer=None,
bias_initializer=tf.zeros_initializer(), trainable=True, name=None, **kwargs
):
layer = tf.layers.Conv1D(
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding='valid',
data_format=data_format,
dilation_rate=dilation_rate,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
trainable=trainable,
name=name, **kwargs
)
super().__init__(layer, name=name, **kwargs)
self.filters = filters
self.kw = kernel_size
self.dilation_rate = dilation_rate
self.scope = 'CausalConv1D' if name is None else name
def build(self, input_shape):
input_shape = tf.TensorShape(input_shape).as_list()
self.input_spec = tf.layers.InputSpec(shape=input_shape)
self.layer.build(input_shape)
if self.layer.data_format == 'channels_last':
in_channels = input_shape[2]
else:
in_channels = input_shape[1]
self.linearized_weights = self._get_linearized_weight(in_channels)
super().build()
def call(self, inputs, is_incremental=False, queue=None):
with tf.variable_scope(self.scope) as scope:
# nomal run
if not is_incremental:
padding_size = (self.kw - 1) * self.dilation_rate
# Casual convolutionのためpaddingは時系列の初めの方のみにする
if self.layer.data_format == 'channels_last':
inputs_ = tf.pad(inputs, tf.constant([(0, 0), (padding_size, 0), (0, 0)]))
else:
assert self.layer.data_format == 'channels_first'
inputs_ = tf.pad(inputs, tf.constant([(0, 0), (0, 0), (padding_size, 0)]))
outputs = self.layer(inputs_)
return outputs
# incremental run
batch_size = tf.shape(inputs)[0]
if self.kw > 1:
queue = queue[:, 1:, :]
queue = tf.concat([queue, tf.expand_dims(inputs[:, -1, :], axis=1)], axis=1)
if self.dilation_rate > 1:
inputs = queue[:, 0::self.dilation_rate, :]
else:
inputs = queue
outputs = tf.matmul(tf.reshape(inputs, [batch_size, -1]), self.linearized_weights)
if self.layer.use_bias:
outputs = tf.nn.bias_add(outputs, self.layer.bias)
# [batch_size, 1(time_len), channels]
if queue is None:
return tf.reshape(outputs, [batch_size, 1, self.layer.filters])
else:
return tf.reshape(outputs, [batch_size, 1, self.layer.filters]), queue
def _get_linearized_weight(self, in_channels):
if tf.shape(self.layer.kernel) == (self.layer.filters, in_channels, self.kw):
# [filters, in_channel, kernel_size]
weight = tf.transpose(self.layer.kernel, [2, 1, 0])
else:
# [kernel_size, in_channel, filters]
weight = self.layer.kernel
# [kernel_size, in_channel, filters]
assert weight.shape == (self.kw, in_channels, self.layer.filters)
self.in_channels = in_channels
return tf.cast(tf.reshape(weight, [-1, self.layer.filters]), dtype=tf.float32)
def test(kernel_size, dilation, T, B, C, data_format='channels_last'):
init = tf.constant_initializer(2.0)
conv = CasualConv1D(C * 2,
kernel_size=kernel_size,
data_format=data_format,
dilation_rate=dilation,
use_bias=False,
kernel_initializer=init,
name="conv_normal")
conv_incremental = CasualConv1D(C * 2,
kernel_size=kernel_size,
data_format='channels_last',
dilation_rate=dilation,
use_bias=False,
kernel_initializer=init,
name="conv_incremental")
if data_format == 'channels_last':
data = tf.Variable(tf.random.uniform([B, T, C], dtype=tf.dtypes.float32))
else:
assert data_format == 'channels_first'
data = tf.Variable(tf.random.uniform([B, C, T], dtype=tf.dtypes.float32))
output_conv = conv(data)
# Remove future time stamps
output_conv = output_conv[:, :T, :]
if data_format == 'channels_first':
data = tf.transpose(data, (0, 2, 1))
data = tf.concat([data, tf.zeros((B, 1, C))], axis=1)
initial_inputs = tf.expand_dims(data[:, 0, :], axis=1)
initial_time = tf.constant(0, dtype=tf.int32)
initial_outputs = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
initial_queue = tf.zeros((B,
conv.kw + (conv.kw - 1) * (conv.dilation_rate - 1),
C))
time_len = T
def cond(time, _outputs, _current_input, _queues):
return tf.less(time, time_len)
def body(time, outputs, inputs, queue):
x, new_queue = conv_incremental(inputs, is_incremental=True, queue=queue)
if len(x.shape) == 3:
x = tf.squeeze(x, [1])
outputs = outputs.write(time, x)
time = time + 1
next_inputs = tf.expand_dims(data[:, time, :], axis=1)
return time, outputs, next_inputs, new_queue
result = tf.while_loop(
cond,
body,
loop_vars=[initial_time, initial_outputs, initial_inputs, initial_queue],
parallel_iterations=32,
swap_memory=False
)
outputs = result[1].stack()
outputs = tf.transpose(outputs, [1, 0, 2])
return output_conv, outputs
if __name__ == '__main__':
with tf.Session() as sess:
for B in [1]:
for T in [10, 20, 30]:
for C in [2, 4]:
for kernel_size in [3, 5]:
for dilation in [1, 2, 4]:
x, y = test(kernel_size=kernel_size, dilation=dilation,
T=T, B=B, C=C, data_format='channels_last')
sess.run(tf.global_variables_initializer())
result_x, result_y = sess.run([x, y])
assert np.allclose(result_x, result_y) == True
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment