-
-
Save KeitaTakenouchi/9b063231ee893ff6e3be1a6dec3394c6 to your computer and use it in GitHub Desktop.
Performance comparison: ordinary Python model vs TorchScript from Python vs TorchScript from Java vs . C++ TorchScript vs C++ LibTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <torch/torch.h> | |
#include <chrono> | |
#include <iostream> | |
#include <memory> | |
struct MyNN : torch::nn::Module | |
{ | |
MyNN() | |
{ | |
fc1 = register_module("fc1", torch::nn::Linear(2, 8)); | |
fc2 = register_module("fc2", torch::nn::Linear(8, 8)); | |
fc3 = register_module("fc3", torch::nn::Linear(8, 1)); | |
} | |
torch::Tensor forward(torch::Tensor x) | |
{ | |
// Use one of many tensor manipulation functions. | |
x = torch::relu(fc1->forward(x)); | |
x = torch::relu(fc2->forward(x)); | |
x = fc3->forward(x); | |
return x; | |
} | |
// Use one of many "standard library" modules. | |
torch::nn::Linear fc1{nullptr}, fc2{nullptr}, fc3{nullptr}; | |
}; | |
// for training | |
int main() | |
{ | |
const int size = 1000; | |
const int batch_size = 100; | |
torch::Tensor input_values = torch::ones({size, 2}); | |
torch::Tensor output_values = torch::ones(size); | |
for (size_t i = 0; i < size; i++) | |
{ | |
float x1 = -10 + 20 * (rand() / float(RAND_MAX)); | |
float x2 = -10 + 20 * (rand() / float(RAND_MAX)); | |
input_values[i][0] = x1; | |
input_values[i][1] = x2; | |
output_values[i] = abs(x1) + x2; | |
} | |
at::Tensor X = input_values.reshape({size / batch_size, batch_size, 2}); | |
at::Tensor Y = output_values.reshape({size / batch_size, batch_size}); | |
auto model = std::make_shared<MyNN>(); | |
torch::optim::SGD optimizer(model->parameters(), 0.001); | |
int n_epochs = 100; | |
for (size_t epoch = 1; epoch <= n_epochs; epoch++) | |
{ | |
for (size_t i = 0; i < size / batch_size; i++) | |
{ | |
torch::Tensor x = X[i]; | |
torch::Tensor y = Y[i]; | |
optimizer.zero_grad(); | |
torch::Tensor predicted = model->forward(x).reshape({batch_size}); | |
torch::Tensor loss = torch::mse_loss(predicted, y); | |
loss.backward(); | |
optimizer.step(); | |
} | |
} | |
torch::save(model, "../model_cpp.pt"); | |
} | |
// for inference | |
int main() | |
{ | |
c10::InferenceMode guard(true); | |
auto model = std::make_shared<MyNN>(); | |
torch::load(model, "../model_cpp.pt"); | |
const int size = 1000000; | |
torch::Tensor input_values = torch::ones({size, 2}); | |
for (size_t i = 0; i < size; i++) | |
{ | |
float x1 = -10 + 20 * (rand() / float(RAND_MAX)); | |
float x2 = -10 + 20 * (rand() / float(RAND_MAX)); | |
input_values[i][0] = x1; | |
input_values[i][1] = x2; | |
} | |
std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now(); | |
for (size_t i = 0; i < size; i++) | |
{ | |
model->forward(input_values[i]); | |
} | |
std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); | |
std::cout << std::chrono::duration_cast<std::chrono::milliseconds>(end - begin).count() << "[ms]" << std::endl; | |
// -> 9 seconds. This is very fast! | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.example; | |
import org.pytorch.IValue; | |
import org.pytorch.Module; | |
import org.pytorch.Tensor; | |
import java.util.Random; | |
public class Main { | |
public static void main(String[] args) { | |
Module mod = Module.load("../../model_torchscript.pt"); | |
Random r = new Random(42); | |
int size = 1000000; | |
float[][] data = new float[size][2]; | |
for (int i = 0; i < size; i++) { | |
data[i][0] = -10 + 20 * r.nextFloat(); | |
data[i][1] = -10 + 20 * r.nextFloat(); | |
} | |
long start = System.currentTimeMillis(); | |
for (int i = 0; i < size; i++) { | |
Tensor inputs = Tensor.fromBlob( | |
data[i], | |
new long[]{2} // shape | |
); | |
mod.forward(IValue.from(inputs)); | |
} | |
long elapsed = (System.currentTimeMillis() - start); | |
System.out.println(elapsed); // -> 36 seconds. Faster! | |
System.exit(0); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
import random | |
# Define my model | |
class MyNN(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.a = nn.Linear(2, 8) | |
self.b = nn.Linear(8, 8) | |
self.c = nn.Linear(8, 1) | |
def forward(self, x): | |
x = nn.ReLU()(self.a(x)) | |
x = nn.ReLU()(self.b(x)) | |
x = self.c(x) | |
return x | |
def abs(x): | |
if x >= 0: | |
return x | |
else: | |
return -x | |
# code for training | |
device = torch.device('cpu') | |
size = 1000 | |
batch_size = 100 | |
input_values = list(map( | |
lambda _: [random.uniform(-10, 10), | |
random.uniform(-10, 10)], | |
range(size))) | |
outputs_values = list(map(lambda x: [abs(x[0]) + x[1]], input_values)) | |
X = torch.tensor(input_values).to(device).split(batch_size) | |
Y = torch.tensor(outputs_values).to(device).split(batch_size) | |
model = MyNN().to(device) | |
loss_fn = nn.MSELoss() | |
optimizer = torch.optim.SGD(model.parameters(), lr=0.001) | |
n_epochs = 100 | |
model.train() | |
for _ in range(n_epochs): | |
for x, y in zip(X, Y): | |
optimizer.zero_grad() | |
predicted = model(x) | |
loss = loss_fn(predicted, y) | |
loss.backward() | |
optimizer.step() | |
# save the model (by a traditional way) | |
torch.save(model.state_dict(), "model_py.pt") | |
# save as TorchScript | |
example = torch.tensor([random.uniform(-10, 10), random.uniform(-10, 10)]) | |
traced_script_module = torch.jit.trace(model, example) | |
traced_script_module.save("../model_torchscript.pt") | |
# load the (traditional) model and measure inference time | |
size = 1_000_000 | |
input_values = list(map( | |
lambda _: [random.uniform(-10, 10), random.uniform(-10, 10)], range(size))) | |
X = torch.tensor(input_values) | |
model = MyNN() | |
model.load_state_dict(torch.load("model_py.pt")) | |
start = time.time() | |
for x in X: | |
_ = model(x) | |
print(time.time() - start) # -> 119 seconds in my environment | |
# load TorchScript model and measure inference time | |
model = torch.jit.load("../model_torchscript.pt") | |
start = time.time() | |
for x in X: | |
_ = model(x) | |
print(time.time() - start) # -> 51 seconds. Faster! |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <torch/script.h> | |
#include <chrono> | |
#include <iostream> | |
#include <memory> | |
#include <chrono> | |
int main(int argc, const char *argv[]) | |
{ | |
if (argc != 2) | |
{ | |
std::cerr << "usage: example-app <path-to-exported-script-module>\n"; | |
return -1; | |
} | |
torch::jit::script::Module model; | |
try | |
{ | |
// Deserialize the ScriptModule from a file using torch::jit::load(). | |
model = torch::jit::load(argv[1]); | |
} | |
catch (const c10::Error &e) | |
{ | |
return -1; | |
} | |
int size = 1000000; | |
float data[size][2]; | |
for (size_t i = 0; i < size; i++) | |
{ | |
data[i][0] = -10 + 20 * (rand() / float(RAND_MAX)); | |
data[i][1] = -10 + 20 * (rand() / float(RAND_MAX)); | |
} | |
std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now(); | |
for (size_t i = 0; i < size; i++) | |
{ | |
int64_t shape[] = {2}; | |
std::vector<torch::jit::IValue> inputs{torch::from_blob(data[i], shape)}; | |
model.forward(inputs); | |
} | |
std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); | |
std::cout << std::chrono::duration_cast<std::chrono::milliseconds>(end - begin).count() << "[ms]" << std::endl; | |
// -> 44 seconds. This is not so fast. | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment