Skip to content

Instantly share code, notes, and snippets.

@bartvm
Created July 16, 2017 19:27
Show Gist options
  • Save bartvm/4da4835ec21a12e4e8d657efeb1e1f04 to your computer and use it in GitHub Desktop.
Save bartvm/4da4835ec21a12e4e8d657efeb1e1f04 to your computer and use it in GitHub Desktop.
diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp
index f1e09b0e..c50f507e 100644
--- a/torch/csrc/autograd/engine.cpp
+++ b/torch/csrc/autograd/engine.cpp
@@ -21,6 +21,10 @@
#include <THC/THC.h>
#endif
+void tid() {
+ printf("%d ", (int)std::hash<std::thread::id>()(std::this_thread::get_id()));
+}
+
using thpp::Tensor;
namespace torch { namespace autograd {
@@ -89,14 +93,21 @@ auto ReadyQueue::push_front(FunctionTask item) -> void {
{
std::lock_guard<std::mutex> lock(mutex);
++item.base->outstanding_tasks;
+ tid();
+ printf("Pushed task onto queue, %llu outstanding\n", item.base->outstanding_tasks.load());
queue.push_front(std::move(item));
}
not_empty.notify_one();
}
auto ReadyQueue::pop_back() -> FunctionTask {
+ tid();
+ printf("Getting lock\n");
std::unique_lock<std::mutex> lock(mutex);
- not_empty.wait(lock, [this]{ return !queue.empty(); });
+ printf("Waiting for a task\n");
+ if (queue.empty()) {
+ not_empty.wait(lock, [this]{ return !queue.empty(); });
+ }
auto task = std::move(queue.back()); queue.pop_back();
return task;
}
@@ -110,20 +121,35 @@ Engine::~Engine() = default;
auto Engine::thread_main(std::shared_ptr<ReadyQueue> queue, int device) -> void {
THInferNumThreads();
AutoGPU guard(device);
+ tid();
+ printf("Starting endless loop main thread\n");
while (1) {
+ tid();
+ printf("Trying to get next task\n");
FunctionTask task = queue->pop_back();
+ tid();
+ printf("Got a task\n");
if (!task.base->has_error.load()) {
try {
+ tid();
+ printf("About to evaluate function, %llu outstanding\n", task.base->outstanding_tasks.load());
evaluate_function(task);
} catch (std::exception& e) {
thread_on_exception(task, e);
}
}
+ tid();
+ printf("Evaluated function, %llu outstanding\n", task.base->outstanding_tasks.load() - 1);
if (--task.base->outstanding_tasks == 0) {
std::lock_guard<std::mutex> lock(task.base->mutex);
- task.base->not_done.notify_all();
+ task.base->not_done.notify_one();
+ tid();
+ printf("Breaking free!\n");
+ break;
}
}
+ tid();
+ printf("Ending main thread\n");
}
auto Engine::thread_on_exception(FunctionTask& task, std::exception& e) -> void {
@@ -299,7 +325,9 @@ auto Engine::execute(const function_list& input_roots,
variable_list& inputs,
bool keep_graph,
const callback_map& callbacks) -> void {
- std::call_once(start_threads_flag, &Engine::start_threads, this);
+ tid();
+ printf("Engine starting threads\n");
+ start_threads();
// Callbacks are only valid for the duration of this run and should always be cleared
ClearCallbacks _cb_guard(post_callbacks, post_callbacks_lock);
@@ -310,6 +338,8 @@ auto Engine::execute(const function_list& input_roots,
function_queue roots;
for (auto entry : input_roots) {
if (entry.first->is_executable) {
+ tid();
+ printf("Pushed first task to queue\n");
graph_task.has_any_work = true;
roots.push_back(graph_root.get());
ready_queue(-1).push_front(FunctionTask(&graph_task, graph_root, InputBuffer(0)));
@@ -329,9 +359,15 @@ auto Engine::execute(const function_list& input_roots,
compute_dependencies(std::move(roots), graph_task);
// Wait for all tasks to complete
- graph_task.not_done.wait(lock, [&graph_task]{
- return graph_task.outstanding_tasks.load() == 0;
- });
+ tid();
+ printf("Waiting for graph to complete!\n");
+ if (graph_task.outstanding_tasks.load() != 0) {
+ graph_task.not_done.wait(lock, [&graph_task]{
+ return graph_task.outstanding_tasks.load() == 0;
+ });
+ }
+ tid();
+ printf("Done waiting\n");
// Check for an exception while running backwards
if (graph_task.has_error.load()) {
@@ -372,6 +408,8 @@ auto Engine::start_threads() -> void {
}
#endif
int num_threads = num_devices + 1;
+ tid();
+ printf("Starting %d threads\n", num_threads);
ready_queues = std::vector<std::shared_ptr<ReadyQueue>>(num_threads);
for (int i = 0; i < num_threads; ++i) {
auto& queue = ready_queues[i];
diff --git a/torch/csrc/autograd/engine.h b/torch/csrc/autograd/engine.h
index a0308f7d..66193b96 100644
--- a/torch/csrc/autograd/engine.h
+++ b/torch/csrc/autograd/engine.h
@@ -55,7 +55,6 @@ protected:
virtual void thread_main(std::shared_ptr<ReadyQueue> queue, int device);
virtual void thread_on_exception(FunctionTask& task, std::exception& e);
- std::once_flag start_threads_flag;
std::vector<std::shared_ptr<ReadyQueue>> ready_queues;
std::vector<std::function<void()>> post_callbacks;
std::mutex post_callbacks_lock;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment