Commit 722ddf41 by FritzFlorian

Remove un-needed iteration in scan algorithm.

parent ef19ea1b
Pipeline #1322 passed with stages
in 4 minutes 45 seconds
......@@ -17,14 +17,13 @@ using namespace pls::internal::scheduling;
template<typename InIter, typename OutIter, typename BinaryOp, typename Type>
void serial_scan(InIter input_start, const InIter input_end, OutIter output, BinaryOp op, Type neutral_element) {
auto current_input = input_start;
auto current_output = output;
auto last_value = neutral_element;
while (current_input != input_end) {
for (auto current_input = input_start; current_input != input_end; current_input++) {
last_value = op(last_value, *current_input);
*current_output = last_value;
current_input++;
current_output++;
}
}
......@@ -52,15 +51,15 @@ class scan_task : public pls::internal::scheduling::task {
size_ = std::distance(in_start, in_end);
auto num_threads = thread_state::get()->scheduler_->num_threads();
chunks_ = num_threads * chunks_per_thread;
chunks_ = num_threads * chunks_per_thread + 1;
items_per_chunk_ = size_ / chunks_ + 1;
chunk_sums_ = reinterpret_cast<Type *>(allocate_memory(sizeof(Type) * chunks_));
chunk_sums_ = reinterpret_cast<Type *>(allocate_memory(sizeof(Type) * chunks_ - 1));
};
void execute_internal() override {
// First Pass = calculate each chunks individual prefix sum
for_each_range(0, chunks_, [&](int i) {
for_each_range(0, chunks_ - 1, [&](int i) {
auto chunk_start = in_start_ + items_per_chunk_ * i;
auto chunk_end = std::min(in_end_, chunk_start + items_per_chunk_);
auto chunk_size = std::distance(chunk_start, chunk_end);
......@@ -73,17 +72,25 @@ class scan_task : public pls::internal::scheduling::task {
// Calculate prefix sums of each chunks sum
// (effectively the prefix sum at the end of each chunk, then used to correct the following chunk).
internal::serial_scan(chunk_sums_, chunk_sums_ + chunks_, chunk_sums_, op_, neutral_elem_);
internal::serial_scan(chunk_sums_, chunk_sums_ + chunks_ - 1, chunk_sums_, op_, neutral_elem_);
// Second Pass = Use results from first pass to correct each chunks sum
auto output_start = out_;
auto output_end = out_ + size_;
for_each_range(1, chunks_, [&](int i) {
auto chunk_start = output_start + items_per_chunk_ * i;
auto chunk_end = std::min(output_end, chunk_start + items_per_chunk_);
for (; chunk_start != chunk_end; chunk_start++) {
*chunk_start = op_(*chunk_start, chunk_sums_[i - 1]);
if (i == chunks_ - 1) {
auto chunk_start = in_start_ + items_per_chunk_ * i;
auto chunk_end = std::min(in_end_, chunk_start + items_per_chunk_);
auto chunk_output = output_start + items_per_chunk_ * i;
*chunk_start += chunk_sums_[i - 1];
internal::serial_scan(chunk_start, chunk_end, chunk_output, op_, neutral_elem_);
} else {
auto chunk_start = output_start + items_per_chunk_ * i;
auto chunk_end = std::min(output_end, chunk_start + items_per_chunk_);
for (; chunk_start != chunk_end; chunk_start++) {
*chunk_start = op_(*chunk_start, chunk_sums_[i - 1]);
}
}
}, fixed_strategy{1});
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment