#include "pls/internal/scheduling/scheduler.h" #include "context_switcher/context_switcher.h" #include "pls/internal/scheduling/strain_local_resource.h" #include "pls/internal/build_flavour.h" #include "pls/internal/base/error_handling.h" #include "pls/internal/base/futex_wrapper.h" #include namespace pls::internal::scheduling { scheduler::scheduler(unsigned int num_threads, size_t computation_depth, size_t stack_size, bool reuse_thread, size_t serial_stack_size) : scheduler(num_threads, computation_depth, stack_size, reuse_thread, serial_stack_size, base::mmap_stack_allocator{}) {} scheduler::~scheduler() { terminate(); } void scheduler::work_thread_main_loop() { auto &scheduler = thread_state::get().get_scheduler(); while (true) { // Wait to be triggered scheduler.sync_barrier_.wait(); // Check for shutdown if (scheduler.terminated_) { return; } scheduler.work_thread_work_section(); // Sync back with main thread scheduler.sync_barrier_.wait(); } } void scheduler::work_thread_work_section() { thread_state &my_state = thread_state::get(); my_state.set_scheduler_active(true); unsigned const num_threads = my_state.get_scheduler().num_threads(); if (my_state.get_thread_id() == 0) { // Main Thread, kick off by executing the user's main code block. main_thread_starter_function_->run(); } unsigned int failed_steals = 0; while (!work_section_done_) { #if PLS_PROFILING_ENABLED my_state.get_scheduler().profiler_.stealing_start(my_state.get_thread_id()); #endif PLS_ASSERT_EXPENSIVE(check_task_chain(*my_state.get_active_task()), "Must start stealing with a clean task chain."); size_t target; do { target = my_state.get_rand() % num_threads; } while (target == my_state.get_thread_id()); thread_state &target_state = my_state.get_scheduler().thread_state_for(target); #if PLS_SLEEP_WORKERS_ON_EMPTY queue_empty_flag_retry_steal: data_structures::stamped_integer target_queue_empty_flag = target_state.get_queue_empty_flag().load(std::memory_order_relaxed); #endif base_task *stolen_task; base_task *chain_after_stolen_task; bool cas_success; do { #if PLS_PROFILING_ENABLED my_state.get_scheduler().profiler_.stealing_cas_op(my_state.get_thread_id()); #endif std::tie(stolen_task, chain_after_stolen_task, cas_success) = target_state.get_task_manager().steal_task(my_state); } while (!cas_success); // re-try on cas-conflicts, as in classic ws if (stolen_task) { // Keep task chain consistent. We want to appear as if we are working an a branch upwards of the stolen task. stolen_task->next_ = chain_after_stolen_task; chain_after_stolen_task->prev_ = stolen_task; my_state.set_active_task(stolen_task); // Keep locally owned resources consistent. auto *stolen_resources = stolen_task->attached_resources_.load(std::memory_order_relaxed); strain_local_resource::acquire_locally(stolen_resources, my_state.get_thread_id()); PLS_ASSERT_EXPENSIVE(check_task_chain_forward(*my_state.get_active_task()), "We are sole owner of this chain, it has to be valid!"); // Execute the stolen task by jumping to it's continuation. PLS_ASSERT(stolen_task->continuation_.valid(), "A task that we can steal must have a valid continuation for us to start working."); stolen_task->is_synchronized_ = false; #if PLS_PROFILING_ENABLED my_state.get_scheduler().profiler_.stealing_end(my_state.get_thread_id(), true); my_state.get_scheduler().profiler_.task_start_running(my_state.get_thread_id(), stolen_task->profiling_node_); #endif context_switcher::switch_context(std::move(stolen_task->continuation_)); // We will continue execution in this line when we finished the stolen work. failed_steals = 0; } else { // TODO: tune value for when we start yielding const unsigned YIELD_AFTER = 1; failed_steals++; if (failed_steals >= YIELD_AFTER) { std::this_thread::yield(); #if PLS_PROFILING_ENABLED my_state.get_scheduler().profiler_.stealing_end(my_state.get_thread_id(), false); #endif #if PLS_SLEEP_WORKERS_ON_EMPTY switch (target_queue_empty_flag.value_) { case EMPTY_QUEUE_STATE::QUEUE_NON_EMPTY: { // We found the queue empty, but the flag says it should still be full. // We want to declare it empty, bet we need to re-check the queue in a sub-step to avoid races. data_structures::stamped_integer maybe_empty_flag{target_queue_empty_flag.stamp_ + 1, EMPTY_QUEUE_STATE::QUEUE_MAYBE_EMPTY}; if (target_state.get_queue_empty_flag().compare_exchange_strong(target_queue_empty_flag, maybe_empty_flag, std::memory_order_acq_rel)) { goto queue_empty_flag_retry_steal; } break; } case EMPTY_QUEUE_STATE::QUEUE_MAYBE_EMPTY: { // We found the queue empty and it was already marked as maybe empty. // We can safely mark it empty and increment the central counter. data_structures::stamped_integer empty_flag{target_queue_empty_flag.stamp_ + 1, EMPTY_QUEUE_STATE::QUEUE_EMPTY}; if (target_state.get_queue_empty_flag().compare_exchange_strong(target_queue_empty_flag, empty_flag, std::memory_order_acq_rel)) { // We marked it empty, now its our duty to modify the central counter my_state.get_scheduler().empty_queue_increase_counter(); } break; } case EMPTY_QUEUE_STATE::QUEUE_EMPTY: { // The queue was already marked empty, just do nothing break; } } // Disregarding if we found the thread empty, we should check if we can put ourself to sleep my_state.get_scheduler().empty_queue_try_sleep_worker(); #endif } } } my_state.set_scheduler_active(false); } void scheduler::sync_internal() { thread_state &syncing_state = thread_state::get(); base_task *active_task = syncing_state.get_active_task(); base_task *spawned_task = active_task->next_; #if PLS_PROFILING_ENABLED syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(), active_task->stack_memory_, active_task->stack_size_, active_task->profiling_node_); syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(), active_task->profiling_node_); auto *next_dag_node = syncing_state.get_scheduler().profiler_.task_sync(syncing_state.get_thread_id(), active_task->profiling_node_); active_task->profiling_node_ = next_dag_node; #endif if (active_task->is_synchronized_) { #if PLS_PROFILING_ENABLED thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(), thread_state::get().get_active_task()->profiling_node_); #endif return; // We are already the sole owner of last_task } else { auto continuation = spawned_task->run_as_task([active_task, spawned_task, &syncing_state](context_switcher::continuation cont) { active_task->continuation_ = std::move(cont); syncing_state.set_active_task(spawned_task); return slow_return(syncing_state, true); }); PLS_ASSERT(!continuation.valid(), "We only return to a sync point, never jump to it directly." "This must therefore never return an unfinished fiber/continuation."); #if PLS_PROFILING_ENABLED thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(), thread_state::get().get_active_task()->profiling_node_); #endif return; // We cleanly synced to the last one finishing work on last_task } } context_switcher::continuation scheduler::slow_return(thread_state &calling_state, bool in_sync) { base_task *this_task = calling_state.get_active_task(); PLS_ASSERT(this_task->depth_ > 0, "Must never try to return from a task at level 0 (no last task), as we must have a target to return to."); base_task *last_task = this_task->prev_; // Slow return means we try to finish the child 'this_task' of 'last_task' and we // do not know if we are the last child to finish. // If we are not the last one, we get a spare task chain for our resources and can return to the main scheduling loop. base_task *pop_result = calling_state.get_task_manager().pop_clean_task_chain(last_task); if (pop_result != nullptr) { base_task *clean_chain = pop_result; // We got a clean chain to fill up our resources. PLS_ASSERT(last_task->depth_ == clean_chain->depth_, "Resources must only reside in the correct depth!"); PLS_ASSERT(last_task != clean_chain, "We want to swap out the last task and its chain to use a clean one, thus they must differ."); PLS_ASSERT_EXPENSIVE(check_task_chain_backward(*clean_chain), "Can only acquire clean chains for clean returns!"); // Acquire it/merge it with our task chain. this_task->prev_ = clean_chain; clean_chain->next_ = this_task; // Keep locally owned resources consistent. auto *clean_resources = clean_chain->attached_resources_.load(std::memory_order_relaxed); strain_local_resource::acquire_locally(clean_resources, calling_state.get_thread_id()); base_task *active_task = clean_chain; while (active_task->depth_ > 0) { active_task = active_task->prev_; } calling_state.set_active_task(active_task); // Jump back to the continuation in main scheduling loop. context_switcher::continuation result_cont = std::move(thread_state::get().main_continuation()); PLS_ASSERT(result_cont.valid(), "Must return a valid continuation (to main)."); return result_cont; } else { // Make sure that we are owner of this full continuation/task chain. last_task->next_ = this_task; // We are the last one working on this task. Thus the sync must be finished, continue working. calling_state.set_active_task(last_task); last_task->is_synchronized_ = true; // Jump to parent task and continue working on it. if (in_sync) { PLS_ASSERT(last_task->continuation_.valid(), "Must return a valid continuation (to last task) in sync."); } else { PLS_ASSERT(last_task->continuation_.valid(), "Must return a valid continuation (to last task) in spawn."); } context_switcher::continuation result_cont = std::move(last_task->continuation_); return result_cont; } } base_task *scheduler::get_trade_task(base_task *stolen_task, thread_state &calling_state) { // Get task itself base_task *result = calling_state.get_active_task(); while (result->depth_ > stolen_task->depth_) { result = result->prev_; } while (result->depth_ < stolen_task->depth_) { result = result->next_; } // Attach other resources we need to trade to it auto *stolen_resources = stolen_task->attached_resources_.load(std::memory_order_relaxed); auto *traded_resources = strain_local_resource::get_local_copy(stolen_resources, calling_state.get_thread_id()); result->attached_resources_.store(traded_resources, std::memory_order_relaxed); return result; } void scheduler::terminate() { if (terminated_) { return; } terminated_ = true; sync_barrier_.wait(); for (unsigned int i = 0; i < num_threads_; i++) { if (reuse_thread_ && i == 0) { continue; } worker_threads_[i].join(); } } bool scheduler::check_task_chain_forward(base_task &start_task) { base_task *current = &start_task; while (current->next_) { if (current->next_->prev_ != current) { return false; } current = current->next_; } return true; } bool scheduler::check_task_chain_backward(base_task &start_task) { base_task *current = &start_task; while (current->prev_) { if (current->prev_->next_ != current) { return false; } current = current->prev_; } return true; } bool scheduler::check_task_chain(base_task &start_task) { return check_task_chain_backward(start_task) && check_task_chain_forward(start_task); } #if PLS_SLEEP_WORKERS_ON_EMPTY // TODO: relax memory orderings void scheduler::empty_queue_try_sleep_worker() { int32_t counter_value = empty_queue_counter_.load(); if (counter_value == num_threads()) { #if PLS_PROFILING_ENABLED get_profiler().sleep_start(thread_state::get().get_thread_id()); #endif threads_sleeping_++; base::futex_wait((int32_t *) &empty_queue_counter_, num_threads()); threads_sleeping_--; base::futex_wakeup((int32_t *) &empty_queue_counter_, 1); #if PLS_PROFILING_ENABLED get_profiler().sleep_stop(thread_state::get().get_thread_id()); #endif } } void scheduler::empty_queue_increase_counter() { empty_queue_counter_.fetch_add(1); } void scheduler::empty_queue_decrease_counter_and_wake() { empty_queue_counter_.fetch_sub(1); if (threads_sleeping_.load() > 0) { base::futex_wakeup((int32_t *) &empty_queue_counter_, 1); } } #endif }