merge_sort-inl.h 9.69 KB
Newer Older
1
/*
Marcus Winter committed
2
 * Copyright (c) 2014-2016, Siemens AG. All rights reserved.
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice,
 * this list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 * this list of conditions and the following disclaimer in the documentation
 * and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */

#ifndef EMBB_ALGORITHMS_INTERNAL_MERGE_SORT_INL_H_
#define EMBB_ALGORITHMS_INTERNAL_MERGE_SORT_INL_H_

#include <cassert>
#include <iterator>
#include <functional>

#include <embb/base/exceptions.h>
35
#include <embb/mtapi/mtapi.h>
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
#include <embb/algorithms/internal/partition.h>

namespace embb {
namespace algorithms {

namespace internal {

/**
 * Contains the merge sort MTAPI action function and data needed there.
 */
template <typename RAI, typename RAITemp, typename ComparisonFunction>
class MergeSortFunctor {
 public:
  typedef typename std::iterator_traits<RAI>::value_type value_type;

51 52
  MergeSortFunctor(size_t chunk_first, size_t chunk_last,
                   RAITemp temporary_first, ComparisonFunction comparison,
53
                   const embb::mtapi::ExecutionPolicy& policy,
54 55
                   const BlockSizePartitioner<RAI>& partitioner,
                   const RAI& global_first, int depth)
56 57 58 59
  : chunk_first_(chunk_first), chunk_last_(chunk_last),
    temp_first_(temporary_first),
    comparison_(comparison), policy_(policy), partitioner_(partitioner),
    global_first_(global_first), depth_(depth) {
60 61
  }

62
  void Action(embb::mtapi::TaskContext&) {
63 64 65 66
    size_t chunk_split_index = (chunk_first_ + chunk_last_) / 2;
    if (chunk_first_ == chunk_last_) {
      // Leaf case: recurse into a single chunk's elements:
      ChunkDescriptor<RAI> chunk = partitioner_[chunk_first_];
67
      MergeSortChunk(chunk.GetFirst(), chunk.GetLast(), depth_);
68
    } else {
69 70 71 72 73 74 75 76 77 78 79
      // Recurse further, split chunks:
      self_t functor_l(chunk_first_,
                       chunk_split_index,
                       temp_first_,
                       comparison_, policy_, partitioner_,
                       global_first_, depth_ + 1);
      self_t functor_r(chunk_split_index + 1,
                       chunk_last_,
                       temp_first_,
                       comparison_, policy_, partitioner_,
                       global_first_, depth_ + 1);
80 81 82 83 84 85 86
      embb::mtapi::Node& node = embb::mtapi::Node::GetInstance();
      embb::mtapi::Task task_l = node.Start(
        base::MakeFunction(functor_l, &self_t::Action),
        policy_);
      embb::mtapi::Task task_r = node.Start(
        base::MakeFunction(functor_r, &self_t::Action),
        policy_);
87 88 89
      task_l.Wait(MTAPI_INFINITE);
      task_r.Wait(MTAPI_INFINITE);

90 91 92
      ChunkDescriptor<RAI> ck_f = partitioner_[chunk_first_];
      ChunkDescriptor<RAI> ck_m = partitioner_[chunk_split_index + 1];
      ChunkDescriptor<RAI> ck_l = partitioner_[chunk_last_];
93 94
      if(CloneBackToInput(depth_)) {
        // Merge from temp into input:
95 96 97
        difference_type first = std::distance(global_first_, ck_f.GetFirst());
        difference_type mid   = std::distance(global_first_, ck_m.GetFirst());
        difference_type last  = std::distance(global_first_, ck_l.GetLast());
98
        SerialMerge(temp_first_ + first, temp_first_ + mid, temp_first_ + last,
99
                    ck_f.GetFirst(),
100 101 102
                    comparison_);
      } else {
        // Merge from input into temp:
103 104
        SerialMerge(ck_f.GetFirst(), ck_m.GetFirst(), ck_l.GetLast(),
                    temp_first_ + std::distance(global_first_, ck_f.GetFirst()),
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
                    comparison_);
      }
    }
  }

  /**
   * Serial merge sort of elements within a single chunk.
   */
  void MergeSortChunk(RAI first,
                      RAI last,
                      int depth) {
    size_t distance = static_cast<size_t>(
      std::distance(first, last));
    if (distance <= 1) {
      // Leaf case:
      if (!CloneBackToInput(depth) && distance != 0) {
        RAITemp temp_first = temp_first_;
        std::advance(temp_first, std::distance(global_first_, first));
        *temp_first = *first;
      }
      return;
    }
    // Recurse further. Use binary split, ignoring chunk size as this
    // recursion is serial and has leaf size 1:
    ChunkPartitioner<RAI> partitioner(first, last, 2);
130 131
    ChunkDescriptor<RAI> ck_l = partitioner[0];
    ChunkDescriptor<RAI> ck_r = partitioner[1];
132
    MergeSortChunk(
133 134
      ck_l.GetFirst(),
      ck_l.GetLast(),
135 136
      depth + 1);
    MergeSortChunk(
137 138
      ck_r.GetFirst(),
      ck_r.GetLast(),
139 140 141
      depth + 1);
    if (CloneBackToInput(depth)) {
      // Merge from temp into input:
142 143 144
      difference_type d_first = std::distance(global_first_, ck_l.GetFirst());
      difference_type d_mid   = std::distance(global_first_, ck_r.GetFirst());
      difference_type d_last  = std::distance(global_first_, ck_r.GetLast());
145 146
      SerialMerge(
        temp_first_ + d_first, temp_first_ + d_mid, temp_first_ + d_last,
147
        ck_l.GetFirst(),
148
        comparison_);
149
    } else {
150 151
      // Merge from input into temp:
      SerialMerge(
152 153
        ck_l.GetFirst(), ck_r.GetFirst(), ck_r.GetLast(),
        temp_first_ + std::distance(global_first_, ck_l.GetFirst()),
154
        comparison_);
155 156 157 158 159 160 161 162 163
    }
  }

  /**
   * Determines the input and output arrays for one level in merge sort.
   *
   * \return \c true if the temporary data range is input and the array to be
   *         sorted is output. \c false, if the other way around.
   */
164 165
  bool CloneBackToInput(int depth) {
    return depth % 2 == 0 ? true : false;
166 167 168
  }

 private:
169
  typedef MergeSortFunctor<RAI, RAITemp, ComparisonFunction> self_t;
170 171
  typedef typename std::iterator_traits<RAI>::difference_type
    difference_type;
172 173 174 175

 private:
  size_t chunk_first_;
  size_t chunk_last_;
176 177
  RAITemp temp_first_;
  ComparisonFunction comparison_;
178
  const embb::mtapi::ExecutionPolicy& policy_;
179
  const BlockSizePartitioner<RAI>& partitioner_;
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
  const RAI& global_first_;
  int depth_;

  MergeSortFunctor(const MergeSortFunctor&);
  MergeSortFunctor& operator=(const MergeSortFunctor&);

  template<typename RAIIn, typename RAIOut>
  void SerialMerge(RAIIn first, RAIIn mid, RAIIn last, RAIOut out,
                   ComparisonFunction comparison) {
    RAIIn save_mid = mid;
    while ((first != save_mid) && (mid != last)) {
      if (comparison(*first, *mid)) {
        *out = *first;
        ++out;
        ++first;
      } else {
        *out = *mid;
        ++out;
        ++mid;
      }
    }
    while (first != save_mid) {
      *out = *first;
      ++out;
      ++first;
    }
    while(mid != last) {
      *out = *mid;
      ++out;
      ++mid;
    }
  }
};

template<typename RAI, typename RAITemp, typename ComparisonFunction>
215
void MergeSortIteratorCheck(
216 217 218 219
  RAI first,
  RAI last,
  RAITemp temporary_first,
  ComparisonFunction comparison,
220
  const embb::mtapi::ExecutionPolicy& policy,
221 222
  size_t block_size,
  std::random_access_iterator_tag
223 224
  ) {
  typedef typename std::iterator_traits<RAI>::difference_type difference_type;
225
  typedef MergeSortFunctor<RAI, RAITemp, ComparisonFunction>
226
    functor_t;
227 228
  difference_type distance = std::distance(first, last);
  if (distance == 0) {
229 230 231
    return;
  } else if (distance < 0) {
    EMBB_THROW(embb::base::ErrorException, "Negative range for MergeSort");
232
  }
233 234 235 236
  unsigned int num_cores = policy.GetCoreCount();
  if (num_cores == 0) {
    EMBB_THROW(embb::base::ErrorException, "No cores in execution policy");
  }
237
  // Determine actually used block size
238
  if (block_size == 0) {
239
    block_size = (static_cast<size_t>(distance) / num_cores);
240 241 242
    if (block_size == 0)
      block_size = 1;
  }
243 244
  // Check task number sufficiency
  if (((distance / block_size) * 2) + 1 > MTAPI_NODE_MAX_TASKS_DEFAULT) {
245
    EMBB_THROW(embb::base::ErrorException,
246
               "Not enough MTAPI tasks available to perform merge sort");
247 248
  }

249
  BlockSizePartitioner<RAI> partitioner(first, last, block_size);
250 251 252 253 254 255 256 257
  functor_t functor(0,
                    partitioner.Size() - 1,
                    temporary_first,
                    comparison,
                    policy,
                    partitioner,
                    first,
                    0);
258 259 260
  embb::mtapi::Task task = embb::mtapi::Node::GetInstance().Start(
    base::MakeFunction(functor, &functor_t::Action),
    policy);
261 262 263 264

  task.Wait(MTAPI_INFINITE);
}

265 266 267 268
}  // namespace internal

template<typename RAI, typename RAITemp, typename ComparisonFunction>
void MergeSort(RAI first, RAI last, RAITemp temporary_first,
269
  ComparisonFunction comparison, const embb::mtapi::ExecutionPolicy& policy,
270 271 272 273 274 275
  size_t block_size) {
  typedef typename std::iterator_traits<RAI>::iterator_category category;
  internal::MergeSortIteratorCheck(first, last, temporary_first, comparison,
    policy, block_size, category());
}

276 277 278 279
}  // namespace algorithms
}  // namespace embb

#endif  // EMBB_ALGORITHMS_INTERNAL_MERGE_SORT_INL_H_