Commit 2adb2d16 by FritzFlorian

WIP: Add simple external trading deque test.

The current version has race conditions and is hard to debug (especially because of the fibers, if a wrong thread executes on a fiber we get segfalts very fast). To combat this mess we now refactor the code bit by bit while also adding tests where it can be done with reasonably effort).
parent 22f4c598
Pipeline #1390 failed with stages
in 31 seconds
...@@ -32,8 +32,8 @@ int pls_fib(int n) { ...@@ -32,8 +32,8 @@ int pls_fib(int n) {
} }
constexpr int MAX_NUM_THREADS = 4; constexpr int MAX_NUM_THREADS = 4;
constexpr int MAX_NUM_TASKS = 64; constexpr int MAX_NUM_TASKS = 32;
constexpr int MAX_STACK_SIZE = 4096 * 8; constexpr int MAX_STACK_SIZE = 1024 * 32;
static_scheduler_memory<MAX_NUM_THREADS, static_scheduler_memory<MAX_NUM_THREADS,
MAX_NUM_TASKS, MAX_NUM_TASKS,
...@@ -48,7 +48,7 @@ int main(int argc, char **argv) { ...@@ -48,7 +48,7 @@ int main(int argc, char **argv) {
string full_directory = directory + "/PLS_v3/"; string full_directory = directory + "/PLS_v3/";
benchmark_runner runner{full_directory, test_name}; benchmark_runner runner{full_directory, test_name};
scheduler scheduler{global_scheduler_memory, (unsigned) num_threads}; scheduler scheduler{global_scheduler_memory, (unsigned) num_threads, false};
volatile int res; volatile int res;
scheduler.perform_work([&]() { scheduler.perform_work([&]() {
...@@ -58,7 +58,7 @@ int main(int argc, char **argv) { ...@@ -58,7 +58,7 @@ int main(int argc, char **argv) {
}); });
scheduler.perform_work([&]() { scheduler.perform_work([&]() {
for (int i = 0; i < fib::NUM_ITERATIONS; i++) { for (int i = 0; i < fib::NUM_ITERATIONS * 100; i++) {
runner.start_iteration(); runner.start_iteration();
res = pls_fib(fib::INPUT_N); res = pls_fib(fib::INPUT_N);
runner.end_iteration(); runner.end_iteration();
......
#include <utility> #include <utility>
#include <cstdio> #include <cstdio>
#include <chrono>
#include "context_switcher/context_switcher.h" #include "context_switcher/context_switcher.h"
using namespace context_switcher; using namespace context_switcher;
using namespace std; using namespace std;
const size_t NUM_RUNS = 1000;
// Memory for custom stack and continuation semantics // Memory for custom stack and continuation semantics
const size_t STACK_SIZE = 512 * 1; const size_t STACK_SIZE = 512 * 16;
const size_t NUM_STACKS = 64; const size_t NUM_STACKS = 4;
char custom_stacks[NUM_STACKS][STACK_SIZE]; char custom_stacks[NUM_STACKS][STACK_SIZE];
int fib(int n) { volatile int result;
if (n <= 1) {
return 1; int main() {
context_switcher::continuation outer_cont = enter_context(custom_stacks[0], STACK_SIZE, [](continuation &&main_cont) {
enter_context(custom_stacks[1], STACK_SIZE, [&main_cont](continuation &&middle_cont) {
enter_context(custom_stacks[2], STACK_SIZE, [&main_cont](continuation &&inner_cont) {
for (int i = 0; i < 10; i++) {
printf("Inner %d\n", i);
main_cont = context_switcher::switch_context(std::move(main_cont));
} }
int a, b; return std::move(inner_cont);
enter_context(custom_stacks[n], STACK_SIZE, [n, &a](continuation &&cont) {
a = fib(n - 1);
return std::move(cont);
}); });
enter_context(custom_stacks[n], STACK_SIZE, [n, &b](continuation &&cont) {
b = fib(n - 2); return std::move(middle_cont);
return std::move(cont);
}); });
return a + b; return std::move(main_cont);
} });
volatile int result; for (int i = 0; i < 10; i++) {
int main() { printf("Outer %d\n", i);
auto start_time = chrono::steady_clock::now(); outer_cont = context_switcher::switch_context(std::move(outer_cont));
for (unsigned int i = 0; i < NUM_RUNS; i++) {
result = fib(18);
} }
auto end_time = chrono::steady_clock::now();
auto time = chrono::duration_cast<chrono::microseconds>(end_time - start_time).count();
printf("%f", (float) time / NUM_RUNS);
return 0; return 0;
} }
...@@ -17,10 +17,11 @@ __cs_enter_context: ...@@ -17,10 +17,11 @@ __cs_enter_context:
# Variables # Variables
# r12 = temporary for the old stack pointer # r12 = temporary for the old stack pointer
############### Save State ############### pushq %rbp
# Make space for all register state we will store. movq %rsp, %rbp
leaq -0x38(%rsp), %rsp subq $0x38, %rsp
############### Save State ###############
# Store calee saved general registers. # Store calee saved general registers.
movq %r12, 0x00(%rsp) movq %r12, 0x00(%rsp)
movq %r13, 0x08(%rsp) movq %r13, 0x08(%rsp)
...@@ -42,6 +43,13 @@ __cs_enter_context: ...@@ -42,6 +43,13 @@ __cs_enter_context:
# Switch to new stack pointer. # Switch to new stack pointer.
movq %rdi, %rsp movq %rdi, %rsp
# Init the new stack (in our case we want the stack trace to 'stick' to where it was created.
# This will not necessary be valid all the time (thus returning into it is not ok), but
# we only use it for debugging as we explicitly state throwing exceptions beyond the fiber is not ok.
# pushq 0x8(%rbp)
# pushq 0x0(%rbp)
# movq %rsp, %rbp
# Perform actual function call, this will now be on the new stack # Perform actual function call, this will now be on the new stack
# rdi = first parametor to callback (continuation) # rdi = first parametor to callback (continuation)
# rsi = second parameter to callback (arbetary pointer) # rsi = second parameter to callback (arbetary pointer)
...@@ -65,14 +73,12 @@ __cs_enter_context: ...@@ -65,14 +73,12 @@ __cs_enter_context:
ldmxcsr 0x30(%rsp) ldmxcsr 0x30(%rsp)
# restore x87 control-word # restore x87 control-word
fldcw 0x34(%rsp) fldcw 0x34(%rsp)
# Free space for restored state
leaq 0x38(%rsp), %rsp
############ Restore State ############ ############ Restore State ############
# TODO: Maybe look into a 'cleanup' hook for freeing the stack space here.
# Just return back from the call. # Just return back from the call.
# This is the end of a fiber, so we have no continuation. # This is the end of a fiber, so we have no continuation.
xor %rax, %rax xor %rax, %rax
movq %rbp, %rsp
popq %rbp
ret ret
...@@ -11,10 +11,11 @@ __cs_switch_context: ...@@ -11,10 +11,11 @@ __cs_switch_context:
# Return # Return
# rax = continuation that returned control back to the caller (null if fallthrough) # rax = continuation that returned control back to the caller (null if fallthrough)
############### Save State ############### pushq %rbp
# Make space for all register state we will store. movq %rsp, %rbp
leaq -0x38(%rsp), %rsp subq $0x38, %rsp
############### Save State ###############
# Store calee saved general registers. # Store calee saved general registers.
movq %r12, 0x00(%rsp) movq %r12, 0x00(%rsp)
movq %r13, 0x08(%rsp) movq %r13, 0x08(%rsp)
...@@ -46,11 +47,10 @@ __cs_switch_context: ...@@ -46,11 +47,10 @@ __cs_switch_context:
ldmxcsr 0x30(%rsp) ldmxcsr 0x30(%rsp)
# restore x87 control-word # restore x87 control-word
fldcw 0x34(%rsp) fldcw 0x34(%rsp)
# Free space for restored state
leaq 0x38(%rsp), %rsp
############ Restore State ############ ############ Restore State ############
# Return the context we came from as a continuation. # Return the context we came from as a continuation.
# rax has already the correct value # rax has already the correct value
movq %rbp, %rsp
popq %rbp
ret ret
...@@ -4,6 +4,9 @@ ...@@ -4,6 +4,9 @@
#include "assembly_bindings.h" #include "assembly_bindings.h"
#include <cstdlib>
#include <cstdio>
namespace context_switcher { namespace context_switcher {
/** /**
...@@ -38,6 +41,10 @@ struct continuation { ...@@ -38,6 +41,10 @@ struct continuation {
} }
assembly_bindings::continuation_t consume() { assembly_bindings::continuation_t consume() {
if (cont_pointer_ == nullptr) {
printf("Error!\n");
}
auto tmp = cont_pointer_; auto tmp = cont_pointer_;
cont_pointer_ = nullptr; cont_pointer_ = nullptr;
return tmp; return tmp;
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#define PLS_INTERNAL_SCHEDULING_TASK_TRADING_DEQUE_H_ #define PLS_INTERNAL_SCHEDULING_TASK_TRADING_DEQUE_H_
#include <atomic> #include <atomic>
#include <utility>
#include <tuple> #include <tuple>
#include "pls/internal/base/error_handling.h" #include "pls/internal/base/error_handling.h"
...@@ -18,6 +19,8 @@ namespace pls { ...@@ -18,6 +19,8 @@ namespace pls {
namespace internal { namespace internal {
namespace scheduling { namespace scheduling {
using namespace data_structures;
struct trading_deque_entry { struct trading_deque_entry {
std::atomic<task *> traded_task_{nullptr}; std::atomic<task *> traded_task_{nullptr};
std::atomic<unsigned long> forwarding_stamp_{}; std::atomic<unsigned long> forwarding_stamp_{};
...@@ -35,116 +38,171 @@ struct trading_deque_entry { ...@@ -35,116 +38,171 @@ struct trading_deque_entry {
* As each task is associated with memory this suffices to exchange memory blocks needed for execution. * As each task is associated with memory this suffices to exchange memory blocks needed for execution.
*/ */
class external_trading_deque { class external_trading_deque {
public: public:
external_trading_deque(trading_deque_entry *entries, size_t num_entries) : external_trading_deque(trading_deque_entry *entries, size_t num_entries) :
entries_{entries}, num_entries_{num_entries} {}; entries_{entries}, num_entries_{num_entries} {};
void set_thread_id(unsigned id) {
thread_id_ = id;
}
static optional<task *> get_trade_object(task *target_task) {
traded_cas_field current_cas = target_task->external_trading_deque_cas_.load();
if (current_cas.is_filled_with_object()) {
task *result = current_cas.get_trade_object();
traded_cas_field empty_cas;
if (target_task->external_trading_deque_cas_.compare_exchange_strong(current_cas, empty_cas)) {
return optional<task *>{result};
}
}
return optional<task *>{};
}
/** /**
* Pushes a task on the bottom of the deque. * Pushes a task on the bottom of the deque.
* The task itself wil be filled with the unique, synchronizing cas word. * The task itself wil be filled with the unique, synchronizing cas word.
* *
* @param published_task The task to publish on the bottom of the deque. * @param published_task The task to publish on the bottom of the deque.
* @return The content of the cas word, can be later used to check if it changed.
*/ */
traded_cas_field push_bot(task *published_task) { void push_bot(task *published_task) {
auto expected_stamp = bot_internal_.stamp; auto expected_stamp = bot_internal_.stamp;
auto &current_entry = entries_[bot_internal_.value]; auto &current_entry = entries_[bot_internal_.value];
// Store stealing information in the task and deque. // Field that all threads synchronize on.
// Relaxed is fine for this, as adding elements is synced over the bot pointer. // This happens not in the deque itself, but in the published task.
current_entry.forwarding_stamp_.store(expected_stamp, std::memory_order_relaxed); traded_cas_field sync_cas_field;
sync_cas_field.fill_with_stamp(expected_stamp, thread_id_);
published_task->external_trading_deque_cas_.store(sync_cas_field);
traded_cas_field new_cas_field; // Publish the prepared task in the deque.
new_cas_field.fill_with_stamp(expected_stamp, deque_id_); current_entry.forwarding_stamp_.store(expected_stamp);
published_task->traded_field_.store(new_cas_field, std::memory_order_relaxed); current_entry.traded_task_.store(published_task);
current_entry.traded_task_.store(published_task, std::memory_order_relaxed);
// Advance the bot pointer. Linearization point for making the task public. // Advance the bot pointer. Linearization point for making the task public.
bot_internal_.stamp++; bot_internal_.stamp++;
bot_internal_.value++; bot_internal_.value++;
bot_.store(bot_internal_.value, std::memory_order_release); bot_.store(bot_internal_.value);
}
void reset_bot_and_top() {
bot_internal_.value = 0;
bot_internal_.stamp++;
return new_cas_field; bot_.store(0);
top_.store({bot_internal_.stamp, 0});
} }
void popped_bot() { void decrease_bot() {
bot_internal_.value--; bot_internal_.value--;
bot_.store(bot_internal_.value);
}
bot_.store(bot_internal_.value, std::memory_order_relaxed); /**
* Tries to pop the last task on the deque.
*
* @return optional<task*> holding the popped task if successful.
*/
optional<task *> pop_bot() {
if (bot_internal_.value == 0) { if (bot_internal_.value == 0) {
bot_internal_.stamp++; reset_bot_and_top();
top_.store({bot_internal_.stamp, 0}, std::memory_order_release); return optional<task *>{};
}
} }
decrease_bot();
void empty_deque() { auto &current_entry = entries_[bot_internal_.value];
bot_internal_.value = 0; auto *popped_task = current_entry.traded_task_.load();
bot_internal_.stamp++; auto expected_stamp = current_entry.forwarding_stamp_.load();
// TODO: We might be able to relax memory orderings... // We know what value must be in the cas field if no other thread stole it.
bot_.store(bot_internal_.value); traded_cas_field expected_sync_cas_field;
top_.store(bot_internal_); expected_sync_cas_field.fill_with_stamp(expected_stamp, thread_id_);
traded_cas_field empty_cas_field;
if (popped_task->external_trading_deque_cas_.compare_exchange_strong(expected_sync_cas_field, empty_cas_field)) {
return optional<task *>{popped_task};
} else {
reset_bot_and_top();
return optional<task *>{};
} }
}
struct peek_result {
peek_result(optional<task *> top_task, stamped_integer top_pointer) : top_task_{std::move(top_task)},
top_pointer_{top_pointer} {};
optional<task *> top_task_;
stamped_integer top_pointer_;
};
std::tuple<data_structures::optional<task *>, data_structures::stamped_integer> peek_top() { /**
* Peek at the current task on top of the deque.
* This is required, as we need to look at the task to figure out what we trade in for it.
* (Note: we could go without this by doing some tricks with top/bot pointers, but this
* is simpler and also more flexible if the traded objects are not as trivial as currently).
*
* @return a peek result containing the optional top task (if present) and the current head pointer.
*/
peek_result peek_top() {
auto local_top = top_.load(); auto local_top = top_.load();
auto local_bot = bot_.load(); auto local_bot = bot_.load();
if (local_top.value >= local_bot) { if (local_top.value < local_bot) {
return std::make_tuple(data_structures::optional<task *>{}, local_top); return peek_result{optional<task *>{entries_[local_top.value].traded_task_}, local_top};
} else { } else {
return std::make_tuple(data_structures::optional<task *>{entries_[local_top.value].traded_task_}, local_top); return peek_result{optional<task *>{}, local_top};
} }
} }
data_structures::optional<task *> pop_top(task *trade_offer, data_structures::stamped_integer local_top) { optional<task *> pop_top(task *offered_task, stamped_integer expected_top) {
auto local_bot = bot_.load(); auto local_bot = bot_.load();
if (local_top.value >= local_bot) { if (expected_top.value >= local_bot) {
return data_structures::optional<task *>{}; return data_structures::optional<task *>{};
} }
unsigned long expected_top_stamp = local_top.stamp; auto &target_entry = entries_[expected_top.value];
auto &target_entry = entries_[local_top.value];
// Read our potential result // Read our potential result
task *result = target_entry.traded_task_.load(std::memory_order_relaxed); task *result = target_entry.traded_task_.load();
unsigned long forwarding_stamp = target_entry.forwarding_stamp_.load(std::memory_order_relaxed); unsigned long forwarding_stamp = target_entry.forwarding_stamp_.load();
// Try to get it by CAS with the expected field entry, giving up our offered_task for it
traded_cas_field expected_sync_cas_field;
expected_sync_cas_field.fill_with_stamp(expected_top.stamp, thread_id_);
// Try to get it by CAS with the expected field entry, giving up our offered_object for it
traded_cas_field expected_field;
expected_field.fill_with_stamp(expected_top_stamp, deque_id_);
traded_cas_field offered_field; traded_cas_field offered_field;
offered_field.fill_with_trade_object(trade_offer); offered_field.fill_with_trade_object(offered_task);
if (result->traded_field_.compare_exchange_strong(expected_field, offered_field, std::memory_order_acq_rel)) { if (result->external_trading_deque_cas_.compare_exchange_strong(expected_sync_cas_field, offered_field)) {
// We got it, for sure move the top pointer forward. // We got it, for sure move the top pointer forward.
top_.compare_exchange_strong(local_top, {local_top.stamp + 1, local_top.value + 1}); top_.compare_exchange_strong(expected_top, {expected_top.stamp + 1, expected_top.value + 1});
// Return the stolen task // Return the stolen task
return data_structures::optional<task *>{result}; return data_structures::optional<task *>{result};
} else { } else {
// We did not get it...help forwarding the top pointer anyway. // We did not get it...help forwarding the top pointer anyway.
if (expected_top_stamp == forwarding_stamp) { if (expected_top.stamp == forwarding_stamp) {
// ...move the pointer forward if someone else put a valid trade object in there. // ...move the pointer forward if someone else put a valid trade object in there.
top_.compare_exchange_strong(local_top, {local_top.stamp + 1, local_top.value + 1}); top_.compare_exchange_strong(expected_top, {expected_top.stamp + 1, expected_top.value + 1});
} else { } else {
// ...we failed because the top tag lags behind...try to fix it. // ...we failed because the top tag lags behind...try to fix it.
// This means only updating the tag, as this location can still hold data we need. // This means only updating the tag, as this location can still hold data we need.
top_.compare_exchange_strong(local_top, {forwarding_stamp, local_top.value}); top_.compare_exchange_strong(expected_top, {forwarding_stamp, expected_top.value});
} }
return data_structures::optional<task *>{}; return data_structures::optional<task *>{};
} }
} }
private: private:
trading_deque_entry *entries_; // info on this deque
size_t num_entries_; trading_deque_entry *const entries_;
const size_t num_entries_;
unsigned thread_id_;
unsigned deque_id_; // fields for stealing/interacting
stamped_integer bot_internal_{0, 0};
alignas(base::system_details::CACHE_LINE_SIZE) std::atomic<data_structures::stamped_integer> top_{{0, 0}}; alignas(base::system_details::CACHE_LINE_SIZE) std::atomic<stamped_integer> top_{{0, 0}};
alignas(base::system_details::CACHE_LINE_SIZE) std::atomic<size_t> bot_{0}; alignas(base::system_details::CACHE_LINE_SIZE) std::atomic<size_t> bot_{0};
data_structures::stamped_integer bot_internal_{0, 0};
}; };
template<size_t SIZE> template<size_t SIZE>
......
...@@ -55,7 +55,7 @@ struct alignas(base::system_details::CACHE_LINE_SIZE) task { ...@@ -55,7 +55,7 @@ struct alignas(base::system_details::CACHE_LINE_SIZE) task {
context_switcher::continuation continuation_; context_switcher::continuation continuation_;
// Work-Stealing // Work-Stealing
std::atomic<traded_cas_field> traded_field_{}; std::atomic<traded_cas_field> external_trading_deque_cas_{};
task *resource_stack_next_{}; task *resource_stack_next_{};
std::atomic<data_structures::stamped_integer> resource_stack_root_{{0, 0}}; std::atomic<data_structures::stamped_integer> resource_stack_root_{{0, 0}};
bool clean_; bool clean_;
......
...@@ -38,6 +38,7 @@ class task_manager { ...@@ -38,6 +38,7 @@ class task_manager {
for (size_t i = 0; i < num_tasks_; i++) { for (size_t i = 0; i < num_tasks_; i++) {
this_thread_tasks_[i].thread_id_ = id; this_thread_tasks_[i].thread_id_ = id;
} }
deque_.set_thread_id(id);
} }
task &get_active_task() { task &get_active_task() {
......
...@@ -27,30 +27,33 @@ void task_manager::spawn_child(F &&lambda) { ...@@ -27,30 +27,33 @@ void task_manager::spawn_child(F &&lambda) {
last_task->continuation_ = std::move(cont); last_task->continuation_ = std::move(cont);
spawning_task_manager->active_task_ = this_task; spawning_task_manager->active_task_ = this_task;
traded_cas_field expected_cas_value = spawning_task_manager->deque_.push_bot(last_task); spawning_task_manager->deque_.push_bot(last_task);
traded_cas_field empty_cas;
lambda(); lambda();
auto *syncing_task_manager = &thread_state::get().get_task_manager(); auto *syncing_task_manager = &thread_state::get().get_task_manager();
if (last_task->traded_field_.compare_exchange_strong(expected_cas_value, empty_cas)) { auto pop_result = syncing_task_manager->deque_.pop_bot();
if (pop_result) {
// Fast path, simply continue execution where we left of before spawn. // Fast path, simply continue execution where we left of before spawn.
// This requires no coordination with the resource stack. PLS_ASSERT(*pop_result == last_task,
"Fast path, nothing can have changed until here.");
PLS_ASSERT(spawning_task_manager == syncing_task_manager,
"Fast path, nothing can have changed here.");
PLS_ASSERT(last_task->continuation_.valid(),
"Fast path, no one can have continued working on the last task.");
syncing_task_manager->active_task_ = last_task; syncing_task_manager->active_task_ = last_task;
syncing_task_manager->deque_.popped_bot();
return std::move(last_task->continuation_); return std::move(last_task->continuation_);
} else { } else {
// Slow path, the continuation was stolen. // Slow path, the continuation was stolen.
// First empty our own deque (everything below must have been stolen already).
syncing_task_manager->deque_.empty_deque();
context_switcher::continuation result_cont; context_switcher::continuation result_cont;
if (syncing_task_manager->try_clean_return(result_cont)) { if (syncing_task_manager->try_clean_return(result_cont)) {
// We return back to the main scheduling loop // We return back to the main scheduling loop
return result_cont; PLS_ASSERT(result_cont.valid(), "Must only return valid continuations...");
return std::move(result_cont);
} else { } else {
// We finish up the last task and are the sole owner again // We finish up the last task and are the sole owner again
return result_cont; PLS_ASSERT(result_cont.valid(), "Must only return valid continuations...");
return std::move(result_cont);
} }
} }
}); });
......
...@@ -45,17 +45,17 @@ struct traded_cas_field { ...@@ -45,17 +45,17 @@ struct traded_cas_field {
public: public:
void fill_with_stamp(unsigned long stamp, unsigned long deque_id) { void fill_with_stamp(unsigned long stamp, unsigned long deque_id) {
cas_integer_ = (((stamp << STAMP_SHIFT) & STAMP_BITS) | ((stamp << ID_SHIFT) & ID_BITS) | STAMP_TAG); cas_integer_ = (((stamp << STAMP_SHIFT) & STAMP_BITS) | ((deque_id << ID_SHIFT) & ID_BITS) | STAMP_TAG);
} }
unsigned long get_stamp() { unsigned long get_stamp() {
PLS_ASSERT(is_filled_with_tag(), "Must only read out the tag when the traded field contains one."); PLS_ASSERT(is_filled_with_stamp(), "Must only read out the tag when the traded field contains one.");
return (((unsigned long) cas_integer_) & STAMP_BITS) >> STAMP_SHIFT; return (((unsigned long) cas_integer_) & STAMP_BITS) >> STAMP_SHIFT;
} }
unsigned long get_deque_id() { unsigned long get_deque_id() {
PLS_ASSERT(is_filled_with_tag(), "Must only read out the tag when the traded field contains one."); PLS_ASSERT(is_filled_with_stamp(), "Must only read out the tag when the traded field contains one.");
return (((unsigned long) cas_integer_) & ID_BITS) >> ID_SHIFT; return (((unsigned long) cas_integer_) & ID_BITS) >> ID_SHIFT;
} }
bool is_filled_with_tag() { bool is_filled_with_stamp() {
return (((unsigned long) cas_integer_) & TAG_BITS) == STAMP_TAG; return (((unsigned long) cas_integer_) & TAG_BITS) == STAMP_TAG;
} }
......
...@@ -90,17 +90,16 @@ void scheduler::work_thread_work_section() { ...@@ -90,17 +90,16 @@ void scheduler::work_thread_work_section() {
// Move the traded in resource of this active task over to the stack of resources. // Move the traded in resource of this active task over to the stack of resources.
auto *stolen_task = &my_task_manager.get_active_task(); auto *stolen_task = &my_task_manager.get_active_task();
traded_cas_field stolen_task_cas = stolen_task->traded_field_.load(); traded_cas_field stolen_task_cas = stolen_task->external_trading_deque_cas_.load();
if (stolen_task_cas.is_filled_with_object()) { if (stolen_task_cas.is_filled_with_object()) {
// Push the traded in resource on the resource stack to clear the traded_field for later steals/spawns. // Push the traded in resource on the resource stack to clear the traded_field for later steals/spawns.
auto *exchanged_task = stolen_task_cas.get_trade_object(); auto *exchanged_task = stolen_task_cas.get_trade_object();
my_task_manager.push_resource_on_task(stolen_task, exchanged_task); my_task_manager.push_resource_on_task(stolen_task, exchanged_task);
traded_cas_field empty_field; traded_cas_field empty_field;
traded_cas_field expected_field; traded_cas_field expected_field;
expected_field.fill_with_trade_object(exchanged_task); expected_field.fill_with_trade_object(exchanged_task);
if (stolen_task->traded_field_.compare_exchange_strong(expected_field, empty_field)) { if (stolen_task->external_trading_deque_cas_.compare_exchange_strong(expected_field, empty_field)) {
// All good, nothing more to do // All good, nothing more to do
} else { } else {
// The last other active thread took it as its spare resource... // The last other active thread took it as its spare resource...
......
...@@ -36,8 +36,8 @@ bool task_manager::steal_task(task_manager &stealing_task_manager) { ...@@ -36,8 +36,8 @@ bool task_manager::steal_task(task_manager &stealing_task_manager) {
PLS_ASSERT(stealing_task_manager.active_task_->depth_ == 0, "Must only steal with clean task chain."); PLS_ASSERT(stealing_task_manager.active_task_->depth_ == 0, "Must only steal with clean task chain.");
auto peek = deque_.peek_top(); auto peek = deque_.peek_top();
auto optional_target_task = std::get<0>(peek); auto optional_target_task = peek.top_task_;
auto target_top = std::get<1>(peek); auto target_top = peek.top_pointer_;
if (optional_target_task) { if (optional_target_task) {
PLS_ASSERT(stealing_task_manager.check_task_chain(), "We are stealing, must not have a bad chain here!"); PLS_ASSERT(stealing_task_manager.check_task_chain(), "We are stealing, must not have a bad chain here!");
...@@ -56,7 +56,10 @@ bool task_manager::steal_task(task_manager &stealing_task_manager) { ...@@ -56,7 +56,10 @@ bool task_manager::steal_task(task_manager &stealing_task_manager) {
auto optional_result_task = deque_.pop_top(traded_task, target_top); auto optional_result_task = deque_.pop_top(traded_task, target_top);
if (optional_result_task) { if (optional_result_task) {
PLS_ASSERT(*optional_result_task == target_task, "We must only steal the task that we peeked at!"); PLS_ASSERT(target_task->thread_id_ != traded_task->thread_id_,
"It is impossible to steal an task we already own!");
PLS_ASSERT(*optional_result_task == target_task,
"We must only steal the task that we peeked at!");
// the steal was a success, link the chain so we own the stolen part // the steal was a success, link the chain so we own the stolen part
target_task->next_ = next_own_task; target_task->next_ = next_own_task;
next_own_task->prev_ = target_task; next_own_task->prev_ = target_task;
...@@ -164,16 +167,13 @@ bool task_manager::try_clean_return(context_switcher::continuation &result_cont) ...@@ -164,16 +167,13 @@ bool task_manager::try_clean_return(context_switcher::continuation &result_cont)
task *clean_chain = pop_resource_from_task(last_task); task *clean_chain = pop_resource_from_task(last_task);
if (clean_chain == nullptr) { if (clean_chain == nullptr) {
// double-check if we are really last one or we only have unlucky timing // double-check if we are really last one or we only have unlucky timing
auto cas_field = last_task->traded_field_.load(); auto optional_cas_task = external_trading_deque::get_trade_object(last_task);
if (cas_field.is_filled_with_object()) { if (optional_cas_task) {
traded_cas_field empty_target; clean_chain = *optional_cas_task;
if (last_task->traded_field_.compare_exchange_strong(cas_field, empty_target)) {
clean_chain = cas_field.get_trade_object();
} else { } else {
clean_chain = pop_resource_from_task(last_task); clean_chain = pop_resource_from_task(last_task);
} }
} }
}
if (clean_chain != nullptr) { if (clean_chain != nullptr) {
// We got a clean chain to continue working on. // We got a clean chain to continue working on.
......
#include <catch.hpp> #include <catch.hpp>
#include <atomic> #include <vector>
#include <thread>
#include <mutex>
#include "pls/internal/scheduling/scheduler.h" #include "pls/internal/scheduling/traded_cas_field.h"
#include "pls/internal/scheduling/cont.h" #include "pls/internal/scheduling/task.h"
#include "pls/internal/scheduling/cont_manager.h" #include "pls/internal/scheduling/external_trading_deque.h"
#include "pls/internal/scheduling/scheduler_memory.h"
#include "pls/internal/scheduling/parallel_result.h"
using namespace pls::internal::scheduling; using namespace pls::internal::scheduling;
// TODO: Introduce actual tests once multiple threads work... TEST_CASE("traded cas field bitmaps correctly", "[internal/scheduling/traded_cas_field.h]") {
TEST_CASE("continuation stealing", "[internal/scheduling/cont_manager.h]") { traded_cas_field empty_field;
const int NUM_THREADS = 2; REQUIRE(empty_field.is_empty());
const int NUM_TASKS = 8; REQUIRE(!empty_field.is_filled_with_stamp());
const int MAX_TASK_STACK_SIZE = 8; REQUIRE(!empty_field.is_filled_with_object());
const int NUM_CONTS = 8;
const int MAX_CONT_SIZE = 256; const int stamp = 42;
const int ID = 10;
static_scheduler_memory<NUM_THREADS, traded_cas_field tag_field;
NUM_TASKS, tag_field.fill_with_stamp(stamp, ID);
MAX_TASK_STACK_SIZE, REQUIRE(tag_field.is_filled_with_stamp());
NUM_CONTS, REQUIRE(!tag_field.is_empty());
MAX_CONT_SIZE> static_scheduler_memory; REQUIRE(!tag_field.is_filled_with_object());
REQUIRE(tag_field.get_stamp() == stamp);
scheduler scheduler{static_scheduler_memory, NUM_THREADS}; REQUIRE(tag_field.get_deque_id() == ID);
// Coordinate progress to match OUR order alignas(64) task obj;
std::atomic<int> progress{0}; traded_cas_field obj_field;
obj_field.fill_with_trade_object(&obj);
// Order: REQUIRE(obj_field.is_filled_with_object());
// 0) work on first task on main thread REQUIRE(!obj_field.is_empty());
// 1) second thread stole right task REQUIRE(!obj_field.is_filled_with_stamp());
}
scheduler.perform_work([&]() {
return scheduler::par([&]() { TEST_CASE("external trading deque", "[internal/scheduling/external_trading_deque]") {
while (progress.load() != 1); static_external_trading_deque<16> static_external_trading_deque_1;
return parallel_result<int>{0}; external_trading_deque &deque_1 = static_external_trading_deque_1.get_deque();
}, [&]() { deque_1.set_thread_id(1);
progress.store(1);
return parallel_result<int>{0}; static_external_trading_deque<16> static_external_trading_deque_2;
}).then([&](int, int) { external_trading_deque &deque_2 = static_external_trading_deque_2.get_deque();
deque_2.set_thread_id(2);
return parallel_result<int>{0};
}); std::vector<task> tasks(16);
});
SECTION("basic operations") {
// Must start empty
REQUIRE(!deque_1.pop_bot());
REQUIRE(!deque_2.pop_bot());
// Local push/pop
deque_1.push_bot(&tasks[0]);
REQUIRE(*deque_1.pop_bot() == &tasks[0]);
REQUIRE(!deque_1.pop_bot());
// Local push, external pop
deque_1.push_bot(&tasks[0]);
auto peek = deque_1.peek_top();
REQUIRE(*deque_1.pop_top(&tasks[1], peek.top_pointer_) == &tasks[0]);
REQUIRE(*external_trading_deque::get_trade_object(&tasks[0]) == &tasks[1]);
REQUIRE(!deque_1.pop_top(&tasks[1], peek.top_pointer_));
REQUIRE(!deque_1.pop_bot());
// Keeps push/pop order
deque_1.push_bot(&tasks[0]);
deque_1.push_bot(&tasks[1]);
REQUIRE(*deque_1.pop_bot() == &tasks[1]);
REQUIRE(*deque_1.pop_bot() == &tasks[0]);
REQUIRE(!deque_1.pop_bot());
deque_1.push_bot(&tasks[0]);
deque_1.push_bot(&tasks[1]);
auto peek1 = deque_1.peek_top();
REQUIRE(*deque_1.pop_top(&tasks[2], peek1.top_pointer_) == &tasks[0]);
auto peek2 = deque_1.peek_top();
REQUIRE(*deque_1.pop_top(&tasks[3], peek2.top_pointer_) == &tasks[1]);
}
SECTION("Interwined execution #1") {
// Two top poppers
deque_1.push_bot(&tasks[0]);
auto peek1 = deque_1.peek_top();
auto peek2 = deque_1.peek_top();
REQUIRE(*deque_1.pop_top(&tasks[1], peek1.top_pointer_) == &tasks[0]);
REQUIRE(!deque_1.pop_top(&tasks[2], peek2.top_pointer_));
}
SECTION("Interwined execution #2") {
// Top and bottom access
deque_1.push_bot(&tasks[0]);
auto peek1 = deque_1.peek_top();
REQUIRE(*deque_1.pop_bot() == &tasks[0]);
REQUIRE(!deque_1.pop_top(&tasks[2], peek1.top_pointer_));
}
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment