merge_sort-inl.h 9.78 KB
Newer Older
1
/*
Marcus Winter committed
2
 * Copyright (c) 2014-2016, Siemens AG. All rights reserved.
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice,
 * this list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 * this list of conditions and the following disclaimer in the documentation
 * and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */

#ifndef EMBB_ALGORITHMS_INTERNAL_MERGE_SORT_INL_H_
#define EMBB_ALGORITHMS_INTERNAL_MERGE_SORT_INL_H_

#include <cassert>
#include <iterator>
#include <functional>

#include <embb/base/exceptions.h>
35
#include <embb/tasks/tasks.h>
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
#include <embb/algorithms/internal/partition.h>

namespace embb {
namespace algorithms {

namespace internal {

/**
 * Contains the merge sort MTAPI action function and data needed there.
 */
template <typename RAI, typename RAITemp, typename ComparisonFunction>
class MergeSortFunctor {
 public:
  typedef typename std::iterator_traits<RAI>::value_type value_type;

51 52
  MergeSortFunctor(size_t chunk_first, size_t chunk_last,
                   RAITemp temporary_first, ComparisonFunction comparison,
53
                   const embb::tasks::ExecutionPolicy& policy,
54 55
                   const BlockSizePartitioner<RAI>& partitioner,
                   const RAI& global_first, int depth)
56 57 58 59
  : chunk_first_(chunk_first), chunk_last_(chunk_last),
    temp_first_(temporary_first),
    comparison_(comparison), policy_(policy), partitioner_(partitioner),
    global_first_(global_first), depth_(depth) {
60 61
  }

62
  void Action(embb::tasks::TaskContext&) {
63 64 65 66
    size_t chunk_split_index = (chunk_first_ + chunk_last_) / 2;
    if (chunk_first_ == chunk_last_) {
      // Leaf case: recurse into a single chunk's elements:
      ChunkDescriptor<RAI> chunk = partitioner_[chunk_first_];
67
      MergeSortChunk(chunk.GetFirst(), chunk.GetLast(), depth_);
68
    } else {
69 70 71 72 73 74 75 76 77 78 79
      // Recurse further, split chunks:
      self_t functor_l(chunk_first_,
                       chunk_split_index,
                       temp_first_,
                       comparison_, policy_, partitioner_,
                       global_first_, depth_ + 1);
      self_t functor_r(chunk_split_index + 1,
                       chunk_last_,
                       temp_first_,
                       comparison_, policy_, partitioner_,
                       global_first_, depth_ + 1);
80 81 82
      embb::tasks::Node& node = embb::tasks::Node::GetInstance();
      embb::tasks::Task task_l = node.Spawn(
        embb::tasks::Action(
83 84
          base::MakeFunction(functor_l, &self_t::Action),
          policy_));
85 86
      embb::tasks::Task task_r = node.Spawn(
        embb::tasks::Action(
87 88 89 90 91
          base::MakeFunction(functor_r, &self_t::Action),
          policy_));
      task_l.Wait(MTAPI_INFINITE);
      task_r.Wait(MTAPI_INFINITE);

92 93 94
      ChunkDescriptor<RAI> ck_f = partitioner_[chunk_first_];
      ChunkDescriptor<RAI> ck_m = partitioner_[chunk_split_index + 1];
      ChunkDescriptor<RAI> ck_l = partitioner_[chunk_last_];
95 96
      if(CloneBackToInput(depth_)) {
        // Merge from temp into input:
97 98 99
        difference_type first = std::distance(global_first_, ck_f.GetFirst());
        difference_type mid   = std::distance(global_first_, ck_m.GetFirst());
        difference_type last  = std::distance(global_first_, ck_l.GetLast());
100
        SerialMerge(temp_first_ + first, temp_first_ + mid, temp_first_ + last,
101
                    ck_f.GetFirst(),
102 103 104
                    comparison_);
      } else {
        // Merge from input into temp:
105 106
        SerialMerge(ck_f.GetFirst(), ck_m.GetFirst(), ck_l.GetLast(),
                    temp_first_ + std::distance(global_first_, ck_f.GetFirst()),
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
                    comparison_);
      }
    }
  }

  /**
   * Serial merge sort of elements within a single chunk.
   */
  void MergeSortChunk(RAI first,
                      RAI last,
                      int depth) {
    size_t distance = static_cast<size_t>(
      std::distance(first, last));
    if (distance <= 1) {
      // Leaf case:
      if (!CloneBackToInput(depth) && distance != 0) {
        RAITemp temp_first = temp_first_;
        std::advance(temp_first, std::distance(global_first_, first));
        *temp_first = *first;
      }
      return;
    }
    // Recurse further. Use binary split, ignoring chunk size as this
    // recursion is serial and has leaf size 1:
    ChunkPartitioner<RAI> partitioner(first, last, 2);
132 133
    ChunkDescriptor<RAI> ck_l = partitioner[0];
    ChunkDescriptor<RAI> ck_r = partitioner[1];
134
    MergeSortChunk(
135 136
      ck_l.GetFirst(),
      ck_l.GetLast(),
137 138
      depth + 1);
    MergeSortChunk(
139 140
      ck_r.GetFirst(),
      ck_r.GetLast(),
141 142 143
      depth + 1);
    if (CloneBackToInput(depth)) {
      // Merge from temp into input:
144 145 146
      difference_type d_first = std::distance(global_first_, ck_l.GetFirst());
      difference_type d_mid   = std::distance(global_first_, ck_r.GetFirst());
      difference_type d_last  = std::distance(global_first_, ck_r.GetLast());
147 148
      SerialMerge(
        temp_first_ + d_first, temp_first_ + d_mid, temp_first_ + d_last,
149
        ck_l.GetFirst(),
150
        comparison_);
151
    } else {
152 153
      // Merge from input into temp:
      SerialMerge(
154 155
        ck_l.GetFirst(), ck_r.GetFirst(), ck_r.GetLast(),
        temp_first_ + std::distance(global_first_, ck_l.GetFirst()),
156
        comparison_);
157 158 159 160 161 162 163 164 165
    }
  }

  /**
   * Determines the input and output arrays for one level in merge sort.
   *
   * \return \c true if the temporary data range is input and the array to be
   *         sorted is output. \c false, if the other way around.
   */
166 167
  bool CloneBackToInput(int depth) {
    return depth % 2 == 0 ? true : false;
168 169 170
  }

 private:
171
  typedef MergeSortFunctor<RAI, RAITemp, ComparisonFunction> self_t;
172 173
  typedef typename std::iterator_traits<RAI>::difference_type
    difference_type;
174 175 176 177

 private:
  size_t chunk_first_;
  size_t chunk_last_;
178 179
  RAITemp temp_first_;
  ComparisonFunction comparison_;
180
  const embb::tasks::ExecutionPolicy& policy_;
181
  const BlockSizePartitioner<RAI>& partitioner_;
182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
  const RAI& global_first_;
  int depth_;

  MergeSortFunctor(const MergeSortFunctor&);
  MergeSortFunctor& operator=(const MergeSortFunctor&);

  template<typename RAIIn, typename RAIOut>
  void SerialMerge(RAIIn first, RAIIn mid, RAIIn last, RAIOut out,
                   ComparisonFunction comparison) {
    RAIIn save_mid = mid;
    while ((first != save_mid) && (mid != last)) {
      if (comparison(*first, *mid)) {
        *out = *first;
        ++out;
        ++first;
      } else {
        *out = *mid;
        ++out;
        ++mid;
      }
    }
    while (first != save_mid) {
      *out = *first;
      ++out;
      ++first;
    }
    while(mid != last) {
      *out = *mid;
      ++out;
      ++mid;
    }
  }
};

template<typename RAI, typename RAITemp, typename ComparisonFunction>
217
void MergeSortIteratorCheck(
218 219 220 221
  RAI first,
  RAI last,
  RAITemp temporary_first,
  ComparisonFunction comparison,
222
  const embb::tasks::ExecutionPolicy& policy,
223 224
  size_t block_size,
  std::random_access_iterator_tag
225 226
  ) {
  typedef typename std::iterator_traits<RAI>::difference_type difference_type;
227
  typedef MergeSortFunctor<RAI, RAITemp, ComparisonFunction>
228
    functor_t;
229 230
  difference_type distance = std::distance(first, last);
  if (distance == 0) {
231 232 233
    return;
  } else if (distance < 0) {
    EMBB_THROW(embb::base::ErrorException, "Negative range for MergeSort");
234
  }
235 236 237 238
  unsigned int num_cores = policy.GetCoreCount();
  if (num_cores == 0) {
    EMBB_THROW(embb::base::ErrorException, "No cores in execution policy");
  }
239
  // Determine actually used block size
240
  if (block_size == 0) {
241
    block_size = (static_cast<size_t>(distance) / num_cores);
242 243 244
    if (block_size == 0)
      block_size = 1;
  }
245 246
  // Check task number sufficiency
  if (((distance / block_size) * 2) + 1 > MTAPI_NODE_MAX_TASKS_DEFAULT) {
247
    EMBB_THROW(embb::base::ErrorException,
248
               "Not enough MTAPI tasks available to perform merge sort");
249 250
  }

251
  BlockSizePartitioner<RAI> partitioner(first, last, block_size);
252 253 254 255 256 257 258 259
  functor_t functor(0,
                    partitioner.Size() - 1,
                    temporary_first,
                    comparison,
                    policy,
                    partitioner,
                    first,
                    0);
260 261
  embb::tasks::Task task = embb::tasks::Node::GetInstance().Spawn(
    embb::tasks::Action(
262 263
      base::MakeFunction(functor, &functor_t::Action),
      policy));
264 265 266 267

  task.Wait(MTAPI_INFINITE);
}

268 269 270 271
}  // namespace internal

template<typename RAI, typename RAITemp, typename ComparisonFunction>
void MergeSort(RAI first, RAI last, RAITemp temporary_first,
272
  ComparisonFunction comparison, const embb::tasks::ExecutionPolicy& policy,
273 274 275 276 277 278
  size_t block_size) {
  typedef typename std::iterator_traits<RAI>::iterator_category category;
  internal::MergeSortIteratorCheck(first, last, temporary_first, comparison,
    policy, block_size, category());
}

279 280 281 282
}  // namespace algorithms
}  // namespace embb

#endif  // EMBB_ALGORITHMS_INTERNAL_MERGE_SORT_INL_H_