scheduler_impl.h 16.2 KB
Newer Older
1 2 3 4

#ifndef PLS_SCHEDULER_IMPL_H
#define PLS_SCHEDULER_IMPL_H

5
#include <utility>
6

7 8 9
#include "context_switcher/context_switcher.h"
#include "context_switcher/continuation.h"

10 11 12 13
#include "pls/internal/scheduling/task_manager.h"
#include "pls/internal/scheduling/base_task.h"
#include "base_task.h"

14 15
#include "pls/internal/profiling/dag_node.h"

16
namespace pls::internal::scheduling {
17

18 19 20 21 22
template<typename ALLOC>
scheduler::scheduler(unsigned int num_threads,
                     size_t computation_depth,
                     size_t stack_size,
                     bool reuse_thread,
23
                     size_t serial_stack_size,
24 25 26 27 28 29 30 31 32
                     ALLOC &&stack_allocator) :
    num_threads_{num_threads},
    reuse_thread_{reuse_thread},
    sync_barrier_{num_threads + 1 - reuse_thread},
    worker_threads_{},
    thread_states_{},
    main_thread_starter_function_{nullptr},
    work_section_done_{false},
    terminated_{false},
33 34
    stack_allocator_{std::make_shared<ALLOC>(std::forward<ALLOC>(stack_allocator))},
    serial_stack_size_{serial_stack_size}
35
#if PLS_PROFILING_ENABLED
36
, profiler_{num_threads}
37 38
#endif
{
39 40 41 42

  worker_threads_.reserve(num_threads);
  task_managers_.reserve(num_threads);
  thread_states_.reserve(num_threads);
43
  std::atomic<unsigned> num_spawned{0};
44 45 46 47 48 49
  for (unsigned int i = 0; i < num_threads_; i++) {
    auto &this_task_manager =
        task_managers_.emplace_back(std::make_unique<task_manager>(i,
                                                                   computation_depth,
                                                                   stack_size,
                                                                   stack_allocator_));
50 51 52 53 54
    auto &this_thread_state = thread_states_.emplace_back(std::make_unique<thread_state>(*this,
                                                                                         i,
                                                                                         *this_task_manager,
                                                                                         stack_allocator_,
                                                                                         serial_stack_size));
55 56 57

    if (reuse_thread && i == 0) {
      worker_threads_.emplace_back();
58
      num_spawned++;
59 60 61 62
      continue; // Skip over first/main thread when re-using the users thread, as this one will replace the first one.
    }

    auto *this_thread_state_pointer = this_thread_state.get();
63 64

    worker_threads_.emplace_back([this_thread_state_pointer, &num_spawned] {
65
      thread_state::set(this_thread_state_pointer);
66
      num_spawned++;
67 68 69
      work_thread_main_loop();
    });
  }
70

71 72 73
  while (num_spawned < num_threads) {
    std::this_thread::yield();
  }
74 75
}

76 77
class scheduler::init_function {
 public:
78
  virtual void run() = 0;
79 80 81 82 83
};
template<typename F>
class scheduler::init_function_impl : public init_function {
 public:
  explicit init_function_impl(F &function) : function_{function} {}
84
  void run() override {
85
    base_task *root_task = thread_state::get().get_active_task();
86
    root_task->attached_resources_.store(nullptr, std::memory_order_relaxed);
87 88 89 90 91 92 93 94 95

#if PLS_PROFILING_ENABLED
    thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(),
                                                                     root_task->profiling_node_);
    thread_state::get().get_scheduler().profiler_.task_prepare_stack_measure(thread_state::get().get_thread_id(),
                                                                             root_task->stack_memory_,
                                                                             root_task->stack_size_);
#endif

96
    root_task->run_as_task([root_task, this](auto cont) {
97
      root_task->is_synchronized_ = true;
98
      thread_state::get().main_continuation() = std::move(cont);
99
      function_();
100

101
      thread_state::get().get_scheduler().work_section_done_.store(true);
102
      PLS_ASSERT(thread_state::get().main_continuation().valid(), "Must return valid continuation from main task.");
103

104
#if PLS_SLEEP_WORKERS_ON_EMPTY
105
      thread_state::get().get_scheduler().empty_queue_reset_and_wake_all();
106 107
#endif

108
#if PLS_PROFILING_ENABLED
FritzFlorian committed
109 110 111 112
      thread_state::get().get_scheduler().profiler_.task_finish_stack_measure(thread_state::get().get_thread_id(),
                                                                              root_task->stack_memory_,
                                                                              root_task->stack_size_,
                                                                              root_task->profiling_node_);
113 114 115
      thread_state::get().get_scheduler().profiler_.task_stop_running(thread_state::get().get_thread_id(),
                                                                      root_task->profiling_node_);
#endif
FritzFlorian committed
116

117
      return std::move(thread_state::get().main_continuation());
118 119 120 121 122
    });
  }
 private:
  F &function_;
};
123

124 125
template<typename Function>
void scheduler::perform_work(Function work_section) {
126
  // Prepare main root task
127 128
  init_function_impl<Function> starter_function{work_section};
  main_thread_starter_function_ = &starter_function;
129

130 131 132 133 134
#if PLS_PROFILING_ENABLED
  auto *root_task = thread_state_for(0).get_active_task();
  auto *root_node = profiler_.start_profiler_run();
  root_task->profiling_node_ = root_node;
#endif
135
#if PLS_SLEEP_WORKERS_ON_EMPTY
136
  empty_queue_counter_.store(0, std::memory_order_relaxed);
137
  for (auto &thread_state : thread_states_) {
138
    thread_state->get_queue_empty_flag().store({EMPTY_QUEUE_STATE::QUEUE_NON_EMPTY}, std::memory_order_relaxed);
139 140 141
  }
#endif

142
  work_section_done_ = false;
143
  if (reuse_thread_) {
144 145
    auto &my_state = thread_state_for(0);
    thread_state::set(&my_state); // Make THIS THREAD become the main worker
146 147

    sync_barrier_.wait(); // Trigger threads to wake up
148
    work_thread_work_section(); // Simply also perform the work section on the main loop
149 150 151 152 153 154
    sync_barrier_.wait(); // Wait for threads to finish
  } else {
    // Simply trigger the others to do the work, this thread will sleep/wait for the time being
    sync_barrier_.wait(); // Trigger threads to wake up
    sync_barrier_.wait(); // Wait for threads to finish
  }
155 156 157
#if PLS_PROFILING_ENABLED
  profiler_.stop_profiler_run();
#endif
158 159
}

160
template<typename Function>
161 162 163
void scheduler::spawn_internal(Function &&lambda) {
  if (thread_state::is_scheduler_active()) {
    thread_state &spawning_state = thread_state::get();
164

165 166
    base_task *last_task = spawning_state.get_active_task();
    base_task *spawned_task = last_task->next_;
167

168 169 170 171 172 173 174 175 176 177 178 179 180
    // We are at the end of our allocated tasks, go over to serial execution.
    if (spawned_task == nullptr) {
      scheduler::serial_internal(std::forward<Function>(lambda));
      return;
    }
    // We are in a serial section.
    // For now we simply stay serial. Later versions could add nested parallel sections.
    if (last_task->is_serial_section_) {
      lambda();
      return;
    }

    // Carry over the resources allocated by parents
181 182 183
    auto *attached_resources = last_task->attached_resources_.load(std::memory_order_relaxed);
    spawned_task->attached_resources_.store(attached_resources, std::memory_order_relaxed);

184
#if PLS_PROFILING_ENABLED
185 186 187 188 189 190
    spawning_state.get_scheduler().profiler_.task_prepare_stack_measure(spawning_state.get_thread_id(),
                                                                        spawned_task->stack_memory_,
                                                                        spawned_task->stack_size_);
    auto *child_dag_node = spawning_state.get_scheduler().profiler_.task_spawn_child(spawning_state.get_thread_id(),
                                                                                     last_task->profiling_node_);
    spawned_task->profiling_node_ = child_dag_node;
191 192
#endif

193 194 195
    auto continuation = spawned_task->run_as_task([last_task, spawned_task, lambda, &spawning_state](auto cont) {
      // allow stealing threads to continue the last task.
      last_task->continuation_ = std::move(cont);
196

197 198 199
      // we are now executing the new task, allow others to steal the last task continuation.
      spawned_task->is_synchronized_ = true;
      spawning_state.set_active_task(spawned_task);
200

201
      spawning_state.get_task_manager().push_local_task(last_task);
202
#if PLS_SLEEP_WORKERS_ON_EMPTY
203 204
      data_structures::stamped_integer
          queue_empty_flag = spawning_state.get_queue_empty_flag().load(std::memory_order_relaxed);
205
      switch (queue_empty_flag.value_) {
206 207 208 209 210 211 212
        case EMPTY_QUEUE_STATE::QUEUE_NON_EMPTY: {
          // The queue was not found empty, ignore it.
          break;
        }
        case EMPTY_QUEUE_STATE::QUEUE_MAYBE_EMPTY: {
          // Someone tries to mark us empty and might be re-stealing right now.
          data_structures::stamped_integer
213
              queue_non_empty_flag{queue_empty_flag.stamp_++, EMPTY_QUEUE_STATE::QUEUE_NON_EMPTY};
214 215
          auto actual_empty_flag =
              spawning_state.get_queue_empty_flag().exchange(queue_non_empty_flag, std::memory_order_acq_rel);
216
          if (actual_empty_flag.value_ == EMPTY_QUEUE_STATE::QUEUE_EMPTY) {
217 218 219 220 221 222 223
            spawning_state.get_scheduler().empty_queue_decrease_counter_and_wake();
          }
          break;
        }
        case EMPTY_QUEUE_STATE::QUEUE_EMPTY: {
          // Someone already marked the queue empty, we must revert its action on the central queue.
          data_structures::stamped_integer
224
              queue_non_empty_flag{queue_empty_flag.stamp_++, EMPTY_QUEUE_STATE::QUEUE_NON_EMPTY};
225
          spawning_state.get_queue_empty_flag().store(queue_non_empty_flag, std::memory_order_release);
226
          spawning_state.get_scheduler().empty_queue_decrease_counter_and_wake();
227
          break;
228 229 230
        }
      }
#endif
231

232
      // execute the lambda itself, which could lead to a different thread returning.
233
#if PLS_PROFILING_ENABLED
234 235 236 237
      spawning_state.get_scheduler().profiler_.task_stop_running(spawning_state.get_thread_id(),
                                                                 last_task->profiling_node_);
      spawning_state.get_scheduler().profiler_.task_start_running(spawning_state.get_thread_id(),
                                                                  spawned_task->profiling_node_);
238
#endif
239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
      lambda();

      thread_state &syncing_state = thread_state::get();
      PLS_ASSERT(syncing_state.get_active_task() == spawned_task,
                 "Task manager must always point its active task onto whats executing.");

      // try to pop a task of the syncing task manager.
      // possible outcomes:
      // - this is a different task manager, it must have an empty deque and fail
      // - this is the same task manager and someone stole last tasks, thus this will fail
      // - this is the same task manager and no one stole the last task, this this will succeed
      base_task *popped_task = syncing_state.get_task_manager().pop_local_task();
      if (popped_task) {
        // Fast path, simply continue execution where we left of before spawn.
        PLS_ASSERT(popped_task == last_task,
                   "Fast path, nothing can have changed until here.");
        PLS_ASSERT(&spawning_state == &syncing_state,
                   "Fast path, we must only return if the task has not been stolen/moved to other thread.");
        PLS_ASSERT(last_task->continuation_.valid(),
                   "Fast path, no one can have continued working on the last task.");

        syncing_state.set_active_task(last_task);
261
#if PLS_PROFILING_ENABLED
262 263 264 265 266 267
        syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(),
                                                                          spawned_task->stack_memory_,
                                                                          spawned_task->stack_size_,
                                                                          spawned_task->profiling_node_);
        syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(),
                                                                  spawned_task->profiling_node_);
268
#endif
269 270 271
        return std::move(last_task->continuation_);
      } else {
        // Slow path, the last task was stolen. This path is common to sync() events.
272
#if PLS_PROFILING_ENABLED
273 274 275 276 277 278
        syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(),
                                                                          spawned_task->stack_memory_,
                                                                          spawned_task->stack_size_,
                                                                          spawned_task->profiling_node_);
        syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(),
                                                                  spawned_task->profiling_node_);
279
#endif
280
        auto continuation = slow_return(syncing_state, false);
281 282 283
        return continuation;
      }
    });
284

285
#if PLS_PROFILING_ENABLED
286 287
    thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(),
                                                                     last_task->profiling_node_);
288 289
#endif

290 291 292 293 294 295 296
    if (continuation.valid()) {
      // We jumped in here from the main loop, keep track!
      thread_state::get().main_continuation() = std::move(continuation);
    }
  } else {
    // Scheduler not active...
    lambda();
297
  }
298 299
}

300
template<typename Function>
301 302
void scheduler::spawn_and_sync_internal(Function &&lambda) {
  thread_state &spawning_state = thread_state::get();
303

304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344
  base_task *last_task = spawning_state.get_active_task();
  base_task *spawned_task = last_task->next_;

  // We are at the end of our allocated tasks, go over to serial execution.
  if (spawned_task == nullptr) {
    scheduler::serial_internal(std::forward<Function>(lambda));
    return;
  }
  // We are in a serial section.
  // For now we simply stay serial. Later versions could add nested parallel sections.
  if (last_task->is_serial_section_) {
    lambda();
    return;
  }

  // Carry over the resources allocated by parents
  auto *attached_resources = last_task->attached_resources_.load(std::memory_order_relaxed);
  spawned_task->attached_resources_.store(attached_resources, std::memory_order_relaxed);

  auto continuation = spawned_task->run_as_task([last_task, spawned_task, lambda, &spawning_state](auto cont) {
    // allow stealing threads to continue the last task.
    last_task->continuation_ = std::move(cont);

    // we are now executing the new task, allow others to steal the last task continuation.
    spawned_task->is_synchronized_ = true;
    spawning_state.set_active_task(spawned_task);

    // execute the lambda itself, which could lead to a different thread returning.
    lambda();

    thread_state &syncing_state = thread_state::get();
    PLS_ASSERT(syncing_state.get_active_task() == spawned_task,
               "Task manager must always point its active task onto whats executing.");

    if (&syncing_state == &spawning_state && last_task->is_synchronized_) {
      // Fast path, simply continue execution where we left of before spawn.
      PLS_ASSERT(last_task->continuation_.valid(),
                 "Fast path, no one can have continued working on the last task.");

      syncing_state.set_active_task(last_task);
      return std::move(last_task->continuation_);
345
    } else {
346 347 348
      // Slow path, the last task was stolen. This path is common to sync() events.
      auto continuation = slow_return(syncing_state, false);
      return continuation;
349
    }
350 351 352 353 354 355 356 357 358 359 360
  });

  PLS_ASSERT(!continuation.valid(), "We must not jump to a not-published (synced) spawn.");
}

template<typename Function>
void scheduler::serial_internal(Function &&lambda) {
  thread_state &spawning_state = thread_state::get();
  base_task *active_task = spawning_state.get_active_task();

  if (active_task->is_serial_section_) {
361
    lambda();
362 363 364 365 366 367 368 369 370
  } else {
    active_task->is_serial_section_ = true;
    context_switcher::enter_context(spawning_state.get_serial_call_stack(),
                                    spawning_state.get_serial_call_stack_size(),
                                    [&](auto cont) {
                                      lambda();
                                      return cont;
                                    });
    active_task->is_serial_section_ = false;
371 372 373
  }
}

374
}
375 376

#endif //PLS_SCHEDULER_IMPL_H