Last active
June 7, 2018 07:52
-
-
Save shadiakiki1986/36546ad09ad1d824d66b4b2680417d39 to your computer and use it in GitHub Desktop.
striding a matrix for RNN input
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Fiddle at | |
# https://pyfiddle.io/fiddle/a14865cf-38c9-48d6-b90a-f5bedc7a5b6e/?m=Saved%20fiddle | |
# | |
# reshape a matrix M x N into M x P x (N-P) while creating overlapping rows | |
# Useful for LSTM input | |
import numpy as np | |
import pandas as pd | |
def stride_group(group, lahead): | |
out = [] | |
for i in range(lahead): | |
out.append(group.shift(i).values) | |
out = np.stack(out, axis=2)[lahead:, :, :] # drop first lahead | |
out = np.swapaxes(out, 1, 2) | |
out = np.flip(out, axis=1) # so that the index=0 is the oldest, and index=4 is latest | |
return out | |
df_in = pd.DataFrame({'A': [1,2,3,4], 'B': [5,6,7,8]}) | |
df_out = stride_group(df_in, 2) | |
print(df_in) | |
print(df_out) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment