scan_impl.h 3.43 KB
Newer Older
1 2 3 4 5 6 7 8 9

#ifndef PLS_PARALLEL_SCAN_IMPL_H_
#define PLS_PARALLEL_SCAN_IMPL_H_

#include <memory>
#include <functional>

#include "pls/pls.h"
#include "pls/internal/scheduling/thread_state.h"
10
#include "pls/internal/scheduling/task.h"
11 12 13 14

namespace pls {
namespace algorithm {
namespace internal {
15 16 17

using namespace pls::internal::scheduling;

18 19 20 21 22 23 24 25 26 27 28 29 30
template<typename InIter, typename OutIter, typename BinaryOp, typename Type>
void serial_scan(InIter input_start, const InIter input_end, OutIter output, BinaryOp op, Type neutral_element) {
  auto current_input = input_start;
  auto current_output = output;
  auto last_value = neutral_element;
  while (current_input != input_end) {
    last_value = op(last_value, *current_input);
    *current_output = last_value;

    current_input++;
    current_output++;
  }
}
31

32
template<typename InIter, typename OutIter, typename BinaryOp, typename Type>
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
class scan_task : public pls::internal::scheduling::task {
  const InIter in_start_;
  const InIter in_end_;
  const OutIter out_;
  const BinaryOp op_;
  const Type neutral_elem_;

  long size_, chunks_;
  long items_per_chunk_;
  Type *chunk_sums_;

 public:
  scan_task(const InIter in_start, const InIter in_end, const OutIter out, const BinaryOp op, const Type neutral_elem) :
      in_start_{in_start},
      in_end_{in_end},
      out_{out},
      op_{op},
      neutral_elem_{neutral_elem} {
    constexpr auto chunks_per_thread = 1;

    size_ = std::distance(in_start, in_end);
    auto num_threads = thread_state::get()->scheduler_->num_threads();
    chunks_ = num_threads * chunks_per_thread;
    items_per_chunk_ = size_ / chunks_ + 1;

    chunk_sums_ = reinterpret_cast<Type *>(allocate_memory(sizeof(Type) * chunks_));
  };

  void execute_internal() override {
62
    // First Pass = calculate each chunks individual prefix sum
63 64 65 66 67 68 69 70 71
    for_each_range(0, chunks_, [&](int i) {
      auto chunk_start = in_start_ + items_per_chunk_ * i;
      auto chunk_end = std::min(in_end_, chunk_start + items_per_chunk_);
      auto chunk_size = std::distance(chunk_start, chunk_end);
      auto chunk_output = out_ + items_per_chunk_ * i;

      internal::serial_scan(chunk_start, chunk_end, chunk_output, op_, neutral_elem_);
      auto last_chunk_value = *(chunk_output + chunk_size - 1);
      chunk_sums_[i] = last_chunk_value;
72
    }, fixed_strategy{1});
73 74 75

    // Calculate prefix sums of each chunks sum
    // (effectively the prefix sum at the end of each chunk, then used to correct the following chunk).
76
    internal::serial_scan(chunk_sums_, chunk_sums_ + chunks_, chunk_sums_, op_, neutral_elem_);
77 78

    // Second Pass = Use results from first pass to correct each chunks sum
79 80 81 82 83
    auto output_start = out_;
    auto output_end = out_ + size_;
    for_each_range(1, chunks_, [&](int i) {
      auto chunk_start = output_start + items_per_chunk_ * i;
      auto chunk_end = std::min(output_end, chunk_start + items_per_chunk_);
84 85

      for (; chunk_start != chunk_end; chunk_start++) {
86
        *chunk_start = op_(*chunk_start, chunk_sums_[i - 1]);
87
      }
88
    }, fixed_strategy{1});
89 90 91 92 93 94 95 96
  }
};

}

template<typename InIter, typename OutIter, typename BinaryOp, typename Type>
void scan(InIter in_start, const InIter in_end, OutIter out, BinaryOp op, Type neutral_elem) {
  using namespace pls::internal::scheduling;
97

98 99
  using scan_task_type = internal::scan_task<InIter, OutIter, BinaryOp, Type>;
  scheduler::spawn_child_and_wait<scan_task_type>(in_start, in_end, out, op, neutral_elem);
100 101 102 103 104 105
}

}
}

#endif //PLS_PARALLEL_SCAN_IMPL_H_