Skip to content

Instantly share code, notes, and snippets.

@JEM-Mosig
Created April 29, 2021 09:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save JEM-Mosig/1a242a7faa5ff3a19c4141a3f1cd934c to your computer and use it in GitHub Desktop.
Save JEM-Mosig/1a242a7faa5ff3a19c4141a3f1cd934c to your computer and use it in GitHub Desktop.
pad_right for tensorflow
import tensorflow as tf
from tensorflow import Tensor, TensorShape
from typing import Union
def pad_right(
x: Tensor, target_shape: TensorShape, value: Union[int, float] = 0
) -> Tensor:
"""Creates a tensor of shape `target_shape` by padding it with `value` on the right.
Args:
x: Any tensor
target_shape: Shape of the padded x; must be at least as large as the
shape of x in all dimensions
Returns:
A tensor like x, but padded with zeros
"""
current_shape = tf.shape(x)
right_padding = tf.expand_dims(
tf.convert_to_tensor(target_shape - current_shape), -1
)
padding = tf.concat([tf.zeros_like(right_padding), right_padding], -1)
return tf.pad(x, padding, "CONSTANT", constant_values=value)
##############################################################################
import pytest
import tensorflow as tf
import numpy as np
import rasa.utils.tensorflow.layers_utils as layers_utils
def test_pad_right():
x = tf.ones([3, 2])
x_padded = layers_utils.pad_right(x, [5, 7])
assert np.all(tf.shape(x_padded).numpy() == [5, 7])
assert np.all(
x_padded.numpy()
== [
[1, 1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
]
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment