diff --git a/lib/pls/include/pls/algorithms/for_each.h b/lib/pls/include/pls/algorithms/for_each.h index 9b38427..0cc11b1 100644 --- a/lib/pls/include/pls/algorithms/for_each.h +++ b/lib/pls/include/pls/algorithms/for_each.h @@ -5,9 +5,21 @@ namespace pls { namespace algorithm { +class fixed_strategy; +class dynamic_strategy; + +template +void for_each_range(unsigned long first, + unsigned long last, + const Function &function, + ExecutionStrategy &execution_strategy); + template void for_each_range(unsigned long first, unsigned long last, const Function &function); +template +void for_each(RandomIt first, RandomIt last, const Function &function, ExecutionStrategy execution_strategy); + template void for_each(RandomIt first, RandomIt last, const Function &function); diff --git a/lib/pls/include/pls/algorithms/for_each_impl.h b/lib/pls/include/pls/algorithms/for_each_impl.h index 68c827e..058ae9f 100644 --- a/lib/pls/include/pls/algorithms/for_each_impl.h +++ b/lib/pls/include/pls/algorithms/for_each_impl.h @@ -4,6 +4,7 @@ #include "pls/internal/scheduling/task.h" #include "pls/internal/scheduling/scheduler.h" +#include "pls/internal/scheduling/thread_state.h" #include "pls/internal/helpers/unique_id.h" #include "pls/internal/helpers/range.h" @@ -13,11 +14,10 @@ namespace algorithm { namespace internal { template -void for_each(RandomIt first, RandomIt last, const Function &function) { +void for_each(const RandomIt first, const RandomIt last, const Function function, const long min_elements) { using namespace ::pls::internal::scheduling; - constexpr long min_elements = 1; // TODO: tune this value/allow for execution strategies - long num_elements = std::distance(first, last); + const long num_elements = std::distance(first, last); if (num_elements <= min_elements) { // calculate last elements in loop to avoid overhead for (auto current = first; current != last; current++) { @@ -25,15 +25,25 @@ void for_each(RandomIt first, RandomIt last, const Function &function) { } } else { // Cut in half recursively - long middle_index = num_elements / 2; + const long middle_index = num_elements / 2; auto second_half_body = - [first, middle_index, last, &function] { internal::for_each(first + middle_index, last, function); }; + [first, middle_index, last, &function, min_elements] { + internal::for_each(first + middle_index, + last, + function, + min_elements); + }; using second_half_t = lambda_task_by_reference; scheduler::spawn_child(std::move(second_half_body)); auto first_half_body = - [first, middle_index, last, &function] { internal::for_each(first, first + middle_index, function); }; + [first, middle_index, last, &function, min_elements] { + internal::for_each(first, + first + middle_index, + function, + min_elements); + }; using first_half_t = lambda_task_by_reference; scheduler::spawn_child_and_wait(std::move(first_half_body)); } @@ -41,15 +51,53 @@ void for_each(RandomIt first, RandomIt last, const Function &function) { } -template -void for_each_range(unsigned long first, unsigned long last, const Function &function) { - auto range = pls::internal::helpers::range(first, last); - internal::for_each(range.begin(), range.end(), function); +class dynamic_strategy { + public: + explicit dynamic_strategy(const unsigned int tasks_per_thread = 4) : tasks_per_thread_{tasks_per_thread} {}; + + long calculate_min_elements(long num_elements) const { + const long num_threads = pls::internal::scheduling::thread_state::get()->scheduler_->num_threads(); + return num_elements / (num_threads * tasks_per_thread_); + } + private: + unsigned const int tasks_per_thread_; +}; + +class fixed_strategy { + public: + explicit fixed_strategy(const long min_elements_per_task) : min_elements_per_task_{min_elements_per_task} {}; + + long calculate_min_elements(long /*num_elements*/) const { + return min_elements_per_task_; + } + private: + const long min_elements_per_task_; +}; + +template +void for_each(RandomIt first, RandomIt last, const Function &function, ExecutionStrategy execution_strategy) { + long num_elements = std::distance(first, last); + internal::for_each(first, last, function, execution_strategy.calculate_min_elements(num_elements)); } template void for_each(RandomIt first, RandomIt last, const Function &function) { - internal::for_each(first, last, function); + for_each(first, last, function, dynamic_strategy{4}); +} + +template +void for_each_range(unsigned long first, + unsigned long last, + const Function &function, + ExecutionStrategy execution_strategy) { + auto range = pls::internal::helpers::range(first, last); + for_each(range.begin(), range.end(), function, execution_strategy); +} + +template +void for_each_range(unsigned long first, unsigned long last, const Function &function) { + auto range = pls::internal::helpers::range(first, last); + for_each(range.begin(), range.end(), function); } } diff --git a/lib/pls/include/pls/algorithms/scan_impl.h b/lib/pls/include/pls/algorithms/scan_impl.h index 9523de3..183a6e8 100644 --- a/lib/pls/include/pls/algorithms/scan_impl.h +++ b/lib/pls/include/pls/algorithms/scan_impl.h @@ -69,7 +69,7 @@ class scan_task : public pls::internal::scheduling::task { internal::serial_scan(chunk_start, chunk_end, chunk_output, op_, neutral_elem_); auto last_chunk_value = *(chunk_output + chunk_size - 1); chunk_sums_[i] = last_chunk_value; - }); + }, fixed_strategy{1}); // Calculate prefix sums of each chunks sum // (effectively the prefix sum at the end of each chunk, then used to correct the following chunk). @@ -85,7 +85,7 @@ class scan_task : public pls::internal::scheduling::task { for (; chunk_start != chunk_end; chunk_start++) { *chunk_start = op_(*chunk_start, chunk_sums_[i - 1]); } - }); + }, fixed_strategy{1}); } };