From 722ddf41acb1c63b8ad2d7977410ce3f4273a6c4 Mon Sep 17 00:00:00 2001 From: FritzFlorian Date: Mon, 30 Sep 2019 14:52:14 +0200 Subject: [PATCH] Remove un-needed iteration in scan algorithm. --- lib/pls/include/pls/algorithms/scan_impl.h | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/lib/pls/include/pls/algorithms/scan_impl.h b/lib/pls/include/pls/algorithms/scan_impl.h index 183a6e8..5abe558 100644 --- a/lib/pls/include/pls/algorithms/scan_impl.h +++ b/lib/pls/include/pls/algorithms/scan_impl.h @@ -17,14 +17,13 @@ using namespace pls::internal::scheduling; template void serial_scan(InIter input_start, const InIter input_end, OutIter output, BinaryOp op, Type neutral_element) { - auto current_input = input_start; + auto current_output = output; auto last_value = neutral_element; - while (current_input != input_end) { + + for (auto current_input = input_start; current_input != input_end; current_input++) { last_value = op(last_value, *current_input); *current_output = last_value; - - current_input++; current_output++; } } @@ -52,15 +51,15 @@ class scan_task : public pls::internal::scheduling::task { size_ = std::distance(in_start, in_end); auto num_threads = thread_state::get()->scheduler_->num_threads(); - chunks_ = num_threads * chunks_per_thread; + chunks_ = num_threads * chunks_per_thread + 1; items_per_chunk_ = size_ / chunks_ + 1; - chunk_sums_ = reinterpret_cast(allocate_memory(sizeof(Type) * chunks_)); + chunk_sums_ = reinterpret_cast(allocate_memory(sizeof(Type) * chunks_ - 1)); }; void execute_internal() override { // First Pass = calculate each chunks individual prefix sum - for_each_range(0, chunks_, [&](int i) { + for_each_range(0, chunks_ - 1, [&](int i) { auto chunk_start = in_start_ + items_per_chunk_ * i; auto chunk_end = std::min(in_end_, chunk_start + items_per_chunk_); auto chunk_size = std::distance(chunk_start, chunk_end); @@ -73,17 +72,25 @@ class scan_task : public pls::internal::scheduling::task { // Calculate prefix sums of each chunks sum // (effectively the prefix sum at the end of each chunk, then used to correct the following chunk). - internal::serial_scan(chunk_sums_, chunk_sums_ + chunks_, chunk_sums_, op_, neutral_elem_); + internal::serial_scan(chunk_sums_, chunk_sums_ + chunks_ - 1, chunk_sums_, op_, neutral_elem_); // Second Pass = Use results from first pass to correct each chunks sum auto output_start = out_; auto output_end = out_ + size_; for_each_range(1, chunks_, [&](int i) { - auto chunk_start = output_start + items_per_chunk_ * i; - auto chunk_end = std::min(output_end, chunk_start + items_per_chunk_); - - for (; chunk_start != chunk_end; chunk_start++) { - *chunk_start = op_(*chunk_start, chunk_sums_[i - 1]); + if (i == chunks_ - 1) { + auto chunk_start = in_start_ + items_per_chunk_ * i; + auto chunk_end = std::min(in_end_, chunk_start + items_per_chunk_); + auto chunk_output = output_start + items_per_chunk_ * i; + + *chunk_start += chunk_sums_[i - 1]; + internal::serial_scan(chunk_start, chunk_end, chunk_output, op_, neutral_elem_); + } else { + auto chunk_start = output_start + items_per_chunk_ * i; + auto chunk_end = std::min(output_end, chunk_start + items_per_chunk_); + for (; chunk_start != chunk_end; chunk_start++) { + *chunk_start = op_(*chunk_start, chunk_sums_[i - 1]); + } } }, fixed_strategy{1}); } -- libgit2 0.26.0