Commit c6dd2fc0 by FritzFlorian

WIP: Partly functional version. Stealing and continuation tarding works 'most' of the time.

The main issue seems to still be the fact that we have a lock free protocol where a steal can be pending. We plan to remove this fact next by introducing a protocol that works on a single atomic update.
parent adf05e9a
Pipeline #1340 failed with stages
in 40 seconds
...@@ -8,8 +8,6 @@ using namespace pls::internal::scheduling; ...@@ -8,8 +8,6 @@ using namespace pls::internal::scheduling;
#include <vector> #include <vector>
#include <atomic> #include <atomic>
std::atomic<unsigned long> count;
static constexpr int CUTOFF = 16; static constexpr int CUTOFF = 16;
static constexpr int INPUT_SIZE = 8192; static constexpr int INPUT_SIZE = 8192;
typedef std::vector<std::complex<double>> complex_vector; typedef std::vector<std::complex<double>> complex_vector;
...@@ -44,7 +42,6 @@ void combine(complex_vector::iterator data, int n) { ...@@ -44,7 +42,6 @@ void combine(complex_vector::iterator data, int n) {
void fft_normal(complex_vector::iterator data, int n) { void fft_normal(complex_vector::iterator data, int n) {
if (n < 2) { if (n < 2) {
// count++;
return; return;
} }
...@@ -56,14 +53,15 @@ void fft_normal(complex_vector::iterator data, int n) { ...@@ -56,14 +53,15 @@ void fft_normal(complex_vector::iterator data, int n) {
parallel_result<short> fft(complex_vector::iterator data, int n) { parallel_result<short> fft(complex_vector::iterator data, int n) {
if (n < 2) { if (n < 2) {
return 0; return parallel_result<short>{0};
} }
divide(data, n); divide(data, n);
if (n <= CUTOFF) { if (n <= CUTOFF) {
fft_normal(data, n / 2); fft_normal(data, n / 2);
fft_normal(data + n / 2, n / 2); fft_normal(data + n / 2, n / 2);
return 0; combine(data, n);
return parallel_result<short>{0};
} else { } else {
return scheduler::par([=]() { return scheduler::par([=]() {
return fft(data, n / 2); return fft(data, n / 2);
...@@ -71,7 +69,7 @@ parallel_result<short> fft(complex_vector::iterator data, int n) { ...@@ -71,7 +69,7 @@ parallel_result<short> fft(complex_vector::iterator data, int n) {
return fft(data + n / 2, n / 2); return fft(data + n / 2, n / 2);
}).then([=](int, int) { }).then([=](int, int) {
combine(data, n); combine(data, n);
return 0; return parallel_result<short>{0};
}); });
} }
} }
...@@ -93,13 +91,13 @@ complex_vector prepare_input(int input_size) { ...@@ -93,13 +91,13 @@ complex_vector prepare_input(int input_size) {
} }
static constexpr int NUM_ITERATIONS = 1000; static constexpr int NUM_ITERATIONS = 1000;
constexpr size_t NUM_THREADS = 1; constexpr size_t NUM_THREADS = 2;
constexpr size_t NUM_TASKS = 64; constexpr size_t NUM_TASKS = 128;
constexpr size_t MAX_TASK_STACK_SIZE = 0; constexpr size_t MAX_TASK_STACK_SIZE = 0;
constexpr size_t NUM_CONTS = 64; constexpr size_t NUM_CONTS = 128;
constexpr size_t MAX_CONT_SIZE = 192; constexpr size_t MAX_CONT_SIZE = 512;
int main() { int main() {
complex_vector initial_input = prepare_input(INPUT_SIZE); complex_vector initial_input = prepare_input(INPUT_SIZE);
...@@ -112,7 +110,6 @@ int main() { ...@@ -112,7 +110,6 @@ int main() {
scheduler scheduler{static_scheduler_memory, NUM_THREADS}; scheduler scheduler{static_scheduler_memory, NUM_THREADS};
count.store(0);
auto start = std::chrono::steady_clock::now(); auto start = std::chrono::steady_clock::now();
for (int i = 0; i < NUM_ITERATIONS; i++) { for (int i = 0; i < NUM_ITERATIONS; i++) {
complex_vector input_2(initial_input); complex_vector input_2(initial_input);
...@@ -122,23 +119,20 @@ int main() { ...@@ -122,23 +119,20 @@ int main() {
}, []() { }, []() {
return parallel_result<int>{0}; return parallel_result<int>{0};
}).then([](int, int) { }).then([](int, int) {
return 0; return parallel_result<int>{0};
}); });
}); });
} }
auto end = std::chrono::steady_clock::now(); auto end = std::chrono::steady_clock::now();
std::cout << "Count: " << count.load() << std::endl;
std::cout << "Framework: " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() std::cout << "Framework: " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count()
<< std::endl; << std::endl;
count.store(0);
start = std::chrono::steady_clock::now(); start = std::chrono::steady_clock::now();
for (int i = 0; i < NUM_ITERATIONS; i++) { for (int i = 0; i < NUM_ITERATIONS; i++) {
complex_vector input_1(initial_input); complex_vector input_1(initial_input);
fft_normal(input_1.begin(), INPUT_SIZE); fft_normal(input_1.begin(), INPUT_SIZE);
} }
end = std::chrono::steady_clock::now(); end = std::chrono::steady_clock::now();
std::cout << "Count: " << count.load() << std::endl;
std::cout << "Normal: " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() std::cout << "Normal: " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count()
<< std::endl; << std::endl;
......
...@@ -8,18 +8,15 @@ ...@@ -8,18 +8,15 @@
using namespace pls::internal; using namespace pls::internal;
constexpr size_t NUM_THREADS = 2; constexpr size_t NUM_THREADS = 1;
constexpr size_t NUM_TASKS = 64; constexpr size_t NUM_TASKS = 128;
constexpr size_t MAX_TASK_STACK_SIZE = 0; constexpr size_t MAX_TASK_STACK_SIZE = 0;
constexpr size_t NUM_CONTS = 64; constexpr size_t NUM_CONTS = 128;
constexpr size_t MAX_CONT_SIZE = 256; constexpr size_t MAX_CONT_SIZE = 256;
std::atomic<int> count{0}; int fib_normal(int n) {
scheduling::parallel_result<int> fib(int n) {
base::this_thread::sleep(100);
// std::cout << "Fib(" << n << "): " << count++ << ", " << scheduling::thread_state::get().get_id() << std::endl;
if (n == 0) { if (n == 0) {
return 0; return 0;
} }
...@@ -27,20 +24,11 @@ scheduling::parallel_result<int> fib(int n) { ...@@ -27,20 +24,11 @@ scheduling::parallel_result<int> fib(int n) {
return 1; return 1;
} }
return scheduling::scheduler::par([=]() { int result = fib_normal(n - 1) + fib_normal(n - 2);
return fib(n - 1); return result;
}, [=]() {
return fib(n - 2);
}).then([=](int a, int b) {
scheduling::parallel_result<int> result{a + b};
base::this_thread::sleep(100);
// std::cout << "Done Fib(" << n << "): " << (a + b) << ", " << scheduling::thread_state::get().get_id() << std::endl;
return result;
});
} }
int fib_normal(int n) { scheduling::parallel_result<int> fib(int n) {
// std::cout << "Fib(" << n << "): " << count++ << std::endl;
if (n == 0) { if (n == 0) {
return 0; return 0;
} }
...@@ -48,9 +36,13 @@ int fib_normal(int n) { ...@@ -48,9 +36,13 @@ int fib_normal(int n) {
return 1; return 1;
} }
int result = fib_normal(n - 1) + fib_normal(n - 2); return scheduling::scheduler::par([=]() {
// std::cout << "Done Fib(" << n << "): " << result << std::endl; return fib(n - 1);
return result; }, [=]() {
return fib(n - 2);
}).then([=](int a, int b) {
return scheduling::parallel_result<int>{a + b};
});
} }
int main() { int main() {
...@@ -63,27 +55,26 @@ int main() { ...@@ -63,27 +55,26 @@ int main() {
scheduling::scheduler scheduler{static_scheduler_memory, NUM_THREADS}; scheduling::scheduler scheduler{static_scheduler_memory, NUM_THREADS};
auto start = std::chrono::steady_clock::now(); auto start = std::chrono::steady_clock::now();
// std::cout << "fib = " << fib_normal(10) << std::endl; std::cout << "fib = " << fib_normal(30) << std::endl;
auto end = std::chrono::steady_clock::now(); auto end = std::chrono::steady_clock::now();
std::cout << "Normal: " << std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() std::cout << "Normal: " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count()
<< std::endl; << std::endl;
start = std::chrono::steady_clock::now(); start = std::chrono::steady_clock::now();
scheduler.perform_work([]() { scheduler.perform_work([]() {
// return scheduling::scheduler::par([]() { return scheduling::scheduler::par([]() {
// return scheduling::parallel_result<int>(0); return scheduling::parallel_result<int>(0);
// }, []() { }, []() {
// return fib(16); return fib(30);
// }).then([](int, int b) { }).then([](int, int b) {
// std::cout << "fib = " << (b) << std::endl; std::cout << "fib = " << b << std::endl;
// return scheduling::parallel_result<int>{0}; return scheduling::parallel_result<int>{0};
// }); });
return fib(10);
}); });
end = std::chrono::steady_clock::now(); end = std::chrono::steady_clock::now();
std::cout << "Framework: " << std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() << std::endl; std::cout << "Framework: " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << std::endl;
return 0; return 0;
} }
...@@ -77,7 +77,7 @@ class cont : public base_cont { ...@@ -77,7 +77,7 @@ class cont : public base_cont {
using BASE_RES_TYPE = typename std::remove_cv<typename std::remove_reference<RES_TYPE>::type>::type; using BASE_RES_TYPE = typename std::remove_cv<typename std::remove_reference<RES_TYPE>::type>::type;
static void execute(cont &cont) { static void execute(cont &cont) {
parallel_result<BASE_RES_TYPE> result{cont.function_((*cont.result_1_).value(), (*cont.result_2_).value())}; parallel_result<BASE_RES_TYPE> result{cont.function_((*cont.left_result_).value(), (*cont.right_result_).value())};
if (result.fast_path() && cont.parent_ != nullptr) { if (result.fast_path() && cont.parent_ != nullptr) {
if (cont.is_right_child()) { if (cont.is_right_child()) {
cont.parent_->store_right_result(std::move(result)); cont.parent_->store_right_result(std::move(result));
...@@ -90,7 +90,7 @@ class cont : public base_cont { ...@@ -90,7 +90,7 @@ class cont : public base_cont {
template<typename INNER_TYPE> template<typename INNER_TYPE>
struct result_runner<parallel_result<INNER_TYPE>> { struct result_runner<parallel_result<INNER_TYPE>> {
static void execute(cont &cont) { static void execute(cont &cont) {
auto result = cont.function_((*cont.result_1_).value(), (*cont.result_2_).value()); auto result = cont.function_((*cont.left_result_).value(), (*cont.right_result_).value());
if (result.fast_path() && cont.parent_) { if (result.fast_path() && cont.parent_) {
if (cont.is_right_child()) { if (cont.is_right_child()) {
cont.parent_->store_right_result(std::move(result)); cont.parent_->store_right_result(std::move(result));
...@@ -113,7 +113,7 @@ class cont : public base_cont { ...@@ -113,7 +113,7 @@ class cont : public base_cont {
task_{std::forward<T2ARGS>(task_2_args)..., this} {}; task_{std::forward<T2ARGS>(task_2_args)..., this} {};
void execute() override { void execute() override {
using result_type = decltype(function_((*result_1_).value(), (*result_2_).value())); using result_type = decltype(function_((*left_result_).value(), (*right_result_).value()));
result_runner<result_type>::execute(*this); result_runner<result_type>::execute(*this);
this->get_memory_block()->free_buffer(); this->get_memory_block()->free_buffer();
this->get_memory_block()->reset_state(); this->get_memory_block()->reset_state();
...@@ -124,12 +124,11 @@ class cont : public base_cont { ...@@ -124,12 +124,11 @@ class cont : public base_cont {
task_.execute(); task_.execute();
} }
void *get_right_result_pointer() override {
return &result_1_;
}
void *get_left_result_pointer() override { void *get_left_result_pointer() override {
return &result_2_; return &left_result_;
}
void *get_right_result_pointer() override {
return &right_result_;
} }
T2 *get_task() { T2 *get_task() {
...@@ -143,8 +142,8 @@ class cont : public base_cont { ...@@ -143,8 +142,8 @@ class cont : public base_cont {
// Some fields/actual values stay uninitialized (save time on the fast path if we don not need them). // Some fields/actual values stay uninitialized (save time on the fast path if we don not need them).
// More fields untouched on the fast path is good, but for ease of an implementation we only keep some for now. // More fields untouched on the fast path is good, but for ease of an implementation we only keep some for now.
delayed_init<R1> result_1_; delayed_init<R1> left_result_;
delayed_init<R2> result_2_; delayed_init<R2> right_result_;
}; };
} }
......
...@@ -25,12 +25,12 @@ class cont_manager { ...@@ -25,12 +25,12 @@ class cont_manager {
: max_cont_size_{MAX_CONT_SIZE}, num_conts_{NUM_CONTS} { : max_cont_size_{MAX_CONT_SIZE}, num_conts_{NUM_CONTS} {
// First node is currently active and our local start // First node is currently active and our local start
start_node_ = active_node_ = init_memory_block<MAX_CONT_SIZE>(cont_storage, nullptr, nullptr, 0); active_node_ = init_memory_block<MAX_CONT_SIZE>(cont_storage, nullptr, 0);
// Build up chain after it // Build up chain after it
memory_block *current_node = start_node_; memory_block *current_node = active_node_;
for (size_t i = 1; i < NUM_CONTS; i++) { for (size_t i = 1; i < NUM_CONTS; i++) {
memory_block *next_node = init_memory_block<MAX_CONT_SIZE>(cont_storage, start_node_, current_node, i); memory_block *next_node = init_memory_block<MAX_CONT_SIZE>(cont_storage, current_node, i);
current_node->set_next(next_node); current_node->set_next(next_node);
current_node = next_node; current_node = next_node;
} }
...@@ -38,16 +38,30 @@ class cont_manager { ...@@ -38,16 +38,30 @@ class cont_manager {
// Aquire and release memory blocks... // Aquire and release memory blocks...
memory_block *get_next_memory_block() { memory_block *get_next_memory_block() {
active_node_->set_start(start_node_);
auto result = active_node_; auto result = active_node_;
active_node_ = active_node_->get_next(); active_node_ = active_node_->get_next();
return result; return result;
} }
void return_memory_block() { void return_memory_block() {
active_node_ = active_node_->get_prev(); active_node_ = active_node_->get_prev();
} }
void move_active_node(int depth) {
if (depth < 0) {
for (int i = 0; i < (depth * -1); i++) {
active_node_ = active_node_->get_prev();
}
} else {
for (int i = 0; i < depth; i++) {
active_node_ = active_node_->get_next();
}
}
}
void move_active_node_to_start() {
move_active_node(-1 * active_node_->get_depth());
}
memory_block *get_active_node() {
return active_node_;
}
// Manage the fall through behaviour/slow path behaviour // Manage the fall through behaviour/slow path behaviour
bool falling_through() const { bool falling_through() const {
...@@ -60,44 +74,11 @@ class cont_manager { ...@@ -60,44 +74,11 @@ class cont_manager {
} }
void aquire_memory_chain(memory_block *target_chain) { void aquire_memory_chain(memory_block *target_chain) {
auto *our_next_node = get_node(target_chain->get_depth() + 1); PLS_ASSERT(active_node_->get_depth() == target_chain->get_depth() + 1,
"Can only steal aquire chain parts with correct depth.");
our_next_node->set_prev(target_chain);
target_chain->set_next(our_next_node);
start_node_ = target_chain->get_start();
active_node_ = target_chain;
}
memory_block *get_node(unsigned int depth) {
// TODO: Remove this O(n) factor to avoid the
// T_1/P + (T_lim + D) stealing time bound.
memory_block *current = start_node_;
for (unsigned int i = 0; i < depth; i++) {
current->set_start(start_node_);
current = current->get_next();
}
return current;
}
void check_clean_chain() {
memory_block *current = start_node_;
for (unsigned int i = 0; i < num_conts_; i++) {
bool buffer_used = current->is_buffer_used();
auto state_value = current->get_state().load().value;
if (state_value != memory_block::initialized || buffer_used || current->get_depth() != i) {
PLS_ASSERT(false,
"Must always steal with a clean chain!");
}
current->set_start(start_node_);
current = current->get_next();
}
}
void set_active_depth(unsigned int depth) { active_node_->set_prev(target_chain);
active_node_ = get_node(depth); target_chain->set_next(active_node_);
} }
void execute_fall_through_code() { void execute_fall_through_code() {
...@@ -109,8 +90,8 @@ class cont_manager { ...@@ -109,8 +90,8 @@ class cont_manager {
auto *notified_cont = fall_through_cont_; auto *notified_cont = fall_through_cont_;
bool notifier_is_right_child = fall_through_child_right; bool notifier_is_right_child = fall_through_child_right;
std::cout << "Notifying Cont on core " << my_state.get_id() << " and depth " // std::cout << "Notifying Cont on core " << my_state.get_id() << " and depth "
<< notified_cont->get_memory_block()->get_depth() << std::endl; // << notified_cont->get_memory_block()->get_depth() << std::endl;
fall_through_cont_ = nullptr; fall_through_cont_ = nullptr;
fall_through_ = false; fall_through_ = false;
...@@ -158,17 +139,18 @@ class cont_manager { ...@@ -158,17 +139,18 @@ class cont_manager {
// ... we finished the continuation. // ... we finished the continuation.
// We are now in charge continuing to execute the above continuation chain. // We are now in charge continuing to execute the above continuation chain.
if (get_node(notified_cont->get_memory_block()->get_depth()) != notified_cont->get_memory_block()) { PLS_ASSERT(active_node_->get_prev()->get_depth() == notified_cont->get_memory_block()->get_depth(),
my_state.cont_manager_.check_clean_chain(); "We must hold the system invariant to be in the correct depth.")
if (active_node_->get_prev() != notified_cont->get_memory_block()) {
// We do not own the thing we will execute. // We do not own the thing we will execute.
// Own it by swapping the chain belonging to it in. // Own it by swapping the chain belonging to it in.
aquire_memory_chain(notified_cont->get_memory_block()); aquire_memory_chain(notified_cont->get_memory_block());
std::cout << "Now in charge of memory chain on core " << my_state.get_id() << std::endl; // std::cout << "Now in charge of memory chain on core " << my_state.get_id() << std::endl;
} }
my_state.parent_cont_ = notified_cont->get_parent(); my_state.parent_cont_ = notified_cont->get_parent();
my_state.right_spawn_ = notified_cont->is_right_child(); my_state.right_spawn_ = notified_cont->is_right_child();
active_node_ = notified_cont->get_memory_block(); active_node_ = notified_cont->get_memory_block();
std::cout << "Execute cont on core " << my_state.get_id() << std::endl; // std::cout << "Execute cont on core " << my_state.get_id() << std::endl;
notified_cont->execute(); notified_cont->execute();
if (!falling_through() && notified_cont->get_parent() != nullptr) { if (!falling_through() && notified_cont->get_parent() != nullptr) {
fall_through_and_notify_cont(notified_cont->get_parent(), notified_cont->is_right_child()); fall_through_and_notify_cont(notified_cont->get_parent(), notified_cont->is_right_child());
...@@ -178,14 +160,16 @@ class cont_manager { ...@@ -178,14 +160,16 @@ class cont_manager {
// ... we did not finish the last continuation. // ... we did not finish the last continuation.
// We are no longer in charge of executing the above continuation chain. // We are no longer in charge of executing the above continuation chain.
if (get_node(notified_cont->get_memory_block()->get_depth()) == notified_cont->get_memory_block()) { PLS_ASSERT(active_node_->get_prev()->get_depth() == notified_cont->get_memory_block()->get_depth(),
"We must hold the system invariant to be in the correct depth.")
if (active_node_->get_prev() == notified_cont->get_memory_block()) {
// We own the thing we are not allowed to execute. // We own the thing we are not allowed to execute.
// Get rid of the ownership by using the offered chain. // Get rid of the ownership by using the offered chain.
aquire_memory_chain(target_chain); aquire_memory_chain(target_chain);
std::cout << "No longer in charge of chain above on core " << my_state.get_id() << std::endl; // std::cout << "No longer in charge of chain above on core " << my_state.get_id() << std::endl;
} }
my_state.cont_manager_.check_clean_chain(); move_active_node_to_start();
// We are done here...nothing more to execute // We are done here...nothing more to execute
return; return;
} }
...@@ -194,7 +178,6 @@ class cont_manager { ...@@ -194,7 +178,6 @@ class cont_manager {
private: private:
template<size_t MAX_CONT_SIZE> template<size_t MAX_CONT_SIZE>
static memory_block *init_memory_block(data_structures::aligned_stack &cont_storage, static memory_block *init_memory_block(data_structures::aligned_stack &cont_storage,
memory_block *memory_chain_start,
memory_block *prev, memory_block *prev,
unsigned long depth) { unsigned long depth) {
// Represents one cont_node and its corresponding memory buffer (as one continuous block of memory). // Represents one cont_node and its corresponding memory buffer (as one continuous block of memory).
...@@ -202,7 +185,7 @@ class cont_manager { ...@@ -202,7 +185,7 @@ class cont_manager {
char *memory_block_ptr = cont_storage.push_bytes<memory_block>(); char *memory_block_ptr = cont_storage.push_bytes<memory_block>();
char *memory_block_buffer_ptr = cont_storage.push_bytes(buffer_size); char *memory_block_buffer_ptr = cont_storage.push_bytes(buffer_size);
return new(memory_block_ptr) memory_block(memory_block_buffer_ptr, buffer_size, memory_chain_start, prev, depth); return new(memory_block_ptr) memory_block(memory_block_buffer_ptr, buffer_size, prev, depth);
} }
private: private:
...@@ -212,7 +195,6 @@ class cont_manager { ...@@ -212,7 +195,6 @@ class cont_manager {
/** /**
* Managing the continuation chain. * Managing the continuation chain.
*/ */
memory_block *start_node_;
memory_block *active_node_; memory_block *active_node_;
/** /**
......
...@@ -18,11 +18,9 @@ class memory_block { ...@@ -18,11 +18,9 @@ class memory_block {
public: public:
memory_block(char *memory_buffer, memory_block(char *memory_buffer,
size_t memory_buffer_size, size_t memory_buffer_size,
memory_block *memory_chain_start,
memory_block *prev, memory_block *prev,
unsigned int depth) unsigned int depth)
: memory_chain_start_{memory_chain_start}, : prev_{prev},
prev_{prev},
next_{nullptr}, next_{nullptr},
offered_chain_{nullptr}, offered_chain_{nullptr},
state_{{initialized}}, state_{{initialized}},
...@@ -70,12 +68,6 @@ class memory_block { ...@@ -70,12 +68,6 @@ class memory_block {
void set_next(memory_block *next) { void set_next(memory_block *next) {
next_ = next; next_ = next;
} }
memory_block *get_start() {
return memory_chain_start_;
}
void set_start(memory_block *start) {
memory_chain_start_ = start;
}
enum state { initialized, execute_local, stealing, stolen, invalid }; enum state { initialized, execute_local, stealing, stolen, invalid };
using stamped_state = data_structures::stamped_integer; using stamped_state = data_structures::stamped_integer;
...@@ -107,7 +99,6 @@ class memory_block { ...@@ -107,7 +99,6 @@ class memory_block {
// Linked list property of memory blocks (a complete list represents a threads currently owned memory). // Linked list property of memory blocks (a complete list represents a threads currently owned memory).
// Each block knows its chain start to allow stealing a whole chain in O(1) // Each block knows its chain start to allow stealing a whole chain in O(1)
// without the need to traverse back to the chain start. // without the need to traverse back to the chain start.
memory_block *memory_chain_start_;
memory_block *prev_, *next_; memory_block *prev_, *next_;
// When blocked on this chain element, we need to know what other chain of memory we // When blocked on this chain element, we need to know what other chain of memory we
......
...@@ -76,8 +76,7 @@ struct scheduler::starter { ...@@ -76,8 +76,7 @@ struct scheduler::starter {
current_cont->store_left_result(std::move(result_1)); current_cont->store_left_result(std::move(result_1));
auto old_state = current_cont->get_memory_block()->get_state().load(); auto old_state = current_cont->get_memory_block()->get_state().load();
current_cont->get_memory_block()->get_state().store({old_state.stamp + 1, memory_block::invalid}); current_cont->get_memory_block()->get_state().store({old_state.stamp + 1, memory_block::invalid});
PLS_ASSERT(current_cont->get_memory_block()->get_results_missing().fetch_add(-1) != 1, current_cont->get_memory_block()->get_results_missing().fetch_add(-1);
"We fall through, meaning we 'block' an cont above, thus this can never happen!");
// Unwind stack... // Unwind stack...
return result_type{}; return result_type{};
} }
...@@ -132,7 +131,6 @@ class scheduler::init_function_impl : public init_function { ...@@ -132,7 +131,6 @@ class scheduler::init_function_impl : public init_function {
explicit init_function_impl(F &function) : function_{function} {} explicit init_function_impl(F &function) : function_{function} {}
void run() override { void run() override {
scheduler::par([]() { scheduler::par([]() {
std::cout << "Dummy Strain, " << scheduling::thread_state::get().get_id() << std::endl;
return parallel_result<int>{0}; return parallel_result<int>{0};
}, [=]() { }, [=]() {
return function_(); return function_();
......
...@@ -59,11 +59,17 @@ class task_manager { ...@@ -59,11 +59,17 @@ class task_manager {
auto stolen_task_handle = task_deque_.pop_top(); auto stolen_task_handle = task_deque_.pop_top();
if (stolen_task_handle) { if (stolen_task_handle) {
base_task *stolen_task = (*stolen_task_handle).task_; base_task *stolen_task = (*stolen_task_handle).task_;
std::cout << "Nearly stole on core " << thread_state::get().get_id() << " task with depth " memory_block *stolen_task_memory = (*stolen_task_handle).task_memory_block_;
<< stolen_task->get_cont()->get_memory_block()->get_depth() << std::endl; auto stolen_task_depth = stolen_task_memory->get_depth();
auto &atomic_state = (*stolen_task_handle).task_memory_block_->get_state(); auto &atomic_state = stolen_task_memory->get_state();
auto &atomic_offered_chain = (*stolen_task_handle).task_memory_block_->get_offered_chain(); auto &atomic_offered_chain = stolen_task_memory->get_offered_chain();
auto offered_chain = stealing_cont_manager.get_node((*stolen_task_handle).task_memory_block_->get_depth());
// std::cout << "Nearly stole on core " << thread_state::get().get_id() << " task with depth "
// << stolen_task_depth << std::endl;
// Move our chain forward for stealing...
stealing_cont_manager.move_active_node(stolen_task_depth);
auto offered_chain = stealing_cont_manager.get_active_node();
if (offered_chain == (*stolen_task_handle).task_memory_block_) { if (offered_chain == (*stolen_task_handle).task_memory_block_) {
PLS_ASSERT(false, "How would we offer our own chain? We only offer when stealing!"); PLS_ASSERT(false, "How would we offer our own chain? We only offer when stealing!");
...@@ -71,6 +77,7 @@ class task_manager { ...@@ -71,6 +77,7 @@ class task_manager {
auto last_state = atomic_state.load(); auto last_state = atomic_state.load();
if (last_state.value != memory_block::initialized) { if (last_state.value != memory_block::initialized) {
stealing_cont_manager.move_active_node(-stolen_task_depth);
return nullptr; return nullptr;
} }
...@@ -86,12 +93,14 @@ class task_manager { ...@@ -86,12 +93,14 @@ class task_manager {
last_offered_chain = atomic_offered_chain.load(); last_offered_chain = atomic_offered_chain.load();
last_state = atomic_state.load(); last_state = atomic_state.load();
if (last_state != loop_state) { if (last_state != loop_state) {
stealing_cont_manager.move_active_node(-stolen_task_depth);
return nullptr; return nullptr;
} }
} }
if (atomic_state.compare_exchange_strong(loop_state, {loop_state.stamp + 1, memory_block::stolen})) { if (atomic_state.compare_exchange_strong(loop_state, {loop_state.stamp + 1, memory_block::stolen})) {
std::cout << "Steal!" << std::endl; // std::cout << "Steal!" << std::endl;
stealing_cont_manager.move_active_node(1);
return stolen_task; return stolen_task;
} else { } else {
return nullptr; return nullptr;
......
...@@ -69,7 +69,6 @@ void scheduler::work_thread_work_section() { ...@@ -69,7 +69,6 @@ void scheduler::work_thread_work_section() {
while (my_cont_manager.falling_through()) { while (my_cont_manager.falling_through()) {
my_cont_manager.execute_fall_through_code(); my_cont_manager.execute_fall_through_code();
} }
my_cont_manager.check_clean_chain();
// Steal Routine (will be continuously executed when there are no more fall through's). // Steal Routine (will be continuously executed when there are no more fall through's).
// TODO: move into separate function // TODO: move into separate function
...@@ -83,13 +82,11 @@ void scheduler::work_thread_work_section() { ...@@ -83,13 +82,11 @@ void scheduler::work_thread_work_section() {
auto &target_state = my_state.get_scheduler().thread_state_for(target); auto &target_state = my_state.get_scheduler().thread_state_for(target);
my_cont_manager.check_clean_chain(); PLS_ASSERT(my_cont_manager.get_active_node()->get_depth() == 0, "Only steal with clean chain!");
auto *stolen_task = target_state.get_task_manager().steal_remote_task(my_cont_manager); auto *stolen_task = target_state.get_task_manager().steal_remote_task(my_cont_manager);
if (stolen_task != nullptr) { if (stolen_task != nullptr) {
my_state.parent_cont_ = stolen_task->get_cont(); my_state.parent_cont_ = stolen_task->get_cont();
my_state.right_spawn_ = true; my_state.right_spawn_ = true;
my_cont_manager.set_active_depth(stolen_task->get_cont()->get_memory_block()->get_depth() + 1);
my_cont_manager.check_clean_chain();
stolen_task->execute(); stolen_task->execute();
if (my_cont_manager.falling_through()) { if (my_cont_manager.falling_through()) {
break; break;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment