Commit fba551da by FritzFlorian

Add first test version of high level task stealing.

This means that high level tasks can be stolen and lays the groundwork for implementing different tasks like classic work stealing.
parent e6565ef0
Pipeline #1105 passed with stages
in 2 minutes 9 seconds
...@@ -10,7 +10,8 @@ add_library(pls STATIC ...@@ -10,7 +10,8 @@ add_library(pls STATIC
src/internal/base/barrier.cpp include/pls/internal/base/barrier.h src/internal/base/barrier.cpp include/pls/internal/base/barrier.h
src/internal/scheduling/root_master_task.cpp include/pls/internal/scheduling/root_master_task.h src/internal/scheduling/root_master_task.cpp include/pls/internal/scheduling/root_master_task.h
src/internal/base/aligned_stack.cpp include/pls/internal/base/aligned_stack.h src/internal/base/aligned_stack.cpp include/pls/internal/base/aligned_stack.h
include/pls/internal/base/system_details.h) include/pls/internal/base/system_details.h
src/internal/scheduling/run_on_n_threads_task.cpp include/pls/internal/scheduling/run_on_n_threads_task.h)
# Add everything in `./include` to be in the include path of this project # Add everything in `./include` to be in the include path of this project
target_include_directories(pls target_include_directories(pls
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#define PLS_SPINLOCK_H #define PLS_SPINLOCK_H
#include <atomic> #include <atomic>
#include <iostream>
#include "pls/internal/base/thread.h" #include "pls/internal/base/thread.h"
...@@ -16,6 +17,9 @@ namespace pls { ...@@ -16,6 +17,9 @@ namespace pls {
public: public:
spin_lock(): flag_{ATOMIC_FLAG_INIT}, yield_at_tries_{1024} {}; spin_lock(): flag_{ATOMIC_FLAG_INIT}, yield_at_tries_{1024} {};
spin_lock(const spin_lock& other): flag_{ATOMIC_FLAG_INIT}, yield_at_tries_{other.yield_at_tries_} {
std::cout << "Spinlock Moved!" << std::endl;
}
void lock(); void lock();
void unlock(); void unlock();
......
...@@ -13,33 +13,24 @@ namespace pls { ...@@ -13,33 +13,24 @@ namespace pls {
int depth_; int depth_;
int unique_id_; int unique_id_;
abstract_task* child_task_; abstract_task* child_task_;
base::spin_lock spin_lock_;
public: public:
explicit abstract_task(int depth, int unique_id): abstract_task(int depth, int unique_id):
depth_{depth}, depth_{depth},
unique_id_{unique_id}, unique_id_{unique_id},
child_task_{nullptr}, child_task_{nullptr} {}
spin_lock_{} {};
virtual void execute() = 0; virtual void execute() = 0;
const base::spin_lock& spin_lock() { return spin_lock_; } void set_child(abstract_task* child_task) { child_task_ = child_task; }
abstract_task* child() { return child_task_; }
void set_depth(int depth) { depth_ = depth; }
int depth() { return depth_; }
protected: protected:
virtual bool my_stealing(abstract_task *other_task) = 0; virtual bool internal_stealing(abstract_task* other_task) = 0;
virtual bool split_task() = 0;
bool steal_work() {
// get scheduler bool steal_work();
// select victim
// try steal
// |-- see if same depth is available
// |-- see if equals depth + id
// |-- try user steal if matches (will return itself if it could steal)
// |-- try internal steal if deeper tasks are available
// |-- if internal steal worked, execute it
// return if the user steal was a success
return false;
};
}; };
} }
} }
......
...@@ -36,7 +36,11 @@ namespace pls { ...@@ -36,7 +36,11 @@ namespace pls {
} }
} }
bool my_stealing(abstract_task* /*other_task*/) override { bool internal_stealing(abstract_task* /*other_task*/) override {
return false;
}
bool split_task() override {
return false; return false;
} }
}; };
......
...@@ -22,7 +22,11 @@ namespace pls { ...@@ -22,7 +22,11 @@ namespace pls {
} while (!master_task_->finished()); } while (!master_task_->finished());
} }
bool my_stealing(abstract_task* /*other_task*/) override { bool internal_stealing(abstract_task* /*other_task*/) override {
return false;
}
bool split_task() override {
return false; return false;
} }
}; };
......
#ifndef PLS_RUN_ON_N_THREADS_TASK_H
#define PLS_RUN_ON_N_THREADS_TASK_H
#include <mutex>
#include "pls/internal/base/spin_lock.h"
#include "pls/internal/base/thread.h"
#include "abstract_task.h"
#include "thread_state.h"
#include "scheduler.h"
namespace pls {
namespace internal {
namespace scheduling {
template<typename Function>
class run_on_n_threads_task : public abstract_task {
template<typename F>
friend class run_on_n_threads_task_worker;
Function function_;
// Improvement: Remove lock and replace by atomic variable (performance)
int counter;
base::spin_lock counter_lock_;
int decrement_counter() {
std::lock_guard<base::spin_lock> lock{counter_lock_};
counter--;
return counter;
}
int get_counter() {
std::lock_guard<base::spin_lock> lock{counter_lock_};
return counter;
}
public:
run_on_n_threads_task(Function function, int num_threads):
abstract_task{PLS_UNIQUE_ID, 0},
function_{function},
counter{num_threads - 1} {}
void execute() override {
// Execute our function ONCE
function_();
// Steal until we are finished (other threads executed)
do {
steal_work();
} while (get_counter() > 0);
std::cout << "Finished Master!" << std::endl;
}
bool internal_stealing(abstract_task* /*other_task*/) override {
return false;
}
bool split_task() override;
};
template<typename Function>
class run_on_n_threads_task_worker : public abstract_task {
Function function_;
run_on_n_threads_task<Function>* root_;
public:
run_on_n_threads_task_worker(Function function, run_on_n_threads_task<Function>* root):
abstract_task{PLS_UNIQUE_ID, 0},
function_{function},
root_{root} {}
void execute() override {
if (root_->decrement_counter() >= 0) {
function_();
std::cout << "Finished Worker!" << std::endl;
} else {
std::cout << "Abandoned Worker!" << std::endl;
}
}
bool internal_stealing(abstract_task* /*other_task*/) override {
return false;
}
bool split_task() override {
return false;
}
};
template<typename Function>
bool run_on_n_threads_task<Function>::split_task() {
if (get_counter() <= 0) {
return false;
}
auto scheduler = base::this_thread::state<thread_state>()->scheduler_;
auto task = run_on_n_threads_task_worker<Function>{function_, this};
scheduler->execute_task(task, depth());
return true;
}
template<typename Function>
run_on_n_threads_task<Function> create_run_on_n_threads_task(Function function, int num_threads) {
return run_on_n_threads_task<Function>{function, num_threads};
}
}
}
}
#endif //PLS_RUN_ON_N_THREADS_TASK_H
...@@ -8,8 +8,8 @@ ...@@ -8,8 +8,8 @@
#include "pls/internal/base/aligned_stack.h" #include "pls/internal/base/aligned_stack.h"
#include "pls/internal/base/thread.h" #include "pls/internal/base/thread.h"
#include "pls/internal/base/barrier.h" #include "pls/internal/base/barrier.h"
#include "pls/internal/scheduling/thread_state.h"
#include "thread_state.h"
#include "root_master_task.h" #include "root_master_task.h"
#include "root_worker_task.h" #include "root_worker_task.h"
...@@ -54,7 +54,7 @@ namespace pls { ...@@ -54,7 +54,7 @@ namespace pls {
class scheduler { class scheduler {
friend void worker_routine(); friend void worker_routine();
unsigned int num_threads_; const unsigned int num_threads_;
scheduler_memory* memory_; scheduler_memory* memory_;
base::barrier sync_barrier_; base::barrier sync_barrier_;
...@@ -69,38 +69,48 @@ namespace pls { ...@@ -69,38 +69,48 @@ namespace pls {
root_worker_task<Function> worker{&master}; root_worker_task<Function> worker{&master};
// Push root task on stacks // Push root task on stacks
memory_->thread_state_for(0)->root_task_ = memory_->task_stack_for(0)->push(master); memory_->thread_state_for(0)->root_task_ = &master;
memory_->thread_state_for(0)->current_task_ = &master;
for (unsigned int i = 1; i < num_threads_; i++) { for (unsigned int i = 1; i < num_threads_; i++) {
memory_->thread_state_for(i)->root_task_ = memory_->task_stack_for(i)->push(worker); memory_->thread_state_for(i)->root_task_ = &worker;
memory_->thread_state_for(i)->current_task_ = &worker;
} }
// Perform and wait for work // Perform and wait for work
sync_barrier_.wait(); // Trigger threads to wake up sync_barrier_.wait(); // Trigger threads to wake up
sync_barrier_.wait(); // Wait for threads to finish sync_barrier_.wait(); // Wait for threads to finish
// Remove root task from stacks
memory_->task_stack_for(0)->pop<typeof(master)>();
for (unsigned int i = 1; i < num_threads_; i++) {
memory_->task_stack_for(i)->pop<typeof(worker)>();
}
} }
// TODO: See if we should place this differently (only for performance reasons) // TODO: See if we should place this differently (only for performance reasons)
template<typename Task> template<typename Task>
void execute_task(Task& task) { static void execute_task(Task task, int depth=-1) {
static_assert(std::is_base_of<abstract_task, Task>::value, "Only pass abstract_task subclasses!"); static_assert(std::is_base_of<abstract_task, Task>::value, "Only pass abstract_task subclasses!");
auto my_state = base::this_thread::state<thread_state>(); auto my_state = base::this_thread::state<thread_state>();
auto task_stack = my_state->task_stack_; auto current_task = my_state->current_task_;
// Init Task
{
std::lock_guard<base::spin_lock> lock{my_state->lock_};
task.set_depth(depth >= 0 ? depth : current_task->depth() + 1);
my_state->current_task_ = &task;
current_task->set_child(&task);
}
// Run Task
task.execute();
// TODO: Assert if 'top level' task even have to go somewhere or if // Teardown state back to before the task was executed
// we can simply keep the on the call stack. {
auto my_task = task_stack->push(task); std::lock_guard<base::spin_lock> lock{my_state->lock_};
my_task.execute(); current_task->set_child(nullptr);
task_stack->pop<Task>(); my_state->current_task_ = current_task;
}
} }
void terminate(bool wait_for_workers=true); void terminate(bool wait_for_workers=true);
unsigned int num_threads() const { return num_threads_; }
thread_state* thread_state_for(size_t id) { return memory_->thread_state_for(id); }
}; };
} }
} }
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#define PLS_THREAD_STATE_H #define PLS_THREAD_STATE_H
#include "abstract_task.h" #include "abstract_task.h"
#include "pls/internal/base/aligned_stack.h"
namespace pls { namespace pls {
namespace internal { namespace internal {
...@@ -11,15 +12,43 @@ namespace pls { ...@@ -11,15 +12,43 @@ namespace pls {
class scheduler; class scheduler;
struct thread_state { struct thread_state {
thread_state(): scheduler_{nullptr}, root_task_{nullptr}, task_stack_{nullptr} {};
explicit thread_state(scheduler* scheduler, base::aligned_stack* task_stack):
scheduler_{scheduler},
root_task_{nullptr},
task_stack_{task_stack} {}
scheduler* scheduler_; scheduler* scheduler_;
abstract_task* root_task_; abstract_task* root_task_;
abstract_task* current_task_;
base::aligned_stack* task_stack_; base::aligned_stack* task_stack_;
unsigned int id_;
base::spin_lock lock_;
thread_state():
scheduler_{nullptr},
root_task_{nullptr},
current_task_{nullptr},
task_stack_{nullptr},
id_{0} {};
thread_state(scheduler* scheduler, base::aligned_stack* task_stack, unsigned int id):
scheduler_{scheduler},
root_task_{nullptr},
current_task_{nullptr},
task_stack_{task_stack},
id_{id} {}
thread_state(const thread_state& other):
scheduler_{other.scheduler_},
root_task_{other.root_task_},
current_task_{other.current_task_},
task_stack_{other.task_stack_},
id_{other.id_} {}
thread_state& operator=(const thread_state& other) {
scheduler_ = other.scheduler_;
root_task_ = other.root_task_;
current_task_ = other.current_task_;
task_stack_ = other.task_stack_;
id_ = other.id_;
return *this;
}
}; };
} }
} }
......
#include "pls/internal/scheduling/thread_state.h"
#include "pls/internal/scheduling/abstract_task.h" #include "pls/internal/scheduling/abstract_task.h"
#include "pls/internal/scheduling/scheduler.h"
namespace pls { namespace pls {
namespace internal { namespace internal {
namespace scheduling { namespace scheduling {
bool abstract_task::steal_work() {
auto my_state = base::this_thread::state<thread_state>();
auto my_scheduler = my_state->scheduler_;
int my_id = my_state->id_;
for (size_t i = 1; i < my_scheduler->num_threads(); i++) {
size_t target = (my_id + i) % my_scheduler->num_threads();
auto target_state = my_scheduler->thread_state_for(target);
std::lock_guard<base::spin_lock> lock{target_state->lock_};
// Dig down to our level
abstract_task* current_task = target_state->root_task_;
while (current_task != nullptr && current_task->depth() < depth()) {
current_task = current_task->child_task_;
}
if (current_task != nullptr) {
// See if it equals our type and depth of task
if (current_task->unique_id_ == unique_id_ &&
current_task->depth_ == depth_) {
if (internal_stealing(current_task)) {
// internal steal was a success, hand it back to the internal scheduler
return true;
}
// No success, we need to steal work from a deeper level using 'top level task stealing'
current_task = current_task->child_task_;
}
}
// Execute 'top level task steal' if possible
// (only try deeper tasks to keep depth restricted stealing)
while (current_task != nullptr) {
if (current_task->split_task()) {
// internal steal was no success (we did a top level task steal)
return false;
}
current_task = current_task->child_task_;
}
}
// internal steal was no success
return false;
};
} }
} }
} }
#include "pls/internal/scheduling/run_on_n_threads_task.h"
namespace pls {
namespace internal {
namespace scheduling {
}
}
}
...@@ -13,7 +13,7 @@ namespace pls { ...@@ -13,7 +13,7 @@ namespace pls {
} }
for (unsigned int i = 0; i < num_threads; i++) { for (unsigned int i = 0; i < num_threads; i++) {
*memory_->thread_state_for(i) = thread_state{this, memory_->task_stack_for(i)}; *memory_->thread_state_for(i) = thread_state{this, memory_->task_stack_for(i), i};
*memory_->thread_for(i) = base::start_thread(&worker_routine, memory_->thread_state_for(i)); *memory_->thread_for(i) = base::start_thread(&worker_routine, memory_->thread_state_for(i));
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment