#ifndef PLS_SCHEDULER_IMPL_H #define PLS_SCHEDULER_IMPL_H #include #include "context_switcher/context_switcher.h" #include "context_switcher/continuation.h" #include "pls/internal/scheduling/task_manager.h" #include "pls/internal/scheduling/base_task.h" #include "base_task.h" #include "pls/internal/profiling/dag_node.h" namespace pls::internal::scheduling { template scheduler::scheduler(unsigned int num_threads, size_t computation_depth, size_t stack_size, bool reuse_thread, ALLOC &&stack_allocator) : num_threads_{num_threads}, reuse_thread_{reuse_thread}, sync_barrier_{num_threads + 1 - reuse_thread}, worker_threads_{}, thread_states_{}, main_thread_starter_function_{nullptr}, work_section_done_{false}, terminated_{false}, stack_allocator_{std::make_shared(std::forward(stack_allocator))} #if PLS_PROFILING_ENABLED , profiler_{num_threads} #endif { worker_threads_.reserve(num_threads); task_managers_.reserve(num_threads); thread_states_.reserve(num_threads); for (unsigned int i = 0; i < num_threads_; i++) { auto &this_task_manager = task_managers_.emplace_back(std::make_unique(i, computation_depth, stack_size, stack_allocator_)); auto &this_thread_state = thread_states_.emplace_back(std::make_unique(*this, i, *this_task_manager)); if (reuse_thread && i == 0) { worker_threads_.emplace_back(); continue; // Skip over first/main thread when re-using the users thread, as this one will replace the first one. } auto *this_thread_state_pointer = this_thread_state.get(); worker_threads_.emplace_back([this_thread_state_pointer] { thread_state::set(this_thread_state_pointer); work_thread_main_loop(); }); } } class scheduler::init_function { public: virtual void run() = 0; }; template class scheduler::init_function_impl : public init_function { public: explicit init_function_impl(F &function) : function_{function} {} void run() override { base_task *root_task = thread_state::get().get_active_task(); #if PLS_PROFILING_ENABLED thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(), root_task->profiling_node_); thread_state::get().get_scheduler().profiler_.task_prepare_stack_measure(thread_state::get().get_thread_id(), root_task->stack_memory_, root_task->stack_size_); #endif root_task->run_as_task([root_task, this](auto cont) { root_task->is_synchronized_ = true; thread_state::get().main_continuation() = std::move(cont); function_(); thread_state::get().get_scheduler().work_section_done_.store(true); PLS_ASSERT(thread_state::get().main_continuation().valid(), "Must return valid continuation from main task."); #if PLS_PROFILING_ENABLED thread_state::get().get_scheduler().profiler_.task_finish_stack_measure(thread_state::get().get_thread_id(), root_task->stack_memory_, root_task->stack_size_, root_task->profiling_node_); thread_state::get().get_scheduler().profiler_.task_stop_running(thread_state::get().get_thread_id(), root_task->profiling_node_); #endif return std::move(thread_state::get().main_continuation()); }); } private: F &function_; }; template void scheduler::perform_work(Function work_section) { // Prepare main root task init_function_impl starter_function{work_section}; main_thread_starter_function_ = &starter_function; #if PLS_PROFILING_ENABLED auto *root_task = thread_state_for(0).get_active_task(); auto *root_node = profiler_.start_profiler_run(); root_task->profiling_node_ = root_node; #endif work_section_done_ = false; if (reuse_thread_) { auto &my_state = thread_state_for(0); thread_state::set(&my_state); // Make THIS THREAD become the main worker sync_barrier_.wait(); // Trigger threads to wake up work_thread_work_section(); // Simply also perform the work section on the main loop sync_barrier_.wait(); // Wait for threads to finish } else { // Simply trigger the others to do the work, this thread will sleep/wait for the time being sync_barrier_.wait(); // Trigger threads to wake up sync_barrier_.wait(); // Wait for threads to finish } #if PLS_PROFILING_ENABLED profiler_.stop_profiler_run(); #endif } template void scheduler::spawn(Function &&lambda) { thread_state &spawning_state = thread_state::get(); base_task *last_task = spawning_state.get_active_task(); base_task *spawned_task = last_task->next_; #if PLS_PROFILING_ENABLED spawning_state.get_scheduler().profiler_.task_prepare_stack_measure(spawning_state.get_thread_id(), spawned_task->stack_memory_, spawned_task->stack_size_); auto *child_dag_node = spawning_state.get_scheduler().profiler_.task_spawn_child(spawning_state.get_thread_id(), last_task->profiling_node_); spawned_task->profiling_node_ = child_dag_node; #endif auto continuation = spawned_task->run_as_task([last_task, spawned_task, lambda, &spawning_state](auto cont) { // allow stealing threads to continue the last task. last_task->continuation_ = std::move(cont); // we are now executing the new task, allow others to steal the last task continuation. spawned_task->is_synchronized_ = true; spawning_state.set_active_task(spawned_task); spawning_state.get_task_manager().push_local_task(last_task); // execute the lambda itself, which could lead to a different thread returning. #if PLS_PROFILING_ENABLED spawning_state.get_scheduler().profiler_.task_stop_running(spawning_state.get_thread_id(), last_task->profiling_node_); spawning_state.get_scheduler().profiler_.task_start_running(spawning_state.get_thread_id(), spawned_task->profiling_node_); #endif lambda(); thread_state &syncing_state = thread_state::get(); PLS_ASSERT(syncing_state.get_active_task() == spawned_task, "Task manager must always point its active task onto whats executing."); // try to pop a task of the syncing task manager. // possible outcomes: // - this is a different task manager, it must have an empty deque and fail // - this is the same task manager and someone stole last tasks, thus this will fail // - this is the same task manager and no one stole the last task, this this will succeed base_task *popped_task = syncing_state.get_task_manager().pop_local_task(); if (popped_task) { // Fast path, simply continue execution where we left of before spawn. PLS_ASSERT(popped_task == last_task, "Fast path, nothing can have changed until here."); PLS_ASSERT(&spawning_state == &syncing_state, "Fast path, we must only return if the task has not been stolen/moved to other thread."); PLS_ASSERT(last_task->continuation_.valid(), "Fast path, no one can have continued working on the last task."); syncing_state.set_active_task(last_task); #if PLS_PROFILING_ENABLED syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(), spawned_task->stack_memory_, spawned_task->stack_size_, spawned_task->profiling_node_); syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(), spawned_task->profiling_node_); #endif return std::move(last_task->continuation_); } else { // Slow path, the last task was stolen. This path is common to sync() events. #if PLS_PROFILING_ENABLED syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(), spawned_task->stack_memory_, spawned_task->stack_size_, spawned_task->profiling_node_); syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(), spawned_task->profiling_node_); #endif auto continuation = slow_return(syncing_state); return continuation; } }); #if PLS_PROFILING_ENABLED thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(), last_task->profiling_node_); #endif if (continuation.valid()) { // We jumped in here from the main loop, keep track! thread_state::get().main_continuation() = std::move(continuation); } } } #endif //PLS_SCHEDULER_IMPL_H