Skip to content

Instantly share code, notes, and snippets.

@JLDC

JLDC/mwe_flux.jl Secret

Last active July 21, 2022 06:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save JLDC/66c8eed6c36cb9cda85ff6404284d841 to your computer and use it in GitHub Desktop.
Save JLDC/66c8eed6c36cb9cda85ff6404284d841 to your computer and use it in GitHub Desktop.
Flux vs. PyTorch computation times
using Pkg
Pkg.activate(".")
using BenchmarkTools
using Flux
using StatsBase
using Random
using MKL
function julia_nn(lstm, X, Y, opt, θ, epochs=100)
for epoch 1:epochs
Flux.reset!(lstm)
= gradient(θ) do
[lstm(x) for x X[1:end-1]] # Warm up
Flux.Losses.mse(lstm(X[end]), Y) # MSE on last item only
end
Flux.update!(opt, θ, ∇)
end
end
function main()
Random.seed!(72)
X = [rand(Float32, 100, 1_000) for _ 1:20]
Y = rand(Float32, 1, 1_000)
lstm = Chain(
LSTM(100 => 128),
LSTM(128 => 128),
Dense(128 => 1)
)
opt = ADAM()
θ = Flux.params(lstm)
# Run one epoch outside of computation timings for compilation
Flux.reset!(lstm)
= gradient(θ) do
[lstm(x) for x X[1:end-1]] # Warm up
Flux.Losses.mse(lstm(X[end]), Y) # MSE on last item only
end
Flux.update!(opt, θ, ∇)
Flux.reset!(lstm)
@time julia_nn(lstm, X, Y, opt, θ)
end
main()
using Pkg
Pkg.activate(".")
using BenchmarkTools
using Flux
using StatsBase
using Random
using MKL
struct Seq2One
rnn
fc
end
Flux.@functor Seq2One
function (m::Seq2One)(X)
[m.rnn(x) for x X[1:end-1]]
m.fc(m.rnn(X[end]))
end
function julia_nn(lstm, X, Y, opt, θ, epochs=100)
for epoch 1:epochs
Flux.reset!(lstm)
= gradient(θ) do
Flux.Losses.mse(lstm(X), Y) # MSE on last item only
end
Flux.update!(opt, θ, ∇)
end
end
function main()
Random.seed!(72)
X = [rand(Float32, 100, 1_000) for _ 1:20]
Y = rand(Float32, 1, 1_000)
lstm = Seq2One(
Chain(LSTM(100 => 128), LSTM(128 => 128)),
Dense(128 => 1)
)
opt = ADAM()
θ = Flux.params(lstm)
# Run one epoch outside of computation timings for compilation
Flux.reset!(lstm)
= gradient(θ) do
Flux.Losses.mse(lstm(X), Y) # MSE on last item only
end
Flux.update!(opt, θ, ∇)
Flux.reset!(lstm)
@time julia_nn(lstm, X, Y, opt, θ)
end
main()
using Pkg
Pkg.activate(".")
using BenchmarkTools
using StatsBase
using Random
using PyCall
py"""
import torch
import torch.nn as nn
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__() # Base class constructor
self.lstm = nn.LSTM(input_size, hidden_size, num_layers=2,
batch_first=True)
self.linear = nn.Linear(hidden_size, output_size)
def forward(self, x):
h = self.lstm(x)[0]
return self.linear(h)
def train_network(X, Y, epochs=100):
dev = torch.device("cpu") # torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
lstm = LSTM(100, 128, 1).to(dev)
X, Y = torch.from_numpy(X).to(dev), torch.from_numpy(Y).to(dev)
lstm.train()
criterion = nn.MSELoss()
opt = torch.optim.Adam(lstm.parameters())
for epoch in range(epochs):
output = lstm(X)
output = output[:, -1, :]
opt.zero_grad()
loss = criterion(output, Y)
loss.backward()
opt.step()
"""
train_pytorch = py"train_network"
function main()
Random.seed!(72)
X = rand(Float32, 1_000, 20, 100)
Y = rand(Float32, 1_000, 1)
@time train_pytorch(X, Y)
end
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment