Created
August 9, 2023 22:39
-
-
Save pranavsharma/c3275863291b20b538cf0cb3265ef069 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
Program to demonstrate using one session and multiple threads to call Run on that session. | |
g++ -std=c++17 -o test_ort_one_session_multiple_threads test_ort_one_session_multiple_threads.cc -I onnxruntime-linux-x64-1.15.1/include/ -lonnxruntime -Lonnxruntime-linux-x64-1.15.1/lib/ -lpthread -Wl,-rpath,/home/pranav/onnxruntime-linux-x64-1.15.1/lib/ | |
Author: Github id: pranavsharma) | |
*/ | |
#include <onnxruntime_cxx_api.h> | |
#include <vector> | |
#include <string> | |
#include <iostream> | |
#include <thread> | |
#include <functional> | |
#include <numeric> | |
int main() | |
{ | |
Ort::Env env; | |
const char *model_path = "mul_1.onnx"; | |
Ort::Session sess{env, model_path, Ort::SessionOptions{nullptr}}; // Note session is created only once for both the threads | |
Ort::MemoryInfo mem_info = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtDeviceAllocator, OrtMemType::OrtMemTypeDefault); | |
std::vector<int64_t> shape{3, 2}; | |
int num_inputs = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>()); | |
std::vector<const char *> input_names{"X"}; | |
std::vector<const char *> output_names{"Y"}; | |
std::thread t1([&]() | |
{ | |
auto id = std::this_thread::get_id(); | |
std::cout << "Inside thread " << id << "\n"; | |
std::vector<float> input(num_inputs, 3.0); | |
auto ort_value = Ort::Value::CreateTensor(mem_info, input.data(), input.size(), shape.data(), shape.size()); | |
auto output = sess.Run(Ort::RunOptions{nullptr}, input_names.data(), &ort_value, input_names.size(), output_names.data(), output_names.size()); }); | |
std::thread t2([&]() | |
{ | |
auto id = std::this_thread::get_id(); | |
std::cout << "Inside thread " << id << "\n"; | |
std::vector<float> input(num_inputs, 6.0); | |
auto ort_value = Ort::Value::CreateTensor(mem_info, input.data(), input.size(), shape.data(), shape.size()); | |
auto output = sess.Run(Ort::RunOptions{nullptr}, input_names.data(), &ort_value, input_names.size(), output_names.data(), output_names.size()); }); | |
t1.join(); | |
t2.join(); | |
std::cout << "End of program\n"; | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
it will crush when use DMLEP or CUDAEP.