#include "pls/internal/scheduling/scheduler.h" #include "context_switcher/context_switcher.h" #include "pls/internal/base/error_handling.h" #include namespace pls::internal::scheduling { scheduler::scheduler(unsigned int num_threads, size_t computation_depth, size_t stack_size, bool reuse_thread) : scheduler(num_threads, computation_depth, stack_size, reuse_thread, base::mmap_stack_allocator{}) {} scheduler::~scheduler() { terminate(); } void scheduler::work_thread_main_loop() { auto &scheduler = thread_state::get().get_scheduler(); while (true) { // Wait to be triggered scheduler.sync_barrier_.wait(); // Check for shutdown if (scheduler.terminated_) { return; } scheduler.work_thread_work_section(); // Sync back with main thread scheduler.sync_barrier_.wait(); } } void scheduler::work_thread_work_section() { thread_state &my_state = thread_state::get(); unsigned const num_threads = my_state.get_scheduler().num_threads(); if (my_state.get_thread_id() == 0) { // Main Thread, kick off by executing the user's main code block. main_thread_starter_function_->run(); } unsigned int failed_steals = 0; while (!work_section_done_) { PLS_ASSERT(check_task_chain(*my_state.get_active_task()), "Must start stealing with a clean task chain."); size_t target; do { target = my_state.get_rand() % num_threads; } while (target == my_state.get_thread_id()); thread_state &target_state = my_state.get_scheduler().thread_state_for(target); base_task *stolen_task = target_state.get_task_manager().steal_task(my_state); if (stolen_task) { my_state.set_active_task(stolen_task); // TODO: Figure out how to model 'steal' interaction . // The scheduler should decide on 'what to steal' and on how 'to manage the chains'. // The task_manager should perform the act of actually performing the steal/trade. // Maybe also give the chain management to the task_manager and associate resources with the traded tasks. PLS_ASSERT(check_task_chain_forward(*my_state.get_active_task()), "We are sole owner of this chain, it has to be valid!"); // Execute the stolen task by jumping to it's continuation. PLS_ASSERT(stolen_task->continuation_.valid(), "A task that we can steal must have a valid continuation for us to start working."); stolen_task->is_synchronized_ = false; context_switcher::switch_context(std::move(stolen_task->continuation_)); // We will continue execution in this line when we finished the stolen work. failed_steals = 0; } else { failed_steals++; if (failed_steals >= num_threads) { std::this_thread::yield(); } } } } void scheduler::sync() { thread_state &syncing_state = thread_state::get(); base_task *active_task = syncing_state.get_active_task(); base_task *spawned_task = active_task->next_; if (active_task->is_synchronized_) { return; // We are already the sole owner of last_task } else { auto continuation = spawned_task->run_as_task([active_task, spawned_task, &syncing_state](context_switcher::continuation cont) { active_task->continuation_ = std::move(cont); syncing_state.set_active_task(spawned_task); return slow_return(syncing_state); }); PLS_ASSERT(!continuation.valid(), "We only return to a sync point, never jump to it directly." "This must therefore never return an unfinished fiber/continuation."); return; // We cleanly synced to the last one finishing work on last_task } } context_switcher::continuation scheduler::slow_return(thread_state &calling_state) { base_task *this_task = calling_state.get_active_task(); PLS_ASSERT(this_task->depth_ > 0, "Must never try to return from a task at level 0 (no last task), as we must have a target to return to."); base_task *last_task = this_task->prev_; // Slow return means we try to finish the child 'this_task' of 'last_task' and we // do not know if we are the last child to finish. // If we are not the last one, we get a spare task chain for our resources and can return to the main scheduling loop. base_task *pop_result = calling_state.get_task_manager().pop_clean_task_chain(last_task); if (pop_result != nullptr) { base_task *clean_chain = pop_result; // We got a clean chain to fill up our resources. PLS_ASSERT(last_task->depth_ == clean_chain->depth_, "Resources must only reside in the correct depth!"); PLS_ASSERT(last_task != clean_chain, "We want to swap out the last task and its chain to use a clean one, thus they must differ."); PLS_ASSERT(check_task_chain_backward(*clean_chain), "Can only acquire clean chains for clean returns!"); // Acquire it/merge it with our task chain. this_task->prev_ = clean_chain; clean_chain->next_ = this_task; base_task *active_task = clean_chain; while (active_task->depth_ > 0) { active_task = active_task->prev_; } calling_state.set_active_task(active_task); // Jump back to the continuation in main scheduling loop. context_switcher::continuation result_cont = std::move(thread_state::get().main_continuation()); PLS_ASSERT(result_cont.valid(), "Must return a valid continuation."); return result_cont; } else { // Make sure that we are owner of this full continuation/task chain. last_task->next_ = this_task; // We are the last one working on this task. Thus the sync must be finished, continue working. calling_state.set_active_task(last_task); last_task->is_synchronized_ = true; // Jump to parent task and continue working on it. context_switcher::continuation result_cont = std::move(last_task->continuation_); PLS_ASSERT(result_cont.valid(), "Must return a valid continuation."); return result_cont; } } base_task &scheduler::task_chain_at(unsigned int depth, thread_state &calling_state) { // TODO: possible optimize with cache array at steal events base_task *result = calling_state.get_active_task(); while (result->depth_ > depth) { result = result->prev_; } while (result->depth_ < depth) { result = result->next_; } return *result; } void scheduler::terminate() { if (terminated_) { return; } terminated_ = true; sync_barrier_.wait(); for (unsigned int i = 0; i < num_threads_; i++) { if (reuse_thread_ && i == 0) { continue; } worker_threads_[i].join(); } } bool scheduler::check_task_chain_forward(base_task &start_task) { base_task *current = &start_task; while (current->next_) { if (current->next_->prev_ != current) { return false; } current = current->next_; } return true; } bool scheduler::check_task_chain_backward(base_task &start_task) { base_task *current = &start_task; while (current->prev_) { if (current->prev_->next_ != current) { return false; } current = current->prev_; } return true; } bool scheduler::check_task_chain(base_task &start_task) { return check_task_chain_backward(start_task) && check_task_chain_forward(start_task); } }