Skip to content

Instantly share code, notes, and snippets.

@innat
Last active September 22, 2023 20:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save innat/1ce396dd46496ce9aac2fce384226211 to your computer and use it in GitHub Desktop.
Save innat/1ce396dd46496ce9aac2fce384226211 to your computer and use it in GitHub Desktop.
# Ref: https://gist.github.com/Rocketknight1/efc47242914788def0144b341b1ad638
import math
import tensorflow as tf
from tensorflow.keras import layers
class TFAdaptiveAveragePooling(layers.Layer):
def __init__(self, output_size, **kwargs):
super().__init__(**kwargs)
if not isinstance(output_size, (list, tuple)):
output_size = (output_size,)
self.output_size = output_size
def _get_pooling_params(self, input_dim, output_dim):
small_window = math.ceil(input_dim / output_dim)
big_window = small_window + 1
return small_window, big_window
def _compute_window(self, pools, windows, input_size, target_size, axis=-1):
small_pool, big_pool = pools
small_window, big_window = windows
both_pool = tf.concat([small_pool, big_pool], axis=axis)
window_starts = tf.math.floor(
(tf.range(target_size, dtype=tf.float32) * input_size) / target_size
)
window_starts = tf.cast(window_starts, tf.int64)
window_ends = tf.math.ceil(
(tf.range(1, target_size + 1, dtype=tf.float32) * input_size)
/ target_size
)
window_ends = tf.cast(window_ends, tf.int64)
pool_selector = tf.cast(
window_ends - window_starts - small_window, tf.bool
)
small_indices = window_starts
big_indices = window_starts + small_pool.shape[axis]
gather_indices = tf.where(pool_selector, big_indices, small_indices)
return tf.gather(both_pool, gather_indices, axis=axis)
def call(self, inputs):
raise NotImplementedError(
"This method should be implemented by subclasses."
)
class TFAdaptiveAveragePooling1D(TFAdaptiveAveragePooling):
def call(self, inputs):
_, input_dim, _ = tf.unstack(tf.shape(inputs))
input_dim = tf.cast(input_dim, tf.float32)
output_size = self.output_size[0]
small_window, big_window = self._get_pooling_params(input_dim, output_size)
small_pool = tf.nn.avg_pool1d(
inputs,
ksize=small_window,
strides=1,
padding="VALID",
data_format='NWC',
)
big_pool = tf.nn.avg_pool1d(
inputs,
ksize=big_window,
strides=1,
padding="VALID",
data_format='NWC',
)
return self._compute_window(
[small_pool, big_pool],
[small_window, big_window],
input_size=input_dim,
target_size=output_size,
axis=1
)
class TFAdaptiveAveragePooling2D(TFAdaptiveAveragePooling):
def _pseudo_pool(self, inputs, output_size, axis=-1):
input_dim = inputs.shape[axis]
small_window, big_window = self._get_pooling_params(input_dim, output_size)
if axis == 1:
small_window_shape = (small_window, 1)
big_window_shape = (big_window, 1)
elif axis == 2:
small_window_shape = (1, small_window)
big_window_shape = (1, big_window)
small_pool = tf.nn.avg_pool2d(
inputs,
ksize=small_window_shape,
strides=1,
padding="VALID",
data_format="NHWC",
)
big_pool = tf.nn.avg_pool2d(
inputs,
ksize=big_window_shape,
strides=1,
padding="VALID",
data_format="NHWC",
)
return self._compute_window(
[small_pool, big_pool],
[small_window, big_window],
input_size=input_dim,
target_size=output_size,
axis=axis
)
def call(self, inputs):
x = self._pseudo_pool(inputs, output_size=self.output_size[0], axis=1)
x = self._pseudo_pool(x, output_size=self.output_size[1], axis=2)
return x
class TFAdaptiveAveragePooling3D(TFAdaptiveAveragePooling):
def _pseudo_pool(self, inputs, output_size, axis=-1):
input_dim = inputs.shape[axis]
small_window, big_window = self._get_pooling_params(input_dim, output_size)
if axis == 1:
small_window_shape = (small_window, 1, 1)
big_window_shape = (big_window, 1, 1)
elif axis == 2:
small_window_shape = (1, small_window, 1)
big_window_shape = (1, big_window, 1)
elif axis == 3:
small_window_shape = (1, 1, small_window)
big_window_shape = (1, 1, big_window)
small_pool = tf.nn.avg_pool3d(
inputs,
ksize=small_window_shape,
strides=1,
padding="VALID",
data_format="NDHWC",
)
big_pool = tf.nn.avg_pool3d(
inputs,
ksize=big_window_shape,
strides=1,
padding="VALID",
data_format="NDHWC",
)
return self._compute_window(
[small_pool, big_pool],
[small_window, big_window],
input_size=input_dim,
target_size=output_size,
axis=axis
)
def call(self, inputs):
x = self._pseudo_pool(inputs, output_size=self.output_size[0], axis=1)
x = self._pseudo_pool(x, output_size=self.output_size[1], axis=2)
x = self._pseudo_pool(x, output_size=self.output_size[2], axis=3)
return x
@innat
Copy link
Author

innat commented Sep 16, 2023

# target output size of 5
pool_size1d=5
m1 = nn.AdaptiveAvgPool1d(pool_size1d)
input1 = torch.randn(1, 64, 8) # bs, channel, dim
output1 = m1(input1)

pool_size2d = (5, 7)
m2 = nn.AdaptiveAvgPool2d(pool_size2d)
input2 = torch.randn(1, 64, 8, 9) # bs, channel, h, w
output2 = m2(input2)

pool_size3d = (5, 7, 9)
m3 = nn.AdaptiveAvgPool3d(pool_size3d)
input3 = torch.randn(1, 64, 8, 9, 10) # bs, channel, depth, h, w
output3 = m3(input3)
# MUST: input-dims: [bs, seq, channel]
tf_m1 = TFAdaptiveAveragePooling1D(output_size=pool_size1d)
output_tf1 = tf_m1(input1.detach().numpy().transpose(0, 2, 1))
output_tf1.shape, output1.shape, pool_size1d

np.testing.assert_allclose(
    output_tf1.numpy(), 
    output1.detach().numpy().transpose(0, 2, 1), 
    1e-4, 1e-4
) # OK
# MUST: input-dims: [bs, h, w, channel]
tf_m2 = TFAdaptiveAveragePooling2D(output_size=pool_size2d)
output_tf2 = tf_m2(input2.detach().numpy().transpose(0, 2, 3, 1))
output_tf2.shape, output2.shape, pool_size2d

np.testing.assert_allclose(
    output_tf2.numpy(), 
    output2.detach().numpy().transpose(0, 2, 3, 1), 
    1e-4, 1e-4
)  # OK
# MUST: input-dims: [bs, depth, h, w, channel]
tf_m3 = TFAdaptiveAveragePooling3D(output_size=pool_size3d)
output_tf3 = tf_m3(input3.detach().numpy().transpose(0, 2, 3, 4, 1))
output_tf3.shape, output3.detach().numpy().transpose(0, 2, 3, 4, 1).shape, pool_size3d

np.testing.assert_allclose(
    output_tf3.numpy(), 
    output3.detach().numpy().transpose(0, 2, 3, 4, 1), 
    1e-4, 1e-4
)  # OK

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment