diff --git a/lib/pls/include/pls/algorithms/for_each_impl.h b/lib/pls/include/pls/algorithms/for_each_impl.h index 3508c7c..68c827e 100644 --- a/lib/pls/include/pls/algorithms/for_each_impl.h +++ b/lib/pls/include/pls/algorithms/for_each_impl.h @@ -15,7 +15,7 @@ namespace internal { template void for_each(RandomIt first, RandomIt last, const Function &function) { using namespace ::pls::internal::scheduling; - constexpr long min_elements = 4; + constexpr long min_elements = 1; // TODO: tune this value/allow for execution strategies long num_elements = std::distance(first, last); if (num_elements <= min_elements) { diff --git a/lib/pls/include/pls/algorithms/scan_impl.h b/lib/pls/include/pls/algorithms/scan_impl.h index 007198a..9523de3 100644 --- a/lib/pls/include/pls/algorithms/scan_impl.h +++ b/lib/pls/include/pls/algorithms/scan_impl.h @@ -7,10 +7,14 @@ #include "pls/pls.h" #include "pls/internal/scheduling/thread_state.h" +#include "pls/internal/scheduling/task.h" namespace pls { namespace algorithm { namespace internal { + +using namespace pls::internal::scheduling; + template void serial_scan(InIter input_start, const InIter input_end, OutIter output, BinaryOp op, Type neutral_element) { auto current_input = input_start; @@ -25,52 +29,74 @@ void serial_scan(InIter input_start, const InIter input_end, OutIter output, Bin } } -} - template -void scan(InIter in_start, const InIter in_end, OutIter out, BinaryOp op, Type neutral_elem) { - constexpr auto chunks_per_thread = 4; - using namespace pls::internal::scheduling; - - // TODO: This must be dynamic to make sense, as it has a far bigger influence than any other cutoff. - // The current strategy is static partitioning, and suboptimal in inballanced workloads. - auto size = std::distance(in_start, in_end); - auto num_threads = thread_state::get()->scheduler_->num_threads(); - auto chunks = num_threads * chunks_per_thread; - auto items_per_chunk = std::max(1l, size / chunks); - - scheduler::allocate_on_stack(sizeof(Type) * (chunks), [&](void *memory) { - Type *chunk_sums = reinterpret_cast(memory); - +class scan_task : public pls::internal::scheduling::task { + const InIter in_start_; + const InIter in_end_; + const OutIter out_; + const BinaryOp op_; + const Type neutral_elem_; + + long size_, chunks_; + long items_per_chunk_; + Type *chunk_sums_; + + public: + scan_task(const InIter in_start, const InIter in_end, const OutIter out, const BinaryOp op, const Type neutral_elem) : + in_start_{in_start}, + in_end_{in_end}, + out_{out}, + op_{op}, + neutral_elem_{neutral_elem} { + constexpr auto chunks_per_thread = 1; + + size_ = std::distance(in_start, in_end); + auto num_threads = thread_state::get()->scheduler_->num_threads(); + chunks_ = num_threads * chunks_per_thread; + items_per_chunk_ = size_ / chunks_ + 1; + + chunk_sums_ = reinterpret_cast(allocate_memory(sizeof(Type) * chunks_)); + }; + + void execute_internal() override { // First Pass = calculate each chunks individual prefix sum - for_each_range(0, chunks, [&](int i) { - auto chunk_start = in_start + items_per_chunk * i; - auto chunk_end = std::min(in_end, chunk_start + items_per_chunk); - auto chunk_output = out + items_per_chunk * i; - - internal::serial_scan(chunk_start, chunk_end, chunk_output, op, neutral_elem); - chunk_sums[i] = *(out + std::distance(chunk_start, chunk_end) - 1); + for_each_range(0, chunks_, [&](int i) { + auto chunk_start = in_start_ + items_per_chunk_ * i; + auto chunk_end = std::min(in_end_, chunk_start + items_per_chunk_); + auto chunk_size = std::distance(chunk_start, chunk_end); + auto chunk_output = out_ + items_per_chunk_ * i; + + internal::serial_scan(chunk_start, chunk_end, chunk_output, op_, neutral_elem_); + auto last_chunk_value = *(chunk_output + chunk_size - 1); + chunk_sums_[i] = last_chunk_value; }); // Calculate prefix sums of each chunks sum // (effectively the prefix sum at the end of each chunk, then used to correct the following chunk). - internal::serial_scan(chunk_sums, chunk_sums + chunks, chunk_sums, std::plus(), 0); + internal::serial_scan(chunk_sums_, chunk_sums_ + chunks_, chunk_sums_, op_, neutral_elem_); // Second Pass = Use results from first pass to correct each chunks sum - auto output_start = out; - auto output_end = out + size; - for_each_range(1, chunks, [&](int i) { - auto chunk_start = output_start + items_per_chunk * i; - auto chunk_end = std::min(output_end, chunk_start + items_per_chunk); + auto output_start = out_; + auto output_end = out_ + size_; + for_each_range(1, chunks_, [&](int i) { + auto chunk_start = output_start + items_per_chunk_ * i; + auto chunk_end = std::min(output_end, chunk_start + items_per_chunk_); for (; chunk_start != chunk_end; chunk_start++) { - *chunk_start = op(*chunk_start, chunk_sums[i - 1]); + *chunk_start = op_(*chunk_start, chunk_sums_[i - 1]); } }); - }); + } +}; + +} + +template +void scan(InIter in_start, const InIter in_end, OutIter out, BinaryOp op, Type neutral_elem) { + using namespace pls::internal::scheduling; - // End this work section by cleaning up stack and tasks - scheduler::wait_for_all(); + using scan_task_type = internal::scan_task; + scheduler::spawn_child_and_wait(in_start, in_end, out, op, neutral_elem); } }