Skip to content

Instantly share code, notes, and snippets.

@bartvm
Last active July 15, 2017 23:26
Show Gist options
  • Save bartvm/f87965f902a17c3a9e80b5bfafa3fc97 to your computer and use it in GitHub Desktop.
Save bartvm/f87965f902a17c3a9e80b5bfafa3fc97 to your computer and use it in GitHub Desktop.
diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp
index f1e09b0e..d78c03e8 100644
--- a/torch/csrc/autograd/engine.cpp
+++ b/torch/csrc/autograd/engine.cpp
@@ -110,7 +110,7 @@ Engine::~Engine() = default;
auto Engine::thread_main(std::shared_ptr<ReadyQueue> queue, int device) -> void {
THInferNumThreads();
AutoGPU guard(device);
- while (1) {
+ while (!exit.back().load()) {
FunctionTask task = queue->pop_back();
if (!task.base->has_error.load()) {
try {
@@ -124,6 +124,7 @@ auto Engine::thread_main(std::shared_ptr<ReadyQueue> queue, int device) -> void
task.base->not_done.notify_all();
}
}
+ exit.pop_back();
}
auto Engine::thread_on_exception(FunctionTask& task, std::exception& e) -> void {
@@ -299,7 +300,7 @@ auto Engine::execute(const function_list& input_roots,
variable_list& inputs,
bool keep_graph,
const callback_map& callbacks) -> void {
- std::call_once(start_threads_flag, &Engine::start_threads, this);
+ start_threads();
// Callbacks are only valid for the duration of this run and should always be cleared
ClearCallbacks _cb_guard(post_callbacks, post_callbacks_lock);
@@ -351,6 +352,8 @@ auto Engine::execute(const function_list& input_roots,
post_callbacks[i]();
cb_lock.lock();
}
+
+ exit.back().store(true);
}
void Engine::queue_callback(std::function<void()> callback) {
@@ -371,6 +374,7 @@ auto Engine::start_threads() -> void {
num_devices = 0;
}
#endif
+ exit.emplace_back(false);
int num_threads = num_devices + 1;
ready_queues = std::vector<std::shared_ptr<ReadyQueue>>(num_threads);
for (int i = 0; i < num_threads; ++i) {
diff --git a/torch/csrc/autograd/engine.h b/torch/csrc/autograd/engine.h
index a0308f7d..1722dd92 100644
--- a/torch/csrc/autograd/engine.h
+++ b/torch/csrc/autograd/engine.h
@@ -4,6 +4,7 @@
// to "root" variables (variables created by the user with requires_grad=True).
#include <Python.h>
+#include <atomic>
#include <deque>
#include <memory>
#include <unordered_map>
@@ -55,8 +56,8 @@ protected:
virtual void thread_main(std::shared_ptr<ReadyQueue> queue, int device);
virtual void thread_on_exception(FunctionTask& task, std::exception& e);
- std::once_flag start_threads_flag;
std::vector<std::shared_ptr<ReadyQueue>> ready_queues;
+ std::deque<std::atomic_bool> exit;
std::vector<std::function<void()>> post_callbacks;
std::mutex post_callbacks_lock;
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment