#ifndef PLS_SCHEDULER_IMPL_H #define PLS_SCHEDULER_IMPL_H #include #include "context_switcher/context_switcher.h" #include "context_switcher/continuation.h" #include "pls/internal/scheduling/task_manager.h" #include "pls/internal/scheduling/base_task.h" #include "base_task.h" #include "pls/internal/profiling/dag_node.h" namespace pls::internal::scheduling { template scheduler::scheduler(unsigned int num_threads, size_t computation_depth, size_t stack_size, bool reuse_thread, size_t serial_stack_size, ALLOC &&stack_allocator) : num_threads_{num_threads}, reuse_thread_{reuse_thread}, sync_barrier_{num_threads + 1 - reuse_thread}, worker_threads_{}, thread_states_{}, main_thread_starter_function_{nullptr}, work_section_done_{false}, terminated_{false}, stack_allocator_{std::make_shared(std::forward(stack_allocator))}, serial_stack_size_{serial_stack_size} #if PLS_PROFILING_ENABLED , profiler_{num_threads} #endif { worker_threads_.reserve(num_threads); task_managers_.reserve(num_threads); thread_states_.reserve(num_threads); std::atomic num_spawned{0}; for (unsigned int i = 0; i < num_threads_; i++) { auto &this_task_manager = task_managers_.emplace_back(std::make_unique(i, computation_depth, stack_size, stack_allocator_)); auto &this_thread_state = thread_states_.emplace_back(std::make_unique(*this, i, *this_task_manager, stack_allocator_, serial_stack_size)); if (reuse_thread && i == 0) { worker_threads_.emplace_back(); num_spawned++; continue; // Skip over first/main thread when re-using the users thread, as this one will replace the first one. } auto *this_thread_state_pointer = this_thread_state.get(); worker_threads_.emplace_back([this_thread_state_pointer, &num_spawned] { thread_state::set(this_thread_state_pointer); num_spawned++; work_thread_main_loop(); }); } while (num_spawned < num_threads) { std::this_thread::yield(); } } class scheduler::init_function { public: virtual void run() = 0; }; template class scheduler::init_function_impl : public init_function { public: explicit init_function_impl(F &function) : function_{function} {} void run() override { base_task *root_task = thread_state::get().get_active_task(); root_task->attached_resources_.store(nullptr, std::memory_order_relaxed); #if PLS_PROFILING_ENABLED thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(), root_task->profiling_node_); thread_state::get().get_scheduler().profiler_.task_prepare_stack_measure(thread_state::get().get_thread_id(), root_task->stack_memory_, root_task->stack_size_); #endif root_task->run_as_task([root_task, this](auto cont) { root_task->is_synchronized_ = true; thread_state::get().main_continuation() = std::move(cont); function_(); thread_state::get().get_scheduler().work_section_done_.store(true); PLS_ASSERT(thread_state::get().main_continuation().valid(), "Must return valid continuation from main task."); #if PLS_SLEEP_WORKERS_ON_EMPTY thread_state::get().get_scheduler().empty_queue_counter_.store(-1000); thread_state::get().get_scheduler().empty_queue_decrease_counter_and_wake(); #endif #if PLS_PROFILING_ENABLED thread_state::get().get_scheduler().profiler_.task_finish_stack_measure(thread_state::get().get_thread_id(), root_task->stack_memory_, root_task->stack_size_, root_task->profiling_node_); thread_state::get().get_scheduler().profiler_.task_stop_running(thread_state::get().get_thread_id(), root_task->profiling_node_); #endif return std::move(thread_state::get().main_continuation()); }); } private: F &function_; }; template void scheduler::perform_work(Function work_section) { // Prepare main root task init_function_impl starter_function{work_section}; main_thread_starter_function_ = &starter_function; #if PLS_PROFILING_ENABLED auto *root_task = thread_state_for(0).get_active_task(); auto *root_node = profiler_.start_profiler_run(); root_task->profiling_node_ = root_node; #endif #if PLS_SLEEP_WORKERS_ON_EMPTY empty_queue_counter_.store(0); for (auto &thread_state : thread_states_) { thread_state->get_queue_empty_flag().store({EMPTY_QUEUE_STATE::QUEUE_NON_EMPTY}); } #endif work_section_done_ = false; if (reuse_thread_) { auto &my_state = thread_state_for(0); thread_state::set(&my_state); // Make THIS THREAD become the main worker sync_barrier_.wait(); // Trigger threads to wake up work_thread_work_section(); // Simply also perform the work section on the main loop sync_barrier_.wait(); // Wait for threads to finish } else { // Simply trigger the others to do the work, this thread will sleep/wait for the time being sync_barrier_.wait(); // Trigger threads to wake up sync_barrier_.wait(); // Wait for threads to finish } #if PLS_PROFILING_ENABLED profiler_.stop_profiler_run(); #endif } template void scheduler::spawn_internal(Function &&lambda) { if (thread_state::is_scheduler_active()) { thread_state &spawning_state = thread_state::get(); base_task *last_task = spawning_state.get_active_task(); base_task *spawned_task = last_task->next_; // We are at the end of our allocated tasks, go over to serial execution. if (spawned_task == nullptr) { scheduler::serial_internal(std::forward(lambda)); return; } // We are in a serial section. // For now we simply stay serial. Later versions could add nested parallel sections. if (last_task->is_serial_section_) { lambda(); return; } // Carry over the resources allocated by parents auto *attached_resources = last_task->attached_resources_.load(std::memory_order_relaxed); spawned_task->attached_resources_.store(attached_resources, std::memory_order_relaxed); #if PLS_PROFILING_ENABLED spawning_state.get_scheduler().profiler_.task_prepare_stack_measure(spawning_state.get_thread_id(), spawned_task->stack_memory_, spawned_task->stack_size_); auto *child_dag_node = spawning_state.get_scheduler().profiler_.task_spawn_child(spawning_state.get_thread_id(), last_task->profiling_node_); spawned_task->profiling_node_ = child_dag_node; #endif auto continuation = spawned_task->run_as_task([last_task, spawned_task, lambda, &spawning_state](auto cont) { // allow stealing threads to continue the last task. last_task->continuation_ = std::move(cont); // we are now executing the new task, allow others to steal the last task continuation. spawned_task->is_synchronized_ = true; spawning_state.set_active_task(spawned_task); spawning_state.get_task_manager().push_local_task(last_task); #if PLS_SLEEP_WORKERS_ON_EMPTY data_structures::stamped_integer queue_empty_flag = spawning_state.get_queue_empty_flag().load(std::memory_order_relaxed); switch (queue_empty_flag.value_) { case EMPTY_QUEUE_STATE::QUEUE_NON_EMPTY: { // The queue was not found empty, ignore it. break; } case EMPTY_QUEUE_STATE::QUEUE_MAYBE_EMPTY: { // Someone tries to mark us empty and might be re-stealing right now. data_structures::stamped_integer queue_non_empty_flag{queue_empty_flag.stamp_++, EMPTY_QUEUE_STATE::QUEUE_NON_EMPTY}; auto actual_empty_flag = spawning_state.get_queue_empty_flag().exchange(queue_non_empty_flag, std::memory_order_acq_rel); if (actual_empty_flag.value_ == EMPTY_QUEUE_STATE::QUEUE_EMPTY) { spawning_state.get_scheduler().empty_queue_decrease_counter_and_wake(); } break; } case EMPTY_QUEUE_STATE::QUEUE_EMPTY: { // Someone already marked the queue empty, we must revert its action on the central queue. data_structures::stamped_integer queue_non_empty_flag{queue_empty_flag.stamp_++, EMPTY_QUEUE_STATE::QUEUE_NON_EMPTY}; spawning_state.get_queue_empty_flag().store(queue_non_empty_flag, std::memory_order_release); spawning_state.get_scheduler().empty_queue_decrease_counter_and_wake(); break; } } #endif // execute the lambda itself, which could lead to a different thread returning. #if PLS_PROFILING_ENABLED spawning_state.get_scheduler().profiler_.task_stop_running(spawning_state.get_thread_id(), last_task->profiling_node_); spawning_state.get_scheduler().profiler_.task_start_running(spawning_state.get_thread_id(), spawned_task->profiling_node_); #endif lambda(); thread_state &syncing_state = thread_state::get(); PLS_ASSERT(syncing_state.get_active_task() == spawned_task, "Task manager must always point its active task onto whats executing."); // try to pop a task of the syncing task manager. // possible outcomes: // - this is a different task manager, it must have an empty deque and fail // - this is the same task manager and someone stole last tasks, thus this will fail // - this is the same task manager and no one stole the last task, this this will succeed base_task *popped_task = syncing_state.get_task_manager().pop_local_task(); if (popped_task) { // Fast path, simply continue execution where we left of before spawn. PLS_ASSERT(popped_task == last_task, "Fast path, nothing can have changed until here."); PLS_ASSERT(&spawning_state == &syncing_state, "Fast path, we must only return if the task has not been stolen/moved to other thread."); PLS_ASSERT(last_task->continuation_.valid(), "Fast path, no one can have continued working on the last task."); syncing_state.set_active_task(last_task); #if PLS_PROFILING_ENABLED syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(), spawned_task->stack_memory_, spawned_task->stack_size_, spawned_task->profiling_node_); syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(), spawned_task->profiling_node_); #endif return std::move(last_task->continuation_); } else { // Slow path, the last task was stolen. This path is common to sync() events. #if PLS_PROFILING_ENABLED syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(), spawned_task->stack_memory_, spawned_task->stack_size_, spawned_task->profiling_node_); syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(), spawned_task->profiling_node_); #endif auto continuation = slow_return(syncing_state, false); return continuation; } }); #if PLS_PROFILING_ENABLED thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(), last_task->profiling_node_); #endif if (continuation.valid()) { // We jumped in here from the main loop, keep track! thread_state::get().main_continuation() = std::move(continuation); } } else { // Scheduler not active... lambda(); } } template void scheduler::spawn_and_sync_internal(Function &&lambda) { thread_state &spawning_state = thread_state::get(); base_task *last_task = spawning_state.get_active_task(); base_task *spawned_task = last_task->next_; // We are at the end of our allocated tasks, go over to serial execution. if (spawned_task == nullptr) { scheduler::serial_internal(std::forward(lambda)); return; } // We are in a serial section. // For now we simply stay serial. Later versions could add nested parallel sections. if (last_task->is_serial_section_) { lambda(); return; } // Carry over the resources allocated by parents auto *attached_resources = last_task->attached_resources_.load(std::memory_order_relaxed); spawned_task->attached_resources_.store(attached_resources, std::memory_order_relaxed); #if PLS_PROFILING_ENABLED spawning_state.get_scheduler().profiler_.task_prepare_stack_measure(spawning_state.get_thread_id(), spawned_task->stack_memory_, spawned_task->stack_size_); auto *child_dag_node = spawning_state.get_scheduler().profiler_.task_spawn_child(spawning_state.get_thread_id(), last_task->profiling_node_); spawned_task->profiling_node_ = child_dag_node; #endif auto continuation = spawned_task->run_as_task([last_task, spawned_task, lambda, &spawning_state](auto cont) { // allow stealing threads to continue the last task. last_task->continuation_ = std::move(cont); // we are now executing the new task, allow others to steal the last task continuation. spawned_task->is_synchronized_ = true; spawning_state.set_active_task(spawned_task); // execute the lambda itself, which could lead to a different thread returning. #if PLS_PROFILING_ENABLED spawning_state.get_scheduler().profiler_.task_finish_stack_measure(spawning_state.get_thread_id(), last_task->stack_memory_, last_task->stack_size_, last_task->profiling_node_); spawning_state.get_scheduler().profiler_.task_stop_running(spawning_state.get_thread_id(), last_task->profiling_node_); auto *next_dag_node = spawning_state.get_scheduler().profiler_.task_sync(spawning_state.get_thread_id(), last_task->profiling_node_); last_task->profiling_node_ = next_dag_node; spawning_state.get_scheduler().profiler_.task_start_running(spawning_state.get_thread_id(), spawned_task->profiling_node_); #endif lambda(); thread_state &syncing_state = thread_state::get(); PLS_ASSERT(syncing_state.get_active_task() == spawned_task, "Task manager must always point its active task onto whats executing."); if (&syncing_state == &spawning_state && last_task->is_synchronized_) { // Fast path, simply continue execution where we left of before spawn. PLS_ASSERT(last_task->continuation_.valid(), "Fast path, no one can have continued working on the last task."); syncing_state.set_active_task(last_task); #if PLS_PROFILING_ENABLED syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(), spawned_task->stack_memory_, spawned_task->stack_size_, spawned_task->profiling_node_); syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(), spawned_task->profiling_node_); #endif return std::move(last_task->continuation_); } else { // Slow path, the last task was stolen. This path is common to sync() events. #if PLS_PROFILING_ENABLED syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(), spawned_task->stack_memory_, spawned_task->stack_size_, spawned_task->profiling_node_); syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(), spawned_task->profiling_node_); #endif auto continuation = slow_return(syncing_state, false); return continuation; } }); #if PLS_PROFILING_ENABLED thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(), last_task->profiling_node_); #endif PLS_ASSERT(!continuation.valid(), "We must not jump to a not-published (synced) spawn."); } template void scheduler::serial_internal(Function &&lambda) { thread_state &spawning_state = thread_state::get(); base_task *active_task = spawning_state.get_active_task(); if (active_task->is_serial_section_) { lambda(); } else { active_task->is_serial_section_ = true; context_switcher::enter_context(spawning_state.get_serial_call_stack(), spawning_state.get_serial_call_stack_size(), [&](auto cont) { lambda(); return cont; }); active_task->is_serial_section_ = false; } } } #endif //PLS_SCHEDULER_IMPL_H