Commit 10ca31dc by FritzFlorian

First working version of pure fork-join based scheduler.

parent 374153ce
Pipeline #1244 failed with stages
in 29 seconds
...@@ -19,15 +19,14 @@ int count_child_nodes(uts::node &node) { ...@@ -19,15 +19,14 @@ int count_child_nodes(uts::node &node) {
return child_count; return child_count;
} }
auto current_task = pls::task::current();
std::vector<int> results(children.size()); std::vector<int> results(children.size());
for (size_t i = 0; i < children.size(); i++) { for (size_t i = 0; i < children.size(); i++) {
size_t index = i; size_t index = i;
auto lambda = [&, index] { results[index] = count_child_nodes(children[index]); }; auto lambda = [&, index] { results[index] = count_child_nodes(children[index]); };
pls::lambda_task_by_value<typeof(lambda)> sub_task(lambda); pls::lambda_task_by_value<typeof(lambda)> sub_task(lambda);
current_task->spawn_child(sub_task); pls::scheduler::spawn_child(sub_task);
} }
current_task->wait_for_all(); pls::scheduler::wait_for_all();
for (auto result : results) { for (auto result : results) {
child_count += result; child_count += result;
} }
...@@ -36,43 +35,41 @@ int count_child_nodes(uts::node &node) { ...@@ -36,43 +35,41 @@ int count_child_nodes(uts::node &node) {
} }
int unbalanced_tree_search(int seed, int root_children, double q, int normal_children) { int unbalanced_tree_search(int seed, int root_children, double q, int normal_children) {
static auto id = pls::unique_id::create(42);
int result; int result;
auto lambda = [&] { auto lambda = [&] {
uts::node root(seed, root_children, q, normal_children); uts::node root(seed, root_children, q, normal_children);
result = count_child_nodes(root); result = count_child_nodes(root);
}; };
pls::lambda_task_by_reference<typeof(lambda)> task(lambda);
pls::lambda_task_by_reference<typeof(lambda)> sub_task(lambda); pls::lambda_task_by_reference<typeof(lambda)> sub_task(lambda);
pls::task root_task{&sub_task, id}; pls::scheduler::spawn_child(sub_task);
pls::scheduler::execute_task(root_task); pls::scheduler::wait_for_all();
return result; return result;
} }
//
//int main() {
// PROFILE_ENABLE
// pls::internal::helpers::run_mini_benchmark([&] {
// unbalanced_tree_search(SEED, ROOT_CHILDREN, Q, NORMAL_CHILDREN);
// }, 8, 4000);
//
// PROFILE_SAVE("test_profile.prof")
//}
int main() { int main() {
PROFILE_ENABLE PROFILE_ENABLE
pls::malloc_scheduler_memory my_scheduler_memory{8, 2u << 18}; pls::internal::helpers::run_mini_benchmark([&] {
pls::scheduler scheduler{&my_scheduler_memory, 8}; unbalanced_tree_search(SEED, ROOT_CHILDREN, Q, NORMAL_CHILDREN);
}, 8, 2000);
scheduler.perform_work([&] {
PROFILE_MAIN_THREAD
for (int i = 0; i < 50; i++) {
PROFILE_WORK_BLOCK("Top Level")
int result = unbalanced_tree_search(SEED, ROOT_CHILDREN, Q, NORMAL_CHILDREN);
std::cout << result << std::endl;
}
});
PROFILE_SAVE("test_profile.prof") PROFILE_SAVE("test_profile.prof")
} }
//int main() {
// PROFILE_ENABLE
// pls::malloc_scheduler_memory my_scheduler_memory{8, 2u << 18};
// pls::scheduler scheduler{&my_scheduler_memory, 8};
//
// scheduler.perform_work([&] {
// PROFILE_MAIN_THREAD
// for (int i = 0; i < 50; i++) {
// PROFILE_WORK_BLOCK("Top Level")
// int result = unbalanced_tree_search(SEED, ROOT_CHILDREN, Q, NORMAL_CHILDREN);
// std::cout << result << std::endl;
// }
// });
//
// PROFILE_SAVE("test_profile.prof")
//}
...@@ -91,7 +91,7 @@ int main() { ...@@ -91,7 +91,7 @@ int main() {
PROFILE_MAIN_THREAD PROFILE_MAIN_THREAD
// Call looks just the same, only requirement is // Call looks just the same, only requirement is
// the enclosure in the perform_work lambda. // the enclosure in the perform_work lambda.
for (int i = 0; i < 1000; i++) { for (int i = 0; i < 10; i++) {
PROFILE_WORK_BLOCK("Top Level FFT") PROFILE_WORK_BLOCK("Top Level FFT")
complex_vector input = initial_input; complex_vector input = initial_input;
fft(input.begin(), input.size()); fft(input.begin(), input.size());
......
#ifndef PLS_LAMBDA_TASK_H_
#define PLS_LAMBDA_TASK_H_
#include "pls/internal/scheduling/task.h"
namespace pls {
namespace internal {
namespace scheduling {
template<typename Function>
class lambda_task_by_reference : public task {
const Function &function_;
public:
explicit lambda_task_by_reference(const Function &function) : task{}, function_{function} {};
protected:
void execute_internal() override {
function_();
}
};
template<typename Function>
class lambda_task_by_value : public task {
const Function function_;
public:
explicit lambda_task_by_value(const Function &function) : task{}, function_{function} {};
protected:
void execute_internal() override {
function_();
}
};
}
}
}
#endif //PLS_LAMBDA_TASK_H_
...@@ -22,12 +22,23 @@ namespace scheduling { ...@@ -22,12 +22,23 @@ namespace scheduling {
using scheduler_thread = base::thread<decltype(&worker_routine), thread_state>; using scheduler_thread = base::thread<decltype(&worker_routine), thread_state>;
/**
* The scheduler is the central part of the dispatching-framework.
* It manages a pool of worker threads (creates, sleeps/wakes up, destroys)
* and allows to execute parallel sections.
*
* It works in close rellation with the 'task' class for scheduling.
*/
class scheduler { class scheduler {
friend class task; friend class task;
const unsigned int num_threads_; const unsigned int num_threads_;
scheduler_memory *memory_; scheduler_memory *memory_;
base::barrier sync_barrier_; base::barrier sync_barrier_;
task *main_thread_root_task_;
bool work_section_done_;
bool terminated_; bool terminated_;
public: public:
/** /**
...@@ -85,6 +96,9 @@ class scheduler { ...@@ -85,6 +96,9 @@ class scheduler {
task *get_local_task(); task *get_local_task();
task *steal_task(); task *steal_task();
bool try_execute_local();
bool try_execute_stolen();
}; };
} }
......
...@@ -2,35 +2,30 @@ ...@@ -2,35 +2,30 @@
#ifndef PLS_SCHEDULER_IMPL_H #ifndef PLS_SCHEDULER_IMPL_H
#define PLS_SCHEDULER_IMPL_H #define PLS_SCHEDULER_IMPL_H
#include "pls/internal/scheduling/lambda_task.h"
namespace pls { namespace pls {
namespace internal { namespace internal {
namespace scheduling { namespace scheduling {
// TODO: generally look into the performance implications of using many thread_state::get() calls
template<typename Function> template<typename Function>
void scheduler::perform_work(Function work_section) { void scheduler::perform_work(Function work_section) {
PROFILE_WORK_BLOCK("scheduler::perform_work") PROFILE_WORK_BLOCK("scheduler::perform_work")
// root_task<Function> master{work_section};
// // if (execute_main_thread) {
// // Push root task on stacks // work_section();
// auto new_master = memory_->task_stack_for(0)->push(master);
// memory_->thread_state_for(0)->root_task_ = new_master;
// memory_->thread_state_for(0)->current_task_ = new_master;
// for (unsigned int i = 1; i < num_threads_; i++) {
// root_worker_task<Function> worker{new_master};
// auto new_worker = memory_->task_stack_for(0)->push(worker);
// memory_->thread_state_for(i)->root_task_ = new_worker;
// memory_->thread_state_for(i)->current_task_ = new_worker;
// }
// //
// // Perform and wait for work
// sync_barrier_.wait(); // Trigger threads to wake up // sync_barrier_.wait(); // Trigger threads to wake up
// sync_barrier_.wait(); // Wait for threads to finish // sync_barrier_.wait(); // Wait for threads to finish
// // } else {
// // Clean up stack lambda_task_by_reference<Function> root_task{work_section};
// memory_->task_stack_for(0)->pop<typeof(master)>(); main_thread_root_task_ = &root_task;
// for (unsigned int i = 1; i < num_threads_; i++) { work_section_done_ = false;
// root_worker_task<Function> worker{new_master};
// memory_->task_stack_for(0)->pop<typeof(worker)>(); sync_barrier_.wait(); // Trigger threads to wake up
sync_barrier_.wait(); // Wait for threads to finish
// } // }
} }
...@@ -39,12 +34,6 @@ void scheduler::spawn_child(T &sub_task) { ...@@ -39,12 +34,6 @@ void scheduler::spawn_child(T &sub_task) {
thread_state::get()->current_task_->spawn_child(sub_task); thread_state::get()->current_task_->spawn_child(sub_task);
} }
void scheduler::wait_for_all() {
thread_state::get()->current_task_->wait_for_all();
}
thread_state *scheduler::thread_state_for(size_t id) { return memory_->thread_state_for(id); }
} }
} }
} }
......
...@@ -39,14 +39,12 @@ class task { ...@@ -39,14 +39,12 @@ class task {
private: private:
void execute(); void execute();
bool try_execute_local();
bool try_execute_stolen();
}; };
template<typename T> template<typename T>
void task::spawn_child(T &sub_task) { void task::spawn_child(T &sub_task) {
PROFILE_FORK_JOIN_STEALING("spawn_child") PROFILE_FORK_JOIN_STEALING("spawn_child")
static_assert(std::is_base_of<T, task>::value, "Only pass task subclasses!"); static_assert(std::is_base_of<task, T>::value, "Only pass task subclasses!");
// Keep our refcount up to date // Keep our refcount up to date
ref_count_++; ref_count_++;
......
...@@ -19,7 +19,6 @@ class task; ...@@ -19,7 +19,6 @@ class task;
struct thread_state { struct thread_state {
alignas(base::system_details::CACHE_LINE_SIZE) scheduler *scheduler_; alignas(base::system_details::CACHE_LINE_SIZE) scheduler *scheduler_;
alignas(base::system_details::CACHE_LINE_SIZE) task *root_task_;
alignas(base::system_details::CACHE_LINE_SIZE) task *current_task_; alignas(base::system_details::CACHE_LINE_SIZE) task *current_task_;
alignas(base::system_details::CACHE_LINE_SIZE) data_structures::aligned_stack *task_stack_; alignas(base::system_details::CACHE_LINE_SIZE) data_structures::aligned_stack *task_stack_;
alignas(base::system_details::CACHE_LINE_SIZE) data_structures::work_stealing_deque<task> deque_; alignas(base::system_details::CACHE_LINE_SIZE) data_structures::work_stealing_deque<task> deque_;
...@@ -28,7 +27,6 @@ struct thread_state { ...@@ -28,7 +27,6 @@ struct thread_state {
thread_state() : thread_state() :
scheduler_{nullptr}, scheduler_{nullptr},
root_task_{nullptr},
current_task_{nullptr}, current_task_{nullptr},
task_stack_{nullptr}, task_stack_{nullptr},
deque_{task_stack_}, deque_{task_stack_},
...@@ -37,7 +35,6 @@ struct thread_state { ...@@ -37,7 +35,6 @@ struct thread_state {
thread_state(scheduler *scheduler, data_structures::aligned_stack *task_stack, unsigned int id) : thread_state(scheduler *scheduler, data_structures::aligned_stack *task_stack, unsigned int id) :
scheduler_{scheduler}, scheduler_{scheduler},
root_task_{nullptr},
current_task_{nullptr}, current_task_{nullptr},
task_stack_{task_stack}, task_stack_{task_stack},
deque_{task_stack_}, deque_{task_stack_},
......
...@@ -30,19 +30,37 @@ scheduler::~scheduler() { ...@@ -30,19 +30,37 @@ scheduler::~scheduler() {
} }
void scheduler::worker_routine() { void scheduler::worker_routine() {
auto my_state = base::this_thread::state<thread_state>(); auto my_state = thread_state::get();
auto scheduler = my_state->scheduler_;
while (true) { while (true) {
my_state->scheduler_->sync_barrier_.wait(); // Wait to be triggered
if (my_state->scheduler_->terminated_) { scheduler->sync_barrier_.wait();
// Check for shutdown
if (scheduler->terminated_) {
return; return;
} }
// The root task must only return when all work is done, // Execute work
// because of this a simple call is enough to ensure the if (my_state->id_ == 0) {
// fork-join-section is done (logically joined back into our main thread). // Main Thread
my_state->root_task_->execute(); auto root_task = scheduler->main_thread_root_task_;
root_task->parent_ = nullptr;
root_task->deque_state_ = my_state->deque_.save_state();
root_task->execute();
scheduler->work_section_done_ = true;
} else {
// Worker Threads
while (!scheduler->work_section_done_) {
if (!scheduler->try_execute_local()) {
scheduler->try_execute_stolen();
}
}
}
// Sync back with main thread
my_state->scheduler_->sync_barrier_.wait(); my_state->scheduler_->sync_barrier_.wait();
} }
} }
...@@ -100,6 +118,33 @@ task *scheduler::steal_task() { ...@@ -100,6 +118,33 @@ task *scheduler::steal_task() {
return nullptr; return nullptr;
} }
bool scheduler::try_execute_local() {
task *local_task = get_local_task();
if (local_task != nullptr) {
local_task->execute();
return true;
} else {
return false;
}
}
bool scheduler::try_execute_stolen() {
task *stolen_task = steal_task();
if (stolen_task != nullptr) {
stolen_task->deque_state_ = thread_state::get()->deque_.save_state();
stolen_task->execute();
return true;
}
return false;
}
void scheduler::wait_for_all() {
thread_state::get()->current_task_->wait_for_all();
}
thread_state *scheduler::thread_state_for(size_t id) { return memory_->thread_state_for(id); }
} }
} }
} }
...@@ -36,31 +36,12 @@ void task::execute() { ...@@ -36,31 +36,12 @@ void task::execute() {
} }
} }
bool task::try_execute_local() {
task *local_task = thread_state::get()->scheduler_->get_local_task();
if (local_task != nullptr) {
local_task->execute();
return true;
} else {
return false;
}
}
bool task::try_execute_stolen() {
task *stolen_task = thread_state::get()->scheduler_->steal_task();
if (stolen_task != nullptr) {
stolen_task->deque_state_ = thread_state::get()->deque_.save_state();
stolen_task->execute();
return true;
}
return false;
}
void task::wait_for_all() { void task::wait_for_all() {
auto scheduler = thread_state::get()->scheduler_;
while (ref_count_ > 0) { while (ref_count_ > 0) {
if (!try_execute_local()) { if (!scheduler->try_execute_local()) {
try_execute_stolen(); scheduler->try_execute_stolen();
} }
} }
thread_state::get()->deque_.release_memory_until(deque_state_); thread_state::get()->deque_.release_memory_until(deque_state_);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment