Skip to content

Instantly share code, notes, and snippets.

@jperl
Last active February 7, 2018 06:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jperl/954631259eda0f81be750b67e25f9bc4 to your computer and use it in GitHub Desktop.
Save jperl/954631259eda0f81be750b67e25f9bc4 to your computer and use it in GitHub Desktop.
stack past tensor slices
import numpy as np
import tensorflow as tf
def _stack_past(x, steps):
"""Stack the past data for each step.
Ex. x = [0, ..., 60]. steps = [10, 20]
Result [x, x[:-10], x[:-20]] normalized to the same shape
"""
# Sort the steps in ascending order [10, 20]
sorted_steps = steps.copy()
sorted_steps.sort()
largest_step = sorted_steps[-1]
# Include the original data
stacks = [x[largest_step:]]
for step in sorted_steps:
# Normalize the shapes by skipping the difference from the largest step
# Which are rows without enough past data
skip = largest_step - step
stacks.append(x[skip:-step])
return stacks
def np_stack_past(x, steps):
stacks = _stack_past(x, steps)
return np.stack(stacks, axis=-1)
def tf_stack_past(tensor, steps):
stacks = _stack_past(tensor, steps)
return tf.stack(stacks, axis=-1)
import logging
import numpy as np
from numpy.testing import assert_array_equal
logging.getLogger('tensorflow').disabled = True
import tensorflow as tf # noqa
from utils.transform.stack import np_stack_past, tf_stack_past # noqa
def example_matrix():
# Build an example matrix
# [['0a', '0b', '0c'], ... ['99a','99b','99c']]
x = []
for i in range(0, 100):
si = str(i)
x.append([si + 'a', si + 'b', si + 'c'])
return np.stack(x)
class StackTestCase(tf.test.TestCase):
def test_stack_past(self):
x = example_matrix()
past = np_stack_past(x, [1, 2, 10, 20])
expected_first = [['20a', '20b', '20c'], # skip 19 rows that don't have enough history
['19a', '19b', '19c'],
['18a', '18b', '18c'],
['10a', '10b', '10c'], #
['0a', '0b', '0c']] #
# Starts at 20 since the first 19 rows don't have enough data
# 0 -1 -2 -10 -20
expected_first = [['20a', '19a', '18a', '10a', '0a'],
['20b', '19b', '18b', '10b', '0b'],
['20c', '19c', '18c', '10c', '0c']]
assert_array_equal(past[0], expected_first)
# 0 -1 -2 -10 -20
expected_last = [['99a', '98a', '97a', '89a', '79a'],
['99b', '98b', '97b', '89b', '79b'],
['99c', '98c', '97c', '89c', '79c']]
assert_array_equal(past[-1], expected_last)
with self.test_session():
x_tensor = tf.constant(x, dtype=tf.string)
past_tensor = tf_stack_past(x_tensor, [1, 2, 10, 20])
result = past_tensor.eval().astype('U13')
self.assertAllEqual(result[0], expected_first)
self.assertAllEqual(result[-1], expected_last)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment