Commit 7d090b3c by FritzFlorian

Add support for spawn-and-sync statement for not wasting steals.

parent 1d1b5185
Pipeline #1518 passed with stages
in 4 minutes 36 seconds
...@@ -32,13 +32,12 @@ static void for_each(const RandomIt first, ...@@ -32,13 +32,12 @@ static void for_each(const RandomIt first,
function, function,
min_elements); min_elements);
}); });
scheduler::spawn([first, middle_index, last, &function, min_elements] { scheduler::spawn_and_sync([first, middle_index, last, &function, min_elements] {
internal::for_each(first + middle_index, internal::for_each(first + middle_index,
last, last,
function, function,
min_elements); min_elements);
}); });
scheduler::sync();
} }
} }
......
...@@ -37,14 +37,13 @@ static Element reduce(const RandomIt first, ...@@ -37,14 +37,13 @@ static Element reduce(const RandomIt first,
reducer, reducer,
min_elements); min_elements);
}); });
scheduler::spawn([first, middle_index, last, neutral, &reducer, min_elements, &right] { scheduler::spawn_and_sync([first, middle_index, last, neutral, &reducer, min_elements, &right] {
right = internal::reduce<RandomIt, Function, Element>(first + middle_index, right = internal::reduce<RandomIt, Function, Element>(first + middle_index,
last, last,
neutral, neutral,
reducer, reducer,
min_elements); min_elements);
}); });
scheduler::sync();
return reducer(left, right); return reducer(left, right);
} }
} }
......
...@@ -86,7 +86,11 @@ class scheduler { ...@@ -86,7 +86,11 @@ class scheduler {
#ifdef PLS_SERIAL_ELUSION #ifdef PLS_SERIAL_ELUSION
lambda(); lambda();
#else #else
spawn_internal(std::forward<Function>(lambda)); if (thread_state::is_scheduler_active()) {
spawn_internal(std::forward<Function>(lambda));
} else {
lambda();
}
#endif #endif
} }
...@@ -97,7 +101,29 @@ class scheduler { ...@@ -97,7 +101,29 @@ class scheduler {
#ifdef PLS_SERIAL_ELUSION #ifdef PLS_SERIAL_ELUSION
return; return;
#else #else
sync_internal(); if (thread_state::is_scheduler_active()) {
sync_internal();
} else {
return;
}
#endif
}
/**
* Equivialent to calling spawn(lambda); sync();
*
* Faster than the direct notation, as stealing the continuation before the sync is not useful and only overhead.
*/
template<typename Function>
static void spawn_and_sync(Function &&lambda) {
#ifdef PLS_SERIAL_ELUSION
lambda();
#else
if (thread_state::is_scheduler_active()) {
spawn_and_sync_internal(std::forward<Function>(lambda));
} else {
lambda();
}
#endif #endif
} }
...@@ -110,7 +136,11 @@ class scheduler { ...@@ -110,7 +136,11 @@ class scheduler {
#ifdef PLS_SERIAL_ELUSION #ifdef PLS_SERIAL_ELUSION
lambda(); lambda();
#else #else
serial_internal(std::forward<Function>(lambda)); if (thread_state::is_scheduler_active()) {
serial_internal(std::forward<Function>(lambda));
} else {
lambda();
}
#endif #endif
} }
...@@ -138,6 +168,8 @@ class scheduler { ...@@ -138,6 +168,8 @@ class scheduler {
private: private:
template<typename Function> template<typename Function>
static void spawn_internal(Function &&lambda); static void spawn_internal(Function &&lambda);
template<typename Function>
static void spawn_and_sync_internal(Function &&lambda);
static void sync_internal(); static void sync_internal();
template<typename Function> template<typename Function>
static void serial_internal(Function &&lambda); static void serial_internal(Function &&lambda);
......
...@@ -33,7 +33,7 @@ scheduler::scheduler(unsigned int num_threads, ...@@ -33,7 +33,7 @@ scheduler::scheduler(unsigned int num_threads,
stack_allocator_{std::make_shared<ALLOC>(std::forward<ALLOC>(stack_allocator))}, stack_allocator_{std::make_shared<ALLOC>(std::forward<ALLOC>(stack_allocator))},
serial_stack_size_{serial_stack_size} serial_stack_size_{serial_stack_size}
#if PLS_PROFILING_ENABLED #if PLS_PROFILING_ENABLED
, profiler_{num_threads} , profiler_{num_threads}
#endif #endif
{ {
...@@ -290,25 +290,119 @@ void scheduler::spawn_internal(Function &&lambda) { ...@@ -290,25 +290,119 @@ void scheduler::spawn_internal(Function &&lambda) {
} }
template<typename Function> template<typename Function>
void scheduler::serial_internal(Function &&lambda) { void scheduler::spawn_and_sync_internal(Function &&lambda) {
if (thread_state::is_scheduler_active()) { thread_state &spawning_state = thread_state::get();
thread_state &spawning_state = thread_state::get();
base_task *active_task = spawning_state.get_active_task();
if (active_task->is_serial_section_) { base_task *last_task = spawning_state.get_active_task();
lambda(); base_task *spawned_task = last_task->next_;
// We are at the end of our allocated tasks, go over to serial execution.
if (spawned_task == nullptr) {
scheduler::serial_internal(std::forward<Function>(lambda));
return;
}
// We are in a serial section.
// For now we simply stay serial. Later versions could add nested parallel sections.
if (last_task->is_serial_section_) {
lambda();
return;
}
// Carry over the resources allocated by parents
auto *attached_resources = last_task->attached_resources_.load(std::memory_order_relaxed);
spawned_task->attached_resources_.store(attached_resources, std::memory_order_relaxed);
#if PLS_PROFILING_ENABLED
spawning_state.get_scheduler().profiler_.task_prepare_stack_measure(spawning_state.get_thread_id(),
spawned_task->stack_memory_,
spawned_task->stack_size_);
auto *child_dag_node = spawning_state.get_scheduler().profiler_.task_spawn_child(spawning_state.get_thread_id(),
last_task->profiling_node_);
spawned_task->profiling_node_ = child_dag_node;
#endif
auto continuation = spawned_task->run_as_task([last_task, spawned_task, lambda, &spawning_state](auto cont) {
// allow stealing threads to continue the last task.
last_task->continuation_ = std::move(cont);
// we are now executing the new task, allow others to steal the last task continuation.
spawned_task->is_synchronized_ = true;
spawning_state.set_active_task(spawned_task);
// execute the lambda itself, which could lead to a different thread returning.
#if PLS_PROFILING_ENABLED
spawning_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(),
last_task->stack_memory_,
last_task->stack_size_,
last_task->profiling_node_);
syncing_state.get_scheduler().profiler_.task_stop_running(spawning_state.get_thread_id(),
last_task->profiling_node_);
auto *next_dag_node =
spawning_state.get_scheduler().profiler_.task_sync(spawning_state.get_thread_id(), last_task->profiling_node_);
last_task->profiling_node_ = next_dag_node;
spawning_state.get_scheduler().profiler_.task_start_running(spawning_state.get_thread_id(),
spawned_task->profiling_node_);
#endif
lambda();
thread_state &syncing_state = thread_state::get();
PLS_ASSERT(syncing_state.get_active_task() == spawned_task,
"Task manager must always point its active task onto whats executing.");
if (&syncing_state == &spawning_state && last_task->is_synchronized_) {
// Fast path, simply continue execution where we left of before spawn.
PLS_ASSERT(last_task->continuation_.valid(),
"Fast path, no one can have continued working on the last task.");
syncing_state.set_active_task(last_task);
#if PLS_PROFILING_ENABLED
syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(),
spawned_task->stack_memory_,
spawned_task->stack_size_,
spawned_task->profiling_node_);
syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(),
spawned_task->profiling_node_);
#endif
return std::move(last_task->continuation_);
} else { } else {
active_task->is_serial_section_ = true; // Slow path, the last task was stolen. This path is common to sync() events.
context_switcher::enter_context(spawning_state.get_serial_call_stack(), #if PLS_PROFILING_ENABLED
spawning_state.get_serial_call_stack_size(), syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(),
[&](auto cont) { spawned_task->stack_memory_,
lambda(); spawned_task->stack_size_,
return cont; spawned_task->profiling_node_);
}); syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(),
active_task->is_serial_section_ = false; spawned_task->profiling_node_);
#endif
auto continuation = slow_return(syncing_state, false);
return continuation;
} }
} else { });
#if PLS_PROFILING_ENABLED
thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(),
last_task->profiling_node_);
#endif
PLS_ASSERT(!continuation.valid(), "We must not jump to a not-published (synced) spawn.");
}
template<typename Function>
void scheduler::serial_internal(Function &&lambda) {
thread_state &spawning_state = thread_state::get();
base_task *active_task = spawning_state.get_active_task();
if (active_task->is_serial_section_) {
lambda(); lambda();
} else {
active_task->is_serial_section_ = true;
context_switcher::enter_context(spawning_state.get_serial_call_stack(),
spawning_state.get_serial_call_stack_size(),
[&](auto cont) {
lambda();
return cont;
});
active_task->is_serial_section_ = false;
} }
} }
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "pls/algorithms/invoke.h" #include "pls/algorithms/invoke.h"
#include "pls/algorithms/for_each.h" #include "pls/algorithms/for_each.h"
#include "pls/algorithms/reduce.h" #include "pls/algorithms/reduce.h"
#include "pls/algorithms/loop_partition_strategy.h"
#include "pls/internal/scheduling/scheduler.h" #include "pls/internal/scheduling/scheduler.h"
#include "pls/internal/scheduling/strain_local_resource.h" #include "pls/internal/scheduling/strain_local_resource.h"
...@@ -21,6 +22,10 @@ template<typename Function> ...@@ -21,6 +22,10 @@ template<typename Function>
static void spawn(Function &&function) { static void spawn(Function &&function) {
scheduler::spawn(std::forward<Function>(function)); scheduler::spawn(std::forward<Function>(function));
} }
template<typename Function>
static void spawn_and_sync(Function &&function) {
scheduler::spawn_and_sync(std::forward<Function>(function));
}
static void sync() { static void sync() {
scheduler::sync(); scheduler::sync();
} }
...@@ -42,6 +47,9 @@ using algorithm::invoke; ...@@ -42,6 +47,9 @@ using algorithm::invoke;
using algorithm::for_each; using algorithm::for_each;
using algorithm::for_each_range; using algorithm::for_each_range;
using algorithm::reduce; using algorithm::reduce;
using algorithm::dynamic_strategy;
using algorithm::fixed_strategy;
} }
#endif #endif
...@@ -155,50 +155,45 @@ void scheduler::work_thread_work_section() { ...@@ -155,50 +155,45 @@ void scheduler::work_thread_work_section() {
} }
void scheduler::sync_internal() { void scheduler::sync_internal() {
if (thread_state::is_scheduler_active()) { thread_state &syncing_state = thread_state::get();
thread_state &syncing_state = thread_state::get();
base_task *active_task = syncing_state.get_active_task(); base_task *active_task = syncing_state.get_active_task();
base_task *spawned_task = active_task->next_; base_task *spawned_task = active_task->next_;
#if PLS_PROFILING_ENABLED #if PLS_PROFILING_ENABLED
syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(), syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(),
active_task->stack_memory_, active_task->stack_memory_,
active_task->stack_size_, active_task->stack_size_,
active_task->profiling_node_); active_task->profiling_node_);
syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(), syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(),
active_task->profiling_node_); active_task->profiling_node_);
auto *next_dag_node = auto *next_dag_node =
syncing_state.get_scheduler().profiler_.task_sync(syncing_state.get_thread_id(), active_task->profiling_node_); syncing_state.get_scheduler().profiler_.task_sync(syncing_state.get_thread_id(), active_task->profiling_node_);
active_task->profiling_node_ = next_dag_node; active_task->profiling_node_ = next_dag_node;
#endif #endif
if (active_task->is_synchronized_) { if (active_task->is_synchronized_) {
#if PLS_PROFILING_ENABLED #if PLS_PROFILING_ENABLED
thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(), thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(),
thread_state::get().get_active_task()->profiling_node_); thread_state::get().get_active_task()->profiling_node_);
#endif #endif
return; // We are already the sole owner of last_task return; // We are already the sole owner of last_task
} else { } else {
auto continuation = auto continuation =
spawned_task->run_as_task([active_task, spawned_task, &syncing_state](context_switcher::continuation cont) { spawned_task->run_as_task([active_task, spawned_task, &syncing_state](context_switcher::continuation cont) {
active_task->continuation_ = std::move(cont); active_task->continuation_ = std::move(cont);
syncing_state.set_active_task(spawned_task); syncing_state.set_active_task(spawned_task);
return slow_return(syncing_state, true); return slow_return(syncing_state, true);
}); });
PLS_ASSERT(!continuation.valid(), PLS_ASSERT(!continuation.valid(),
"We only return to a sync point, never jump to it directly." "We only return to a sync point, never jump to it directly."
"This must therefore never return an unfinished fiber/continuation."); "This must therefore never return an unfinished fiber/continuation.");
#if PLS_PROFILING_ENABLED #if PLS_PROFILING_ENABLED
thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(), thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(),
thread_state::get().get_active_task()->profiling_node_); thread_state::get().get_active_task()->profiling_node_);
#endif #endif
return; // We cleanly synced to the last one finishing work on last_task return; // We cleanly synced to the last one finishing work on last_task
}
} else {
// Scheduler not active
return;
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment