-
-
Save JLDC/66c8eed6c36cb9cda85ff6404284d841 to your computer and use it in GitHub Desktop.
Flux vs. PyTorch computation times
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Pkg | |
Pkg.activate(".") | |
using BenchmarkTools | |
using Flux | |
using StatsBase | |
using Random | |
using MKL | |
function julia_nn(lstm, X, Y, opt, θ, epochs=100) | |
for epoch ∈ 1:epochs | |
Flux.reset!(lstm) | |
∇ = gradient(θ) do | |
[lstm(x) for x ∈ X[1:end-1]] # Warm up | |
Flux.Losses.mse(lstm(X[end]), Y) # MSE on last item only | |
end | |
Flux.update!(opt, θ, ∇) | |
end | |
end | |
function main() | |
Random.seed!(72) | |
X = [rand(Float32, 100, 1_000) for _ ∈ 1:20] | |
Y = rand(Float32, 1, 1_000) | |
lstm = Chain( | |
LSTM(100 => 128), | |
LSTM(128 => 128), | |
Dense(128 => 1) | |
) | |
opt = ADAM() | |
θ = Flux.params(lstm) | |
# Run one epoch outside of computation timings for compilation | |
Flux.reset!(lstm) | |
∇ = gradient(θ) do | |
[lstm(x) for x ∈ X[1:end-1]] # Warm up | |
Flux.Losses.mse(lstm(X[end]), Y) # MSE on last item only | |
end | |
Flux.update!(opt, θ, ∇) | |
Flux.reset!(lstm) | |
@time julia_nn(lstm, X, Y, opt, θ) | |
end | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Pkg | |
Pkg.activate(".") | |
using BenchmarkTools | |
using Flux | |
using StatsBase | |
using Random | |
using MKL | |
struct Seq2One | |
rnn | |
fc | |
end | |
Flux.@functor Seq2One | |
function (m::Seq2One)(X) | |
[m.rnn(x) for x ∈ X[1:end-1]] | |
m.fc(m.rnn(X[end])) | |
end | |
function julia_nn(lstm, X, Y, opt, θ, epochs=100) | |
for epoch ∈ 1:epochs | |
Flux.reset!(lstm) | |
∇ = gradient(θ) do | |
Flux.Losses.mse(lstm(X), Y) # MSE on last item only | |
end | |
Flux.update!(opt, θ, ∇) | |
end | |
end | |
function main() | |
Random.seed!(72) | |
X = [rand(Float32, 100, 1_000) for _ ∈ 1:20] | |
Y = rand(Float32, 1, 1_000) | |
lstm = Seq2One( | |
Chain(LSTM(100 => 128), LSTM(128 => 128)), | |
Dense(128 => 1) | |
) | |
opt = ADAM() | |
θ = Flux.params(lstm) | |
# Run one epoch outside of computation timings for compilation | |
Flux.reset!(lstm) | |
∇ = gradient(θ) do | |
Flux.Losses.mse(lstm(X), Y) # MSE on last item only | |
end | |
Flux.update!(opt, θ, ∇) | |
Flux.reset!(lstm) | |
@time julia_nn(lstm, X, Y, opt, θ) | |
end | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Pkg | |
Pkg.activate(".") | |
using BenchmarkTools | |
using StatsBase | |
using Random | |
using PyCall | |
py""" | |
import torch | |
import torch.nn as nn | |
class LSTM(nn.Module): | |
def __init__(self, input_size, hidden_size, output_size): | |
super().__init__() # Base class constructor | |
self.lstm = nn.LSTM(input_size, hidden_size, num_layers=2, | |
batch_first=True) | |
self.linear = nn.Linear(hidden_size, output_size) | |
def forward(self, x): | |
h = self.lstm(x)[0] | |
return self.linear(h) | |
def train_network(X, Y, epochs=100): | |
dev = torch.device("cpu") # torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
lstm = LSTM(100, 128, 1).to(dev) | |
X, Y = torch.from_numpy(X).to(dev), torch.from_numpy(Y).to(dev) | |
lstm.train() | |
criterion = nn.MSELoss() | |
opt = torch.optim.Adam(lstm.parameters()) | |
for epoch in range(epochs): | |
output = lstm(X) | |
output = output[:, -1, :] | |
opt.zero_grad() | |
loss = criterion(output, Y) | |
loss.backward() | |
opt.step() | |
""" | |
train_pytorch = py"train_network" | |
function main() | |
Random.seed!(72) | |
X = rand(Float32, 1_000, 20, 100) | |
Y = rand(Float32, 1_000, 1) | |
@time train_pytorch(X, Y) | |
end | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment