diff --git a/algorithms_cpp/include/embb/algorithms/internal/merge_sort-inl.h b/algorithms_cpp/include/embb/algorithms/internal/merge_sort-inl.h index db0e7fe..e86e1c9 100644 --- a/algorithms_cpp/include/embb/algorithms/internal/merge_sort-inl.h +++ b/algorithms_cpp/include/embb/algorithms/internal/merge_sort-inl.h @@ -48,57 +48,69 @@ class MergeSortFunctor { public: typedef typename std::iterator_traits::value_type value_type; - MergeSortFunctor(RAI first, RAI last, RAITemp temporary_first, - ComparisonFunction comparison, const embb::mtapi::ExecutionPolicy& policy, - size_t block_size, const RAI& global_first, int depth) - : first_(first), last_(last), temp_first_(temporary_first), - comparison_(comparison), policy_(policy), block_size_(block_size), + MergeSortFunctor(size_t chunk_first, size_t chunk_last, + RAITemp temporary_first, ComparisonFunction comparison, + const embb::mtapi::ExecutionPolicy& policy, + const BlockSizePartitioner& partitioner, + const RAI& global_first, int depth) + : chunk_first_(chunk_first), chunk_last_(chunk_last), + temp_first_(temporary_first), + comparison_(comparison), policy_(policy), partitioner_(partitioner), global_first_(global_first), depth_(depth) { } - void Action(mtapi::TaskContext& context) { - typedef typename std::iterator_traits::difference_type difference_type; - size_t distance = static_cast(std::distance(first_, last_)); - if (distance <= 1) { - if(!CloneBackToInput() && distance != 0) { - RAITemp temp_first = temp_first_; - temp_first += std::distance(global_first_, first_); - *temp_first = *first_; - } + void Action(mtapi::TaskContext&) { + typedef typename std::iterator_traits::difference_type + difference_type; + size_t chunk_split_index = (chunk_first_ + chunk_last_) / 2; + if (chunk_first_ == chunk_last_) { + // Leaf case: recurse into a single chunk's elements: + ChunkDescriptor chunk = partitioner_[chunk_first_]; + MergeSortChunkFunctor functor(chunk.GetFirst(), + chunk.GetLast(), + temp_first_, + global_first_, + depth_); + functor.Action(); return; - } - internal::ChunkPartitioner partitioner(first_, last_, 2); - MergeSortFunctor functorL( - partitioner[0].GetFirst(), partitioner[0].GetLast(), temp_first_, - comparison_, policy_, block_size_, global_first_, depth_ + 1); - MergeSortFunctor functorR( - partitioner[1].GetFirst(), partitioner[1].GetLast(), temp_first_, - comparison_, policy_, block_size_, global_first_, depth_ + 1); - - if (distance <= block_size_) { - functorL.Action(context); - functorR.Action(context); - } else { - mtapi::Node& node = mtapi::Node::GetInstance(); - mtapi::Task taskL = node.Spawn(mtapi::Action(base::MakeFunction(functorL, - &MergeSortFunctor::Action), - policy_)); - mtapi::Task taskR = node.Spawn(mtapi::Action(base::MakeFunction(functorR, - &MergeSortFunctor::Action), - policy_)); - taskL.Wait(MTAPI_INFINITE); - taskR.Wait(MTAPI_INFINITE); - } + } + // Recurse further: + // Split chunks into left / right branches: + self_t functor_l(chunk_first_, + chunk_split_index, + temp_first_, + comparison_, policy_, partitioner_, + global_first_, depth_ + 1); + self_t functor_r(chunk_split_index + 1, + chunk_last_, + temp_first_, + comparison_, policy_, partitioner_, + global_first_, depth_ + 1); + mtapi::Node& node = mtapi::Node::GetInstance(); + mtapi::Task task_l = node.Spawn( + mtapi::Action( + base::MakeFunction(functor_l, &self_t::Action), + policy_)); + mtapi::Task task_r = node.Spawn( + mtapi::Action( + base::MakeFunction(functor_r, &self_t::Action), + policy_)); + task_l.Wait(MTAPI_INFINITE); + task_r.Wait(MTAPI_INFINITE); + ChunkDescriptor chunk_f = partitioner_[chunk_first_]; + ChunkDescriptor chunk_m = partitioner_[chunk_split_index + 1]; + ChunkDescriptor chunk_l = partitioner_[chunk_last_]; if(CloneBackToInput()) { - difference_type first = std::distance(global_first_, functorL.first_); - difference_type mid = std::distance(global_first_, functorR.first_); - difference_type last = std::distance(global_first_, functorR.last_); - SerialMerge(temp_first_ + first, temp_first_ + mid, - temp_first_ + last, functorL.first_, comparison_); + difference_type first = std::distance(global_first_, chunk_f.GetFirst()); + difference_type mid = std::distance(global_first_, chunk_m.GetFirst()); + difference_type last = std::distance(global_first_, chunk_l.GetLast()); + SerialMerge(temp_first_ + first, temp_first_ + mid, temp_first_ + last, + chunk_f.GetFirst(), + comparison_); } else { - SerialMerge(functorL.first_, functorR.first_, functorR.last_, - temp_first_ + std::distance(global_first_, functorL.first_), + SerialMerge(chunk_f.GetFirst(), chunk_m.GetFirst(), chunk_l.GetLast(), + temp_first_ + std::distance(global_first_, chunk_f.GetFirst()), comparison_); } } @@ -114,12 +126,77 @@ class MergeSortFunctor { } private: - RAI first_; - RAI last_; + typedef MergeSortFunctor self_t; + + private: + /** + * Non-parallelized part of merge sort on elements within a single chunk. + */ + class MergeSortChunkFunctor { + public: + MergeSortChunkFunctor(RAI first, RAI last, + RAITemp temp_first, + const RAI & global_first, + int depth) + : first_(first), last_(last), + temp_first_(temp_first), global_first_(global_first), + depth_(depth) { + } + + void Action() { + size_t distance = static_cast( + std::distance(first_, last_)); + if (distance <= 1) { + // Leaf case: + if(!CloneBackToInput() && distance != 0) { + RAITemp temp_first = temp_first_; + std::advance(temp_first, std::distance(global_first_, first_)); + *temp_first = *first_; + } + return; + } + // Recurse further. Use binary split, ignoring chunk size as this + // recursion is serial: + ChunkPartitioner partitioner(first_, last_, 2); + ChunkDescriptor chunk_l = partitioner[0]; + ChunkDescriptor chunk_r = partitioner[1]; + MergeSortChunkFunctor functor_l( + chunk_l.GetFirst(), + chunk_l.GetLast(), + temp_first_, global_first_, depth_ + 1); + MergeSortChunkFunctor functor_r( + chunk_r.GetFirst(), + chunk_r.GetLast(), + temp_first_, global_first_, depth_ + 1); + functor_l.Action(); + functor_r.Action(); + } + + private: + /** + * Determines the input and output arrays for one level in merge sort. + * + * \return \c true if the temporary data range is input and the array to be + * sorted is output. \c false, if the other way around. + */ + bool CloneBackToInput() { + return depth_ % 2 == 0 ? true : false; + } + + RAI first_; + RAI last_; + RAITemp temp_first_; + RAI global_first_; + int depth_; + }; + + private: + size_t chunk_first_; + size_t chunk_last_; RAITemp temp_first_; ComparisonFunction comparison_; const embb::mtapi::ExecutionPolicy& policy_; - size_t block_size_; + const BlockSizePartitioner& partitioner_; const RAI& global_first_; int depth_; @@ -180,8 +257,16 @@ void MergeSort( "Not enough MTAPI tasks available to perform the merge sort"); } + internal::BlockSizePartitioner partitioner(first, last, block_size); + internal::MergeSortFunctor functor( - first, last, temporary_first, comparison, policy, block_size, first, 0); + 0, partitioner.Size() - 1, + temporary_first, + comparison, + policy, + partitioner, + first, + 0); mtapi::Task task = node.Spawn(mtapi::Action(base::MakeFunction(functor, &internal::MergeSortFunctor::Action), policy));