diff --git a/lib/pls/include/pls/algorithms/for_each_impl.h b/lib/pls/include/pls/algorithms/for_each_impl.h index 9d2bb8a..722cb69 100644 --- a/lib/pls/include/pls/algorithms/for_each_impl.h +++ b/lib/pls/include/pls/algorithms/for_each_impl.h @@ -32,13 +32,12 @@ static void for_each(const RandomIt first, function, min_elements); }); - scheduler::spawn([first, middle_index, last, &function, min_elements] { + scheduler::spawn_and_sync([first, middle_index, last, &function, min_elements] { internal::for_each(first + middle_index, last, function, min_elements); }); - scheduler::sync(); } } diff --git a/lib/pls/include/pls/algorithms/reduce_impl.h b/lib/pls/include/pls/algorithms/reduce_impl.h index 5022c9c..6155367 100644 --- a/lib/pls/include/pls/algorithms/reduce_impl.h +++ b/lib/pls/include/pls/algorithms/reduce_impl.h @@ -37,14 +37,13 @@ static Element reduce(const RandomIt first, reducer, min_elements); }); - scheduler::spawn([first, middle_index, last, neutral, &reducer, min_elements, &right] { + scheduler::spawn_and_sync([first, middle_index, last, neutral, &reducer, min_elements, &right] { right = internal::reduce(first + middle_index, last, neutral, reducer, min_elements); }); - scheduler::sync(); return reducer(left, right); } } diff --git a/lib/pls/include/pls/internal/scheduling/scheduler.h b/lib/pls/include/pls/internal/scheduling/scheduler.h index 542d42c..f6219c5 100644 --- a/lib/pls/include/pls/internal/scheduling/scheduler.h +++ b/lib/pls/include/pls/internal/scheduling/scheduler.h @@ -86,7 +86,11 @@ class scheduler { #ifdef PLS_SERIAL_ELUSION lambda(); #else - spawn_internal(std::forward(lambda)); + if (thread_state::is_scheduler_active()) { + spawn_internal(std::forward(lambda)); + } else { + lambda(); + } #endif } @@ -97,7 +101,29 @@ class scheduler { #ifdef PLS_SERIAL_ELUSION return; #else - sync_internal(); + if (thread_state::is_scheduler_active()) { + sync_internal(); + } else { + return; + } +#endif + } + + /** + * Equivialent to calling spawn(lambda); sync(); + * + * Faster than the direct notation, as stealing the continuation before the sync is not useful and only overhead. + */ + template + static void spawn_and_sync(Function &&lambda) { +#ifdef PLS_SERIAL_ELUSION + lambda(); +#else + if (thread_state::is_scheduler_active()) { + spawn_and_sync_internal(std::forward(lambda)); + } else { + lambda(); + } #endif } @@ -110,7 +136,11 @@ class scheduler { #ifdef PLS_SERIAL_ELUSION lambda(); #else - serial_internal(std::forward(lambda)); + if (thread_state::is_scheduler_active()) { + serial_internal(std::forward(lambda)); + } else { + lambda(); + } #endif } @@ -138,6 +168,8 @@ class scheduler { private: template static void spawn_internal(Function &&lambda); + template + static void spawn_and_sync_internal(Function &&lambda); static void sync_internal(); template static void serial_internal(Function &&lambda); diff --git a/lib/pls/include/pls/internal/scheduling/scheduler_impl.h b/lib/pls/include/pls/internal/scheduling/scheduler_impl.h index a53b70e..3e1077b 100644 --- a/lib/pls/include/pls/internal/scheduling/scheduler_impl.h +++ b/lib/pls/include/pls/internal/scheduling/scheduler_impl.h @@ -33,7 +33,7 @@ scheduler::scheduler(unsigned int num_threads, stack_allocator_{std::make_shared(std::forward(stack_allocator))}, serial_stack_size_{serial_stack_size} #if PLS_PROFILING_ENABLED - , profiler_{num_threads} +, profiler_{num_threads} #endif { @@ -290,25 +290,119 @@ void scheduler::spawn_internal(Function &&lambda) { } template -void scheduler::serial_internal(Function &&lambda) { - if (thread_state::is_scheduler_active()) { - thread_state &spawning_state = thread_state::get(); - base_task *active_task = spawning_state.get_active_task(); +void scheduler::spawn_and_sync_internal(Function &&lambda) { + thread_state &spawning_state = thread_state::get(); - if (active_task->is_serial_section_) { - lambda(); + base_task *last_task = spawning_state.get_active_task(); + base_task *spawned_task = last_task->next_; + + // We are at the end of our allocated tasks, go over to serial execution. + if (spawned_task == nullptr) { + scheduler::serial_internal(std::forward(lambda)); + return; + } + // We are in a serial section. + // For now we simply stay serial. Later versions could add nested parallel sections. + if (last_task->is_serial_section_) { + lambda(); + return; + } + + // Carry over the resources allocated by parents + auto *attached_resources = last_task->attached_resources_.load(std::memory_order_relaxed); + spawned_task->attached_resources_.store(attached_resources, std::memory_order_relaxed); + +#if PLS_PROFILING_ENABLED + spawning_state.get_scheduler().profiler_.task_prepare_stack_measure(spawning_state.get_thread_id(), + spawned_task->stack_memory_, + spawned_task->stack_size_); + auto *child_dag_node = spawning_state.get_scheduler().profiler_.task_spawn_child(spawning_state.get_thread_id(), + last_task->profiling_node_); + spawned_task->profiling_node_ = child_dag_node; +#endif + + auto continuation = spawned_task->run_as_task([last_task, spawned_task, lambda, &spawning_state](auto cont) { + // allow stealing threads to continue the last task. + last_task->continuation_ = std::move(cont); + + // we are now executing the new task, allow others to steal the last task continuation. + spawned_task->is_synchronized_ = true; + spawning_state.set_active_task(spawned_task); + + // execute the lambda itself, which could lead to a different thread returning. +#if PLS_PROFILING_ENABLED + spawning_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(), + last_task->stack_memory_, + last_task->stack_size_, + last_task->profiling_node_); + syncing_state.get_scheduler().profiler_.task_stop_running(spawning_state.get_thread_id(), + last_task->profiling_node_); + auto *next_dag_node = + spawning_state.get_scheduler().profiler_.task_sync(spawning_state.get_thread_id(), last_task->profiling_node_); + last_task->profiling_node_ = next_dag_node; + spawning_state.get_scheduler().profiler_.task_start_running(spawning_state.get_thread_id(), + spawned_task->profiling_node_); +#endif + lambda(); + + thread_state &syncing_state = thread_state::get(); + PLS_ASSERT(syncing_state.get_active_task() == spawned_task, + "Task manager must always point its active task onto whats executing."); + + if (&syncing_state == &spawning_state && last_task->is_synchronized_) { + // Fast path, simply continue execution where we left of before spawn. + PLS_ASSERT(last_task->continuation_.valid(), + "Fast path, no one can have continued working on the last task."); + + syncing_state.set_active_task(last_task); +#if PLS_PROFILING_ENABLED + syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(), + spawned_task->stack_memory_, + spawned_task->stack_size_, + spawned_task->profiling_node_); + syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(), + spawned_task->profiling_node_); +#endif + return std::move(last_task->continuation_); } else { - active_task->is_serial_section_ = true; - context_switcher::enter_context(spawning_state.get_serial_call_stack(), - spawning_state.get_serial_call_stack_size(), - [&](auto cont) { - lambda(); - return cont; - }); - active_task->is_serial_section_ = false; + // Slow path, the last task was stolen. This path is common to sync() events. +#if PLS_PROFILING_ENABLED + syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(), + spawned_task->stack_memory_, + spawned_task->stack_size_, + spawned_task->profiling_node_); + syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(), + spawned_task->profiling_node_); +#endif + auto continuation = slow_return(syncing_state, false); + return continuation; } - } else { + }); + +#if PLS_PROFILING_ENABLED + thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(), + last_task->profiling_node_); +#endif + + PLS_ASSERT(!continuation.valid(), "We must not jump to a not-published (synced) spawn."); +} + +template +void scheduler::serial_internal(Function &&lambda) { + thread_state &spawning_state = thread_state::get(); + base_task *active_task = spawning_state.get_active_task(); + + if (active_task->is_serial_section_) { lambda(); + } else { + active_task->is_serial_section_ = true; + context_switcher::enter_context(spawning_state.get_serial_call_stack(), + spawning_state.get_serial_call_stack_size(), + [&](auto cont) { + lambda(); + return cont; + }); + active_task->is_serial_section_ = false; } } diff --git a/lib/pls/include/pls/pls.h b/lib/pls/include/pls/pls.h index 15c1ebf..9755229 100644 --- a/lib/pls/include/pls/pls.h +++ b/lib/pls/include/pls/pls.h @@ -6,6 +6,7 @@ #include "pls/algorithms/invoke.h" #include "pls/algorithms/for_each.h" #include "pls/algorithms/reduce.h" +#include "pls/algorithms/loop_partition_strategy.h" #include "pls/internal/scheduling/scheduler.h" #include "pls/internal/scheduling/strain_local_resource.h" @@ -21,6 +22,10 @@ template static void spawn(Function &&function) { scheduler::spawn(std::forward(function)); } +template +static void spawn_and_sync(Function &&function) { + scheduler::spawn_and_sync(std::forward(function)); +} static void sync() { scheduler::sync(); } @@ -42,6 +47,9 @@ using algorithm::invoke; using algorithm::for_each; using algorithm::for_each_range; using algorithm::reduce; + +using algorithm::dynamic_strategy; +using algorithm::fixed_strategy; } #endif diff --git a/lib/pls/src/internal/scheduling/scheduler.cpp b/lib/pls/src/internal/scheduling/scheduler.cpp index cb95032..f2f5986 100644 --- a/lib/pls/src/internal/scheduling/scheduler.cpp +++ b/lib/pls/src/internal/scheduling/scheduler.cpp @@ -155,50 +155,45 @@ void scheduler::work_thread_work_section() { } void scheduler::sync_internal() { - if (thread_state::is_scheduler_active()) { - thread_state &syncing_state = thread_state::get(); + thread_state &syncing_state = thread_state::get(); - base_task *active_task = syncing_state.get_active_task(); - base_task *spawned_task = active_task->next_; + base_task *active_task = syncing_state.get_active_task(); + base_task *spawned_task = active_task->next_; #if PLS_PROFILING_ENABLED - syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(), - active_task->stack_memory_, - active_task->stack_size_, - active_task->profiling_node_); - syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(), - active_task->profiling_node_); - auto *next_dag_node = - syncing_state.get_scheduler().profiler_.task_sync(syncing_state.get_thread_id(), active_task->profiling_node_); - active_task->profiling_node_ = next_dag_node; + syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(), + active_task->stack_memory_, + active_task->stack_size_, + active_task->profiling_node_); + syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(), + active_task->profiling_node_); + auto *next_dag_node = + syncing_state.get_scheduler().profiler_.task_sync(syncing_state.get_thread_id(), active_task->profiling_node_); + active_task->profiling_node_ = next_dag_node; #endif - if (active_task->is_synchronized_) { + if (active_task->is_synchronized_) { #if PLS_PROFILING_ENABLED - thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(), - thread_state::get().get_active_task()->profiling_node_); + thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(), + thread_state::get().get_active_task()->profiling_node_); #endif - return; // We are already the sole owner of last_task - } else { - auto continuation = - spawned_task->run_as_task([active_task, spawned_task, &syncing_state](context_switcher::continuation cont) { - active_task->continuation_ = std::move(cont); - syncing_state.set_active_task(spawned_task); - return slow_return(syncing_state, true); - }); - - PLS_ASSERT(!continuation.valid(), - "We only return to a sync point, never jump to it directly." - "This must therefore never return an unfinished fiber/continuation."); + return; // We are already the sole owner of last_task + } else { + auto continuation = + spawned_task->run_as_task([active_task, spawned_task, &syncing_state](context_switcher::continuation cont) { + active_task->continuation_ = std::move(cont); + syncing_state.set_active_task(spawned_task); + return slow_return(syncing_state, true); + }); + + PLS_ASSERT(!continuation.valid(), + "We only return to a sync point, never jump to it directly." + "This must therefore never return an unfinished fiber/continuation."); #if PLS_PROFILING_ENABLED - thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(), - thread_state::get().get_active_task()->profiling_node_); + thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(), + thread_state::get().get_active_task()->profiling_node_); #endif - return; // We cleanly synced to the last one finishing work on last_task - } - } else { - // Scheduler not active - return; + return; // We cleanly synced to the last one finishing work on last_task } }