Commit 6ee522a3 by FritzFlorian

WIP: clean through scheduler code and fix obvious issues.

We still see very sporadic crashes, however the current version is at least a starting point for refactoring and debugging. Next steps have to be to re-enable tooling support (i.e. add code to let sanitizers do their work).
parent 2adb2d16
Pipeline #1391 failed with stages
in 27 seconds
#include "pls/internal/scheduling/scheduler.h" #include "pls/internal/scheduling/scheduler.h"
#include "pls/internal/scheduling/parallel_result.h" #include "pls/internal/scheduling/static_scheduler_memory.h"
#include "pls/internal/scheduling/scheduler_memory.h"
#include "pls/internal/helpers/profiler.h" #include "pls/internal/helpers/profiler.h"
using namespace pls::internal::scheduling; using namespace pls::internal::scheduling;
...@@ -14,33 +13,35 @@ using namespace pls::internal::scheduling; ...@@ -14,33 +13,35 @@ using namespace pls::internal::scheduling;
using namespace comparison_benchmarks::base; using namespace comparison_benchmarks::base;
parallel_result<short> conquer(fft::complex_vector::iterator data, int n) { void conquer(fft::complex_vector::iterator data, int n) {
if (n < 2) { if (n < 2) {
return parallel_result<short>{0}; return;
} }
fft::divide(data, n); fft::divide(data, n);
if (n <= fft::RECURSIVE_CUTOFF) { if (n <= fft::RECURSIVE_CUTOFF) {
fft::conquer(data, n / 2); fft::conquer(data, n / 2);
fft::conquer(data + n / 2, n / 2); fft::conquer(data + n / 2, n / 2);
fft::combine(data, n);
return parallel_result<short>{0};
} else { } else {
return scheduler::par([=]() { scheduler::spawn([data, n]() {
return conquer(data, n / 2); conquer(data, n / 2);
}, [=]() {
return conquer(data + n / 2, n / 2);
}).then([=](int, int) {
fft::combine(data, n);
return parallel_result<short>{0};
}); });
scheduler::spawn([data, n]() {
conquer(data + n / 2, n / 2);
});
scheduler::sync();
} }
fft::combine(data, n);
} }
constexpr int MAX_NUM_THREADS = 8; constexpr int MAX_NUM_THREADS = 8;
constexpr int MAX_NUM_TASKS = 64; constexpr int MAX_NUM_TASKS = 32;
constexpr int MAX_NUM_CONTS = 64; constexpr int MAX_STACK_SIZE = 1024 * 8;
constexpr int MAX_CONT_SIZE = 256;
static_scheduler_memory<MAX_NUM_THREADS,
MAX_NUM_TASKS,
MAX_STACK_SIZE> global_scheduler_memory;
int main(int argc, char **argv) { int main(int argc, char **argv) {
int num_threads; int num_threads;
...@@ -53,39 +54,21 @@ int main(int argc, char **argv) { ...@@ -53,39 +54,21 @@ int main(int argc, char **argv) {
fft::complex_vector data = fft::generate_input(); fft::complex_vector data = fft::generate_input();
static_scheduler_memory<MAX_NUM_THREADS, scheduler scheduler{global_scheduler_memory, (unsigned) num_threads};
MAX_NUM_TASKS,
MAX_NUM_CONTS,
MAX_CONT_SIZE> static_scheduler_memory;
scheduler scheduler{static_scheduler_memory, (unsigned int) num_threads};
for (int i = 0; i < fft::NUM_WARMUP_ITERATIONS; i++) {
scheduler.perform_work([&]() {
return scheduler::par([&]() {
return conquer(data.begin(), fft::SIZE);
}, []() {
return parallel_result<short>{0};
}).then([&](short, short) {
return parallel_result<int>{0};
});
});
}
for (int i = 0; i < fft::NUM_ITERATIONS; i++) { scheduler.perform_work([&]() {
scheduler.perform_work([&]() { for (int i = 0; i < fft::NUM_WARMUP_ITERATIONS; i++) {
runner.start_iteration(); conquer(data.begin(), fft::SIZE);
}
});
return scheduler::par([&]() { scheduler.perform_work([&]() {
return conquer(data.begin(), fft::SIZE); for (int i = 0; i < fft::NUM_ITERATIONS; i++) {
}, []() { runner.start_iteration();
return parallel_result<short>{0}; conquer(data.begin(), fft::SIZE);
}).then([&](short, short) { runner.end_iteration();
runner.end_iteration(); }
return parallel_result<int>{0}; });
});
});
}
runner.commit_results(true); runner.commit_results(true);
return 0; return 0;
......
...@@ -31,9 +31,9 @@ int pls_fib(int n) { ...@@ -31,9 +31,9 @@ int pls_fib(int n) {
return a + b; return a + b;
} }
constexpr int MAX_NUM_THREADS = 4; constexpr int MAX_NUM_THREADS = 8;
constexpr int MAX_NUM_TASKS = 32; constexpr int MAX_NUM_TASKS = 32;
constexpr int MAX_STACK_SIZE = 1024 * 32; constexpr int MAX_STACK_SIZE = 1024 * 1;
static_scheduler_memory<MAX_NUM_THREADS, static_scheduler_memory<MAX_NUM_THREADS,
MAX_NUM_TASKS, MAX_NUM_TASKS,
...@@ -48,7 +48,7 @@ int main(int argc, char **argv) { ...@@ -48,7 +48,7 @@ int main(int argc, char **argv) {
string full_directory = directory + "/PLS_v3/"; string full_directory = directory + "/PLS_v3/";
benchmark_runner runner{full_directory, test_name}; benchmark_runner runner{full_directory, test_name};
scheduler scheduler{global_scheduler_memory, (unsigned) num_threads, false}; scheduler scheduler{global_scheduler_memory, (unsigned) num_threads};
volatile int res; volatile int res;
scheduler.perform_work([&]() { scheduler.perform_work([&]() {
...@@ -58,7 +58,7 @@ int main(int argc, char **argv) { ...@@ -58,7 +58,7 @@ int main(int argc, char **argv) {
}); });
scheduler.perform_work([&]() { scheduler.perform_work([&]() {
for (int i = 0; i < fib::NUM_ITERATIONS * 100; i++) { for (int i = 0; i < fib::NUM_ITERATIONS; i++) {
runner.start_iteration(); runner.start_iteration();
res = pls_fib(fib::INPUT_N); res = pls_fib(fib::INPUT_N);
runner.end_iteration(); runner.end_iteration();
......
...@@ -16,6 +16,6 @@ ...@@ -16,6 +16,6 @@
void pls_error(const char *msg); void pls_error(const char *msg);
// TODO: Distinguish between debug/internal asserts and production asserts. // TODO: Distinguish between debug/internal asserts and production asserts.
#define PLS_ASSERT(cond, msg) if (!(cond)) { pls_error(msg); } #define PLS_ASSERT(cond, msg) // if (!(cond)) { pls_error(msg); }
#endif //PLS_ERROR_HANDLING_H #endif //PLS_ERROR_HANDLING_H
...@@ -46,6 +46,15 @@ class external_trading_deque { ...@@ -46,6 +46,15 @@ class external_trading_deque {
thread_id_ = id; thread_id_ = id;
} }
static optional<task *> peek_traded_object(task *target_task) {
traded_cas_field current_cas = target_task->external_trading_deque_cas_.load();
if (current_cas.is_filled_with_object()) {
return optional<task *>{current_cas.get_trade_object()};
} else {
return optional<task *>{};
}
}
static optional<task *> get_trade_object(task *target_task) { static optional<task *> get_trade_object(task *target_task) {
traded_cas_field current_cas = target_task->external_trading_deque_cas_.load(); traded_cas_field current_cas = target_task->external_trading_deque_cas_.load();
if (current_cas.is_filled_with_object()) { if (current_cas.is_filled_with_object()) {
...@@ -129,7 +138,7 @@ class external_trading_deque { ...@@ -129,7 +138,7 @@ class external_trading_deque {
struct peek_result { struct peek_result {
peek_result(optional<task *> top_task, stamped_integer top_pointer) : top_task_{std::move(top_task)}, peek_result(optional<task *> top_task, stamped_integer top_pointer) : top_task_{std::move(top_task)},
top_pointer_{top_pointer} {}; top_pointer_{top_pointer} {};
optional<task *> top_task_; optional<task *> top_task_;
stamped_integer top_pointer_; stamped_integer top_pointer_;
}; };
......
...@@ -56,7 +56,7 @@ struct alignas(base::system_details::CACHE_LINE_SIZE) task { ...@@ -56,7 +56,7 @@ struct alignas(base::system_details::CACHE_LINE_SIZE) task {
// Work-Stealing // Work-Stealing
std::atomic<traded_cas_field> external_trading_deque_cas_{}; std::atomic<traded_cas_field> external_trading_deque_cas_{};
task *resource_stack_next_{}; std::atomic<task *> resource_stack_next_{};
std::atomic<data_structures::stamped_integer> resource_stack_root_{{0, 0}}; std::atomic<data_structures::stamped_integer> resource_stack_root_{{0, 0}};
bool clean_; bool clean_;
......
...@@ -52,7 +52,7 @@ class task_manager { ...@@ -52,7 +52,7 @@ class task_manager {
void spawn_child(F &&lambda); void spawn_child(F &&lambda);
void sync(); void sync();
bool steal_task(task_manager &stealing_task_manager); task* steal_task(task_manager &stealing_task_manager);
bool try_clean_return(context_switcher::continuation &result_cont); bool try_clean_return(context_switcher::continuation &result_cont);
......
...@@ -19,18 +19,29 @@ namespace scheduling { ...@@ -19,18 +19,29 @@ namespace scheduling {
template<typename F> template<typename F>
void task_manager::spawn_child(F &&lambda) { void task_manager::spawn_child(F &&lambda) {
auto *spawning_task_manager = this; auto *spawning_task_manager = this;
auto continuation = auto *last_task = spawning_task_manager->active_task_;
active_task_->next_->run_as_task([lambda, spawning_task_manager](context_switcher::continuation cont) { auto *spawned_task = spawning_task_manager->active_task_->next_;
auto *last_task = spawning_task_manager->active_task_;
auto *this_task = spawning_task_manager->active_task_->next_;
auto continuation =
spawned_task->run_as_task([=](context_switcher::continuation cont) {
// allow stealing threads to continue the last task.
last_task->continuation_ = std::move(cont); last_task->continuation_ = std::move(cont);
spawning_task_manager->active_task_ = this_task;
// we are now executing the new task, allow others to steal the last task continuation.
spawning_task_manager->active_task_ = spawned_task;
spawning_task_manager->deque_.push_bot(last_task); spawning_task_manager->deque_.push_bot(last_task);
// execute the lambda itself, which could lead to a different thread returning.
lambda(); lambda();
auto *syncing_task_manager = &thread_state::get().get_task_manager(); auto *syncing_task_manager = &thread_state::get().get_task_manager();
PLS_ASSERT(syncing_task_manager->active_task_ == spawned_task,
"Task manager must always point its active task onto whats executing.");
// try to pop a task of the syncing task manager.
// possible outcomes:
// - this is a different task manager, it must have an empty deque and fail
// - this is the same task manager and someone stole last tasks, thus this will fail
// - this is the same task manager and no one stole the last task, this this will succeed
auto pop_result = syncing_task_manager->deque_.pop_bot(); auto pop_result = syncing_task_manager->deque_.pop_bot();
if (pop_result) { if (pop_result) {
// Fast path, simply continue execution where we left of before spawn. // Fast path, simply continue execution where we left of before spawn.
...@@ -44,7 +55,7 @@ void task_manager::spawn_child(F &&lambda) { ...@@ -44,7 +55,7 @@ void task_manager::spawn_child(F &&lambda) {
syncing_task_manager->active_task_ = last_task; syncing_task_manager->active_task_ = last_task;
return std::move(last_task->continuation_); return std::move(last_task->continuation_);
} else { } else {
// Slow path, the continuation was stolen. // Slow path, the last task was stolen. Sync using the resource stack.
context_switcher::continuation result_cont; context_switcher::continuation result_cont;
if (syncing_task_manager->try_clean_return(result_cont)) { if (syncing_task_manager->try_clean_return(result_cont)) {
// We return back to the main scheduling loop // We return back to the main scheduling loop
......
...@@ -62,7 +62,6 @@ void scheduler::work_thread_work_section() { ...@@ -62,7 +62,6 @@ void scheduler::work_thread_work_section() {
auto &my_task_manager = my_state.get_task_manager(); auto &my_task_manager = my_state.get_task_manager();
auto const num_threads = my_state.get_scheduler().num_threads(); auto const num_threads = my_state.get_scheduler().num_threads();
auto const my_id = my_state.get_id();
if (my_state.get_id() == 0) { if (my_state.get_id() == 0) {
// Main Thread, kick off by executing the user's main code block. // Main Thread, kick off by executing the user's main code block.
...@@ -72,60 +71,46 @@ void scheduler::work_thread_work_section() { ...@@ -72,60 +71,46 @@ void scheduler::work_thread_work_section() {
while (!work_section_done_) { while (!work_section_done_) {
PLS_ASSERT(my_task_manager.check_task_chain(), "Must start stealing with a clean task chain."); PLS_ASSERT(my_task_manager.check_task_chain(), "Must start stealing with a clean task chain.");
// Steal Routine (will be continuously executed when there are no more fall through's). // TODO: move steal routine into separate function
// TODO: move into separate function const size_t target = my_state.get_rand() % num_threads;
const size_t offset = my_state.get_rand() % num_threads; if (target == my_state.get_id()) {
const size_t max_tries = num_threads; continue;
for (size_t i = 0; i < max_tries; i++) { }
// Perform steal
size_t target = (offset + i) % num_threads; auto &target_state = my_state.get_scheduler().thread_state_for(target);
auto &target_state = my_state.get_scheduler().thread_state_for(target); task *traded_task = target_state.get_task_manager().steal_task(my_task_manager);
bool steal_success = target_state.get_task_manager().steal_task(my_task_manager);
if (traded_task != nullptr) {
if (steal_success) { // The stealing procedure correctly changed our chain and active task.
// The stealing procedure correctly changed our chain and active task. // Now we need to perform the 'post steal' actions (manage resources and execute the stolen task).
// Now we need to perform the 'post steal' actions (manage resources and execute the stolen task). PLS_ASSERT(my_task_manager.check_task_chain_forward(&my_task_manager.get_active_task()),
PLS_ASSERT(my_task_manager.check_task_chain_forward(&my_task_manager.get_active_task()), "We are sole owner of this chain, it has to be valid!");
"We are sole owner of this chain, it has to be valid!");
// Move the traded in resource of this active task over to the stack of resources.
// Move the traded in resource of this active task over to the stack of resources. auto *stolen_task = &my_task_manager.get_active_task();
auto *stolen_task = &my_task_manager.get_active_task(); // Push the traded in resource on the resource stack to clear the traded_field for later steals/spawns.
traded_cas_field stolen_task_cas = stolen_task->external_trading_deque_cas_.load(); my_task_manager.push_resource_on_task(stolen_task, traded_task);
if (stolen_task_cas.is_filled_with_object()) {
// Push the traded in resource on the resource stack to clear the traded_field for later steals/spawns. auto optional_exchanged_task = external_trading_deque::get_trade_object(stolen_task);
auto *exchanged_task = stolen_task_cas.get_trade_object(); if (optional_exchanged_task) {
my_task_manager.push_resource_on_task(stolen_task, exchanged_task); // All good, we pushed the task over to the stack, nothing more to do
PLS_ASSERT(*optional_exchanged_task == traded_task,
traded_cas_field empty_field; "We are currently executing this, no one else can put another task in this field!");
traded_cas_field expected_field; } else {
expected_field.fill_with_trade_object(exchanged_task); // The last other active thread took it as its spare resource...
if (stolen_task->external_trading_deque_cas_.compare_exchange_strong(expected_field, empty_field)) { // ...remove our traded object from the stack again (it must be empty now and no one must access it anymore).
// All good, nothing more to do auto current_root = stolen_task->resource_stack_root_.load();
} else { current_root.stamp++;
// The last other active thread took it as its spare resource... current_root.value = 0;
// ...remove our traded object from the stack again (it must be empty now and no one must access it anymore). stolen_task->resource_stack_root_.store(current_root);
PLS_ASSERT(expected_field.is_empty(),
"Must be empty, as otherwise no one will steal the 'spare traded task'.");
auto current_root = stolen_task->resource_stack_root_.load();
current_root.stamp++;
current_root.value = 0;
stolen_task->resource_stack_root_.store(current_root);
}
}
// Execute the stolen task by jumping to it's continuation.
PLS_ASSERT(stolen_task->continuation_.valid(),
"A task that we can steal must have a valid continuation for us to start working.");
context_switcher::switch_context(std::move(stolen_task->continuation_));
// ...now we are done with this steal attempt, loop over.
break;
} }
// Execute the stolen task by jumping to it's continuation.
PLS_ASSERT(stolen_task->continuation_.valid(),
"A task that we can steal must have a valid continuation for us to start working.");
context_switcher::switch_context(std::move(stolen_task->continuation_));
// We will continue execution in this line when we finished the stolen work.
} }
// if (!my_cont_manager.falling_through()) {
// base::this_thread::sleep(5);
// }
} }
} }
......
...@@ -32,48 +32,46 @@ static task *find_task(unsigned id, unsigned depth) { ...@@ -32,48 +32,46 @@ static task *find_task(unsigned id, unsigned depth) {
return thread_state::get().get_scheduler().thread_state_for(id).get_task_manager().get_this_thread_task(depth); return thread_state::get().get_scheduler().thread_state_for(id).get_task_manager().get_this_thread_task(depth);
} }
bool task_manager::steal_task(task_manager &stealing_task_manager) { task *task_manager::steal_task(task_manager &stealing_task_manager) {
PLS_ASSERT(stealing_task_manager.active_task_->depth_ == 0, "Must only steal with clean task chain."); PLS_ASSERT(stealing_task_manager.active_task_->depth_ == 0, "Must only steal with clean task chain.");
PLS_ASSERT(stealing_task_manager.check_task_chain(), "Must only steal with clean task chain.");
auto peek = deque_.peek_top(); auto peek = deque_.peek_top();
auto optional_target_task = peek.top_task_; if (peek.top_task_) {
auto target_top = peek.top_pointer_;
if (optional_target_task) {
PLS_ASSERT(stealing_task_manager.check_task_chain(), "We are stealing, must not have a bad chain here!");
// search for the task we want to trade in // search for the task we want to trade in
task *target_task = *optional_target_task; task *stolen_task = *peek.top_task_;
task *traded_task = stealing_task_manager.active_task_; task *traded_task = stealing_task_manager.active_task_;
for (unsigned i = 0; i < target_task->depth_; i++) { for (unsigned i = 0; i < stolen_task->depth_; i++) {
traded_task = traded_task->next_; traded_task = traded_task->next_;
} }
// keep a reference to the rest of the task chain that we keep // keep a reference to the rest of the task chain that we keep
task *next_own_task = traded_task->next_; task *next_own_task = traded_task->next_;
// 'unchain' the traded tasks (to help us find bugs only) // 'unchain' the traded tasks (to help us find bugs)
traded_task->next_ = nullptr; traded_task->next_ = nullptr;
auto optional_result_task = deque_.pop_top(traded_task, target_top); // perform the actual pop operation
if (optional_result_task) { auto pop_result_task = deque_.pop_top(traded_task, peek.top_pointer_);
PLS_ASSERT(target_task->thread_id_ != traded_task->thread_id_, if (pop_result_task) {
PLS_ASSERT(stolen_task->thread_id_ != traded_task->thread_id_,
"It is impossible to steal an task we already own!"); "It is impossible to steal an task we already own!");
PLS_ASSERT(*optional_result_task == target_task, PLS_ASSERT(*pop_result_task == stolen_task,
"We must only steal the task that we peeked at!"); "We must only steal the task that we peeked at!");
// the steal was a success, link the chain so we own the stolen part // the steal was a success, link the chain so we own the stolen part
target_task->next_ = next_own_task; stolen_task->next_ = next_own_task;
next_own_task->prev_ = target_task; next_own_task->prev_ = stolen_task;
stealing_task_manager.set_active_task(target_task); stealing_task_manager.active_task_ = stolen_task;
return true; return traded_task;
} else { } else {
// the steal failed, reset our chain to its old, clean state (re-link what we have broken) // the steal failed, reset our chain to its old, clean state (re-link what we have broken)
traded_task->next_ = next_own_task; traded_task->next_ = next_own_task;
return false; return nullptr;
} }
} else { } else {
return false; return nullptr;
} }
} }
...@@ -92,11 +90,11 @@ void task_manager::push_resource_on_task(task *target_task, task *spare_task_cha ...@@ -92,11 +90,11 @@ void task_manager::push_resource_on_task(task *target_task, task *spare_task_cha
if (current_root.value == 0) { if (current_root.value == 0) {
// Empty, simply push in with no successor // Empty, simply push in with no successor
spare_task_chain->resource_stack_next_ = nullptr; spare_task_chain->resource_stack_next_.store(nullptr, std::memory_order_relaxed);
} else { } else {
// Already an entry. Find it's corresponding task and set it as our successor. // Already an entry. Find it's corresponding task and set it as our successor.
auto *current_root_task = find_task(current_root.value - 1, target_task->depth_); auto *current_root_task = find_task(current_root.value - 1, target_task->depth_);
spare_task_chain->resource_stack_next_ = current_root_task; spare_task_chain->resource_stack_next_.store(current_root_task, std::memory_order_relaxed);
} }
} while (!target_task->resource_stack_root_.compare_exchange_strong(current_root, target_root)); } while (!target_task->resource_stack_root_.compare_exchange_strong(current_root, target_root));
...@@ -108,15 +106,15 @@ task *task_manager::pop_resource_from_task(task *target_task) { ...@@ -108,15 +106,15 @@ task *task_manager::pop_resource_from_task(task *target_task) {
task *output_task; task *output_task;
do { do {
current_root = target_task->resource_stack_root_.load(); current_root = target_task->resource_stack_root_.load();
target_root.stamp = current_root.stamp + 1;
if (current_root.value == 0) { if (current_root.value == 0) {
// Empty... // Empty...
return nullptr; return nullptr;
} else { } else {
// Found something, try to pop it // Found something, try to pop it
auto *current_root_task = find_task(current_root.value - 1, target_task->depth_); auto *current_root_task = find_task(current_root.value - 1, target_task->depth_);
auto *next_stack_task = current_root_task->resource_stack_next_; auto *next_stack_task = current_root_task->resource_stack_next_.load(std::memory_order_relaxed);
target_root.stamp = current_root.stamp + 1;
target_root.value = next_stack_task != nullptr ? next_stack_task->thread_id_ + 1 : 0; target_root.value = next_stack_task != nullptr ? next_stack_task->thread_id_ + 1 : 0;
output_task = current_root_task; output_task = current_root_task;
...@@ -124,19 +122,21 @@ task *task_manager::pop_resource_from_task(task *target_task) { ...@@ -124,19 +122,21 @@ task *task_manager::pop_resource_from_task(task *target_task) {
} while (!target_task->resource_stack_root_.compare_exchange_strong(current_root, target_root)); } while (!target_task->resource_stack_root_.compare_exchange_strong(current_root, target_root));
PLS_ASSERT(check_task_chain_backward(output_task), "Must only pop proper task chains."); PLS_ASSERT(check_task_chain_backward(output_task), "Must only pop proper task chains.");
output_task->resource_stack_next_.store(nullptr, std::memory_order_relaxed);
return output_task; return output_task;
} }
void task_manager::sync() { void task_manager::sync() {
auto continuation = active_task_->next_->run_as_task([this](context_switcher::continuation cont) { auto *spawning_task_manager = this;
auto *last_task = active_task_; auto *last_task = spawning_task_manager->active_task_;
auto *this_task = active_task_->next_; auto *spawned_task = spawning_task_manager->active_task_->next_;
auto continuation = spawned_task->run_as_task([=](context_switcher::continuation cont) {
last_task->continuation_ = std::move(cont); last_task->continuation_ = std::move(cont);
active_task_ = this_task; spawning_task_manager->active_task_ = spawned_task;
context_switcher::continuation result_cont; context_switcher::continuation result_cont;
if (try_clean_return(result_cont)) { if (spawning_task_manager->try_clean_return(result_cont)) {
// We return back to the main scheduling loop // We return back to the main scheduling loop
active_task_->clean_ = true; active_task_->clean_ = true;
return result_cont; return result_cont;
...@@ -147,21 +147,17 @@ void task_manager::sync() { ...@@ -147,21 +147,17 @@ void task_manager::sync() {
} }
}); });
if (continuation.valid()) { PLS_ASSERT(!continuation.valid(),
// We jumped in here from the main loop, keep track! "We only return to a sync point, never jump to it directly."
thread_state::get().set_main_continuation(std::move(continuation)); "This must therefore never return an unfinished fiber/continuation.");
}
} }
bool task_manager::try_clean_return(context_switcher::continuation &result_cont) { bool task_manager::try_clean_return(context_switcher::continuation &result_cont) {
task *this_task = active_task_; task *this_task = active_task_;
task *last_task = active_task_->prev_; task *last_task = active_task_->prev_;
if (last_task == nullptr) { PLS_ASSERT(last_task != nullptr,
// We finished the final task of the computation, return to the scheduling loop. "Must never try to return from a task at level 0 (no last task), as we must have a target to return to.");
result_cont = thread_state::get().get_main_continuation();
return true;
}
// Try to get a clean resource chain to go back to the main stealing loop // Try to get a clean resource chain to go back to the main stealing loop
task *clean_chain = pop_resource_from_task(last_task); task *clean_chain = pop_resource_from_task(last_task);
...@@ -179,28 +175,30 @@ bool task_manager::try_clean_return(context_switcher::continuation &result_cont) ...@@ -179,28 +175,30 @@ bool task_manager::try_clean_return(context_switcher::continuation &result_cont)
// We got a clean chain to continue working on. // We got a clean chain to continue working on.
PLS_ASSERT(last_task->depth_ == clean_chain->depth_, PLS_ASSERT(last_task->depth_ == clean_chain->depth_,
"Resources must only reside in the correct depth!"); "Resources must only reside in the correct depth!");
PLS_ASSERT(check_task_chain_backward(clean_chain), "Can only aquire clean chains for clean returns!"); PLS_ASSERT(clean_chain != last_task,
"We want to swap out the last task and its chain to use a clean one, thus they must differ.");
PLS_ASSERT(check_task_chain_backward(clean_chain),
"Can only acquire clean chains for clean returns!");
this_task->prev_ = clean_chain; this_task->prev_ = clean_chain;
clean_chain->next_ = this_task; clean_chain->next_ = this_task;
// Walk back chain to make first task active // Walk back chain to make first task active
active_task_ = clean_chain; active_task_ = clean_chain;
while (active_task_->prev_ != nullptr) { while (active_task_->prev_ != nullptr) {
active_task_ = active_task_->prev_; active_task_ = active_task_->prev_;
} }
PLS_ASSERT(check_task_chain(), "We just aquired a clean chain..."); // jump back to the continuation in main scheduling loop, time to steal some work
// jump back to continuation in main scheduling loop, time to steal some work
result_cont = thread_state::get().get_main_continuation(); result_cont = thread_state::get().get_main_continuation();
return true; return true;
} else { } else {
// Make sure that we are owner fo this full continuation/task chain.
last_task->next_ = this_task;
this_task->prev_ = last_task;
// We are the last one working on this task. Thus the sync must be finished, continue working. // We are the last one working on this task. Thus the sync must be finished, continue working.
active_task_ = last_task; active_task_ = last_task;
// Make sure that we are owner fo this full continuation/task chain.
active_task_->next_ = this_task;
this_task->prev_ = active_task_;
result_cont = std::move(last_task->continuation_); result_cont = std::move(last_task->continuation_);
return false; return false;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment