Created
May 1, 2024 10:47
-
-
Save JLDC/40b03f209e3e45e5121bdead80f89fea to your computer and use it in GitHub Desktop.
RNN Example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Flux | |
n_timesteps = 1_000 | |
n_sensors = 8 | |
# You have to define a sequence length for an LSTM | |
# (anything above 200 is probably a bad choice, finding the right number | |
# requires domain knowledge about your problem which I don't have) | |
seq_len = 100 | |
# This is the matrix as you currently have it (n_timesteps × n_sensors) | |
X = permutedims(repeat(reshape(1:n_timesteps, 1, :), n_sensors)) | |
# In Julia (Flux), you want your matrix to have the rows | |
# as the features and the columns as the timesteps, i.e., n_sensors × n_timesteps | |
X = permutedims(X) .+ reshape(0.1 * (1:n_sensors), :, 1) | |
# (I add 0.1 to each sensor to make it clearer later that they are not the same series) | |
# Now we reshape for RNN format. | |
# !!! Note that we must have n_timesteps divisible by seq_len !!! | |
# otherwise use padding to achieve it | |
# RNN Format is a vector of size seq_len, where each element is of size (n_inputs, n_batches) | |
# The following function is taken straight from my blogpost | |
# https://www.jldc.ch/post/flux-batching/ | |
# Create batches of a time series | |
function batch_timeseries(X, s::Int, r::Int) | |
if isa(X, AbstractVector) # If X is passed in format T×1, reshape it | |
X = permutedims(X) | |
end | |
T = size(X, 2) | |
@assert s ≤ T "s cannot be longer than the total series" | |
X = X[:, ((T - s) % r)+1:end] # Ensure uniform sequence lengths | |
[X[:, t:r:end-s+t] for t ∈ 1:s] # Output | |
end | |
X_rnn = batch_timeseries(X, seq_len, seq_len) | |
# Inspect the first two element | |
X_rnn[1] # A 8x10 matrix (n_sensors × n_batches) with n_batches = total_timesteps ÷ seq_len | |
# Notice how the second column starts at 101.1, this is where our second batch starts | |
X_rnn[2] # The second element in the series, notice how each column contains the | |
# sensor values at the next timestep | |
# Define the model | |
model = Chain( | |
LSTM(n_sensors => 128), | |
Dense(128 => 1) | |
) | |
# One pass through the model, this gives a vector of n_timesteps ÷ seq_len entries, | |
# each entry is the prediction for the corresponding sequence (i.e., seq_len timesteps) | |
# You will want to drop the first element for loss computation | |
Ys = [model(x) for x in X_rnn] | |
Ys[1] # 10 values, first is the prediction for the first step of first batch, second is first step of second batch, etc. | |
# Append all batches into a single vector if it is easier to work with | |
reshape(vcat(Ys...), :) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment