Skip to content

Instantly share code, notes, and snippets.

@JLDC
Created May 1, 2024 10:47
Show Gist options
  • Save JLDC/40b03f209e3e45e5121bdead80f89fea to your computer and use it in GitHub Desktop.
Save JLDC/40b03f209e3e45e5121bdead80f89fea to your computer and use it in GitHub Desktop.
RNN Example
using Flux
n_timesteps = 1_000
n_sensors = 8
# You have to define a sequence length for an LSTM
# (anything above 200 is probably a bad choice, finding the right number
# requires domain knowledge about your problem which I don't have)
seq_len = 100
# This is the matrix as you currently have it (n_timesteps × n_sensors)
X = permutedims(repeat(reshape(1:n_timesteps, 1, :), n_sensors))
# In Julia (Flux), you want your matrix to have the rows
# as the features and the columns as the timesteps, i.e., n_sensors × n_timesteps
X = permutedims(X) .+ reshape(0.1 * (1:n_sensors), :, 1)
# (I add 0.1 to each sensor to make it clearer later that they are not the same series)
# Now we reshape for RNN format.
# !!! Note that we must have n_timesteps divisible by seq_len !!!
# otherwise use padding to achieve it
# RNN Format is a vector of size seq_len, where each element is of size (n_inputs, n_batches)
# The following function is taken straight from my blogpost
# https://www.jldc.ch/post/flux-batching/
# Create batches of a time series
function batch_timeseries(X, s::Int, r::Int)
if isa(X, AbstractVector) # If X is passed in format T×1, reshape it
X = permutedims(X)
end
T = size(X, 2)
@assert s ≤ T "s cannot be longer than the total series"
X = X[:, ((T - s) % r)+1:end] # Ensure uniform sequence lengths
[X[:, t:r:end-s+t] for t ∈ 1:s] # Output
end
X_rnn = batch_timeseries(X, seq_len, seq_len)
# Inspect the first two element
X_rnn[1] # A 8x10 matrix (n_sensors × n_batches) with n_batches = total_timesteps ÷ seq_len
# Notice how the second column starts at 101.1, this is where our second batch starts
X_rnn[2] # The second element in the series, notice how each column contains the
# sensor values at the next timestep
# Define the model
model = Chain(
LSTM(n_sensors => 128),
Dense(128 => 1)
)
# One pass through the model, this gives a vector of n_timesteps ÷ seq_len entries,
# each entry is the prediction for the corresponding sequence (i.e., seq_len timesteps)
# You will want to drop the first element for loss computation
Ys = [model(x) for x in X_rnn]
Ys[1] # 10 values, first is the prediction for the first step of first batch, second is first step of second batch, etc.
# Append all batches into a single vector if it is easier to work with
reshape(vcat(Ys...), :)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment