#include "pls/internal/scheduling/scheduler.h" #include "pls/internal/scheduling/thread_state.h" #include "pls/internal/scheduling/task.h" #include "pls/internal/base/error_handling.h" namespace pls { namespace internal { namespace scheduling { scheduler::scheduler(scheduler_memory &memory, const unsigned int num_threads, bool reuse_thread) : num_threads_{num_threads}, reuse_thread_{reuse_thread}, memory_{memory}, sync_barrier_{num_threads + 1 - reuse_thread}, terminated_{false} { if (num_threads_ > memory.max_threads()) { PLS_ERROR("Tried to create scheduler with more OS threads than pre-allocated memory."); } for (unsigned int i = 0; i < num_threads_; i++) { // Placement new is required, as the memory of `memory_` is not required to be initialized. memory.thread_state_for(i).scheduler_ = this; memory.thread_state_for(i).id_ = i; if (reuse_thread && i == 0) { continue; // Skip over first/main thread when re-using the users thread, as this one will replace the first one. } memory.thread_for(i) = base::thread(&scheduler::work_thread_main_loop, &memory_.thread_state_for(i)); } } scheduler::~scheduler() { terminate(); } void scheduler::work_thread_main_loop() { auto &scheduler = thread_state::get().get_scheduler(); while (true) { // Wait to be triggered scheduler.sync_barrier_.wait(); // Check for shutdown if (scheduler.terminated_) { return; } scheduler.work_thread_work_section(); // Sync back with main thread scheduler.sync_barrier_.wait(); } } void scheduler::work_thread_work_section() { auto &my_state = thread_state::get(); my_state.reset(); auto &my_cont_manager = my_state.get_cont_manager(); auto const num_threads = my_state.get_scheduler().num_threads(); auto const my_id = my_state.get_id(); if (my_state.get_id() == 0) { // Main Thread, kick off by executing the user's main code block. main_thread_starter_function_->run(); } do { // Work off pending continuations we need to execute locally while (my_cont_manager.falling_through()) { my_cont_manager.execute_fall_through_code(); } // Steal Routine (will be continuously executed when there are no more fall through's). // TODO: move into separate function const size_t offset = my_state.random_() % num_threads; const size_t max_tries = num_threads - 1; for (size_t i = 0; i < max_tries; i++) { size_t target = (offset + i) % num_threads; // Skip our self for stealing target = ((target == my_id) + target) % num_threads; auto &target_state = my_state.get_scheduler().thread_state_for(target); PLS_ASSERT(my_cont_manager.is_clean(), "Only steal with clean chain!"); auto *stolen_task = target_state.get_task_manager().steal_remote_task(my_cont_manager); if (stolen_task != nullptr) { my_state.parent_cont_ = stolen_task->get_cont(); my_state.right_spawn_ = true; stolen_task->execute(); if (my_cont_manager.falling_through()) { break; } else { my_cont_manager.fall_through_and_notify_cont(stolen_task->get_cont(), true); break; } } } } while (!work_section_done_); PLS_ASSERT(my_cont_manager.is_clean(), "Only finish work section with clean chain!"); } void scheduler::terminate() { if (terminated_) { return; } terminated_ = true; sync_barrier_.wait(); for (unsigned int i = 0; i < num_threads_; i++) { if (reuse_thread_ && i == 0) { continue; } memory_.thread_for(i).join(); } } thread_state &scheduler::thread_state_for(size_t id) { return memory_.thread_state_for(id); } } } }