Commit 779978e2 by FritzFlorian

Fix Bug in scan (elements where skipped).

parent 73550b12
...@@ -15,7 +15,7 @@ namespace internal { ...@@ -15,7 +15,7 @@ namespace internal {
template<typename RandomIt, typename Function> template<typename RandomIt, typename Function>
void for_each(RandomIt first, RandomIt last, const Function &function) { void for_each(RandomIt first, RandomIt last, const Function &function) {
using namespace ::pls::internal::scheduling; using namespace ::pls::internal::scheduling;
constexpr long min_elements = 4; constexpr long min_elements = 1; // TODO: tune this value/allow for execution strategies
long num_elements = std::distance(first, last); long num_elements = std::distance(first, last);
if (num_elements <= min_elements) { if (num_elements <= min_elements) {
......
...@@ -7,10 +7,14 @@ ...@@ -7,10 +7,14 @@
#include "pls/pls.h" #include "pls/pls.h"
#include "pls/internal/scheduling/thread_state.h" #include "pls/internal/scheduling/thread_state.h"
#include "pls/internal/scheduling/task.h"
namespace pls { namespace pls {
namespace algorithm { namespace algorithm {
namespace internal { namespace internal {
using namespace pls::internal::scheduling;
template<typename InIter, typename OutIter, typename BinaryOp, typename Type> template<typename InIter, typename OutIter, typename BinaryOp, typename Type>
void serial_scan(InIter input_start, const InIter input_end, OutIter output, BinaryOp op, Type neutral_element) { void serial_scan(InIter input_start, const InIter input_end, OutIter output, BinaryOp op, Type neutral_element) {
auto current_input = input_start; auto current_input = input_start;
...@@ -25,52 +29,74 @@ void serial_scan(InIter input_start, const InIter input_end, OutIter output, Bin ...@@ -25,52 +29,74 @@ void serial_scan(InIter input_start, const InIter input_end, OutIter output, Bin
} }
} }
}
template<typename InIter, typename OutIter, typename BinaryOp, typename Type> template<typename InIter, typename OutIter, typename BinaryOp, typename Type>
void scan(InIter in_start, const InIter in_end, OutIter out, BinaryOp op, Type neutral_elem) { class scan_task : public pls::internal::scheduling::task {
constexpr auto chunks_per_thread = 4; const InIter in_start_;
using namespace pls::internal::scheduling; const InIter in_end_;
const OutIter out_;
// TODO: This must be dynamic to make sense, as it has a far bigger influence than any other cutoff. const BinaryOp op_;
// The current strategy is static partitioning, and suboptimal in inballanced workloads. const Type neutral_elem_;
auto size = std::distance(in_start, in_end);
auto num_threads = thread_state::get()->scheduler_->num_threads(); long size_, chunks_;
auto chunks = num_threads * chunks_per_thread; long items_per_chunk_;
auto items_per_chunk = std::max(1l, size / chunks); Type *chunk_sums_;
scheduler::allocate_on_stack(sizeof(Type) * (chunks), [&](void *memory) { public:
Type *chunk_sums = reinterpret_cast<Type *>(memory); scan_task(const InIter in_start, const InIter in_end, const OutIter out, const BinaryOp op, const Type neutral_elem) :
in_start_{in_start},
in_end_{in_end},
out_{out},
op_{op},
neutral_elem_{neutral_elem} {
constexpr auto chunks_per_thread = 1;
size_ = std::distance(in_start, in_end);
auto num_threads = thread_state::get()->scheduler_->num_threads();
chunks_ = num_threads * chunks_per_thread;
items_per_chunk_ = size_ / chunks_ + 1;
chunk_sums_ = reinterpret_cast<Type *>(allocate_memory(sizeof(Type) * chunks_));
};
void execute_internal() override {
// First Pass = calculate each chunks individual prefix sum // First Pass = calculate each chunks individual prefix sum
for_each_range(0, chunks, [&](int i) { for_each_range(0, chunks_, [&](int i) {
auto chunk_start = in_start + items_per_chunk * i; auto chunk_start = in_start_ + items_per_chunk_ * i;
auto chunk_end = std::min(in_end, chunk_start + items_per_chunk); auto chunk_end = std::min(in_end_, chunk_start + items_per_chunk_);
auto chunk_output = out + items_per_chunk * i; auto chunk_size = std::distance(chunk_start, chunk_end);
auto chunk_output = out_ + items_per_chunk_ * i;
internal::serial_scan(chunk_start, chunk_end, chunk_output, op, neutral_elem);
chunk_sums[i] = *(out + std::distance(chunk_start, chunk_end) - 1); internal::serial_scan(chunk_start, chunk_end, chunk_output, op_, neutral_elem_);
auto last_chunk_value = *(chunk_output + chunk_size - 1);
chunk_sums_[i] = last_chunk_value;
}); });
// Calculate prefix sums of each chunks sum // Calculate prefix sums of each chunks sum
// (effectively the prefix sum at the end of each chunk, then used to correct the following chunk). // (effectively the prefix sum at the end of each chunk, then used to correct the following chunk).
internal::serial_scan(chunk_sums, chunk_sums + chunks, chunk_sums, std::plus<int>(), 0); internal::serial_scan(chunk_sums_, chunk_sums_ + chunks_, chunk_sums_, op_, neutral_elem_);
// Second Pass = Use results from first pass to correct each chunks sum // Second Pass = Use results from first pass to correct each chunks sum
auto output_start = out; auto output_start = out_;
auto output_end = out + size; auto output_end = out_ + size_;
for_each_range(1, chunks, [&](int i) { for_each_range(1, chunks_, [&](int i) {
auto chunk_start = output_start + items_per_chunk * i; auto chunk_start = output_start + items_per_chunk_ * i;
auto chunk_end = std::min(output_end, chunk_start + items_per_chunk); auto chunk_end = std::min(output_end, chunk_start + items_per_chunk_);
for (; chunk_start != chunk_end; chunk_start++) { for (; chunk_start != chunk_end; chunk_start++) {
*chunk_start = op(*chunk_start, chunk_sums[i - 1]); *chunk_start = op_(*chunk_start, chunk_sums_[i - 1]);
} }
}); });
}); }
};
}
template<typename InIter, typename OutIter, typename BinaryOp, typename Type>
void scan(InIter in_start, const InIter in_end, OutIter out, BinaryOp op, Type neutral_elem) {
using namespace pls::internal::scheduling;
// End this work section by cleaning up stack and tasks using scan_task_type = internal::scan_task<InIter, OutIter, BinaryOp, Type>;
scheduler::wait_for_all(); scheduler::spawn_child_and_wait<scan_task_type>(in_start, in_end, out, op, neutral_elem);
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment