scheduler_impl.h 19.3 KB
Newer Older
1 2 3 4

#ifndef PLS_SCHEDULER_IMPL_H
#define PLS_SCHEDULER_IMPL_H

5
#include <utility>
6

7 8 9
#include "context_switcher/context_switcher.h"
#include "context_switcher/continuation.h"

10 11 12 13
#include "pls/internal/scheduling/task_manager.h"
#include "pls/internal/scheduling/base_task.h"
#include "base_task.h"

14 15
#include "pls/internal/profiling/dag_node.h"

16
namespace pls::internal::scheduling {
17

18 19 20 21 22
template<typename ALLOC>
scheduler::scheduler(unsigned int num_threads,
                     size_t computation_depth,
                     size_t stack_size,
                     bool reuse_thread,
23
                     size_t serial_stack_size,
24 25 26 27 28 29 30 31 32
                     ALLOC &&stack_allocator) :
    num_threads_{num_threads},
    reuse_thread_{reuse_thread},
    sync_barrier_{num_threads + 1 - reuse_thread},
    worker_threads_{},
    thread_states_{},
    main_thread_starter_function_{nullptr},
    work_section_done_{false},
    terminated_{false},
33 34
    stack_allocator_{std::make_shared<ALLOC>(std::forward<ALLOC>(stack_allocator))},
    serial_stack_size_{serial_stack_size}
35
#if PLS_PROFILING_ENABLED
36
, profiler_{num_threads}
37 38
#endif
{
39 40 41 42

  worker_threads_.reserve(num_threads);
  task_managers_.reserve(num_threads);
  thread_states_.reserve(num_threads);
43
  std::atomic<unsigned> num_spawned{0};
44 45 46 47 48 49
  for (unsigned int i = 0; i < num_threads_; i++) {
    auto &this_task_manager =
        task_managers_.emplace_back(std::make_unique<task_manager>(i,
                                                                   computation_depth,
                                                                   stack_size,
                                                                   stack_allocator_));
50 51 52 53 54
    auto &this_thread_state = thread_states_.emplace_back(std::make_unique<thread_state>(*this,
                                                                                         i,
                                                                                         *this_task_manager,
                                                                                         stack_allocator_,
                                                                                         serial_stack_size));
55 56 57

    if (reuse_thread && i == 0) {
      worker_threads_.emplace_back();
58
      num_spawned++;
59 60 61 62
      continue; // Skip over first/main thread when re-using the users thread, as this one will replace the first one.
    }

    auto *this_thread_state_pointer = this_thread_state.get();
63 64

    worker_threads_.emplace_back([this_thread_state_pointer, &num_spawned] {
65
      thread_state::set(this_thread_state_pointer);
66
      num_spawned++;
67 68 69
      work_thread_main_loop();
    });
  }
70

71 72 73
  while (num_spawned < num_threads) {
    std::this_thread::yield();
  }
74 75
}

76 77
class scheduler::init_function {
 public:
78
  virtual void run() = 0;
79 80 81 82 83
};
template<typename F>
class scheduler::init_function_impl : public init_function {
 public:
  explicit init_function_impl(F &function) : function_{function} {}
84
  void run() override {
85
    base_task *root_task = thread_state::get().get_active_task();
86
    root_task->attached_resources_.store(nullptr, std::memory_order_relaxed);
87 88 89 90 91 92 93 94 95

#if PLS_PROFILING_ENABLED
    thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(),
                                                                     root_task->profiling_node_);
    thread_state::get().get_scheduler().profiler_.task_prepare_stack_measure(thread_state::get().get_thread_id(),
                                                                             root_task->stack_memory_,
                                                                             root_task->stack_size_);
#endif

96
    root_task->run_as_task([root_task, this](auto cont) {
97
      root_task->is_synchronized_ = true;
98
      thread_state::get().main_continuation() = std::move(cont);
99
      function_();
100

101
      thread_state::get().get_scheduler().work_section_done_.store(true);
102
      PLS_ASSERT(thread_state::get().main_continuation().valid(), "Must return valid continuation from main task.");
103

104 105 106 107 108
#if PLS_SLEEP_WORKERS_ON_EMPTY
      thread_state::get().get_scheduler().empty_queue_counter_.store(-1000);
      thread_state::get().get_scheduler().empty_queue_decrease_counter_and_wake();
#endif

109
#if PLS_PROFILING_ENABLED
FritzFlorian committed
110 111 112 113
      thread_state::get().get_scheduler().profiler_.task_finish_stack_measure(thread_state::get().get_thread_id(),
                                                                              root_task->stack_memory_,
                                                                              root_task->stack_size_,
                                                                              root_task->profiling_node_);
114 115 116
      thread_state::get().get_scheduler().profiler_.task_stop_running(thread_state::get().get_thread_id(),
                                                                      root_task->profiling_node_);
#endif
FritzFlorian committed
117

118
      return std::move(thread_state::get().main_continuation());
119 120 121 122 123
    });
  }
 private:
  F &function_;
};
124

125 126
template<typename Function>
void scheduler::perform_work(Function work_section) {
127
  // Prepare main root task
128 129
  init_function_impl<Function> starter_function{work_section};
  main_thread_starter_function_ = &starter_function;
130

131 132 133 134 135
#if PLS_PROFILING_ENABLED
  auto *root_task = thread_state_for(0).get_active_task();
  auto *root_node = profiler_.start_profiler_run();
  root_task->profiling_node_ = root_node;
#endif
136 137 138 139 140 141 142
#if PLS_SLEEP_WORKERS_ON_EMPTY
  empty_queue_counter_.store(0);
  for (auto &thread_state : thread_states_) {
    thread_state->get_queue_empty_flag().store({EMPTY_QUEUE_STATE::QUEUE_NON_EMPTY});
  }
#endif

143
  work_section_done_ = false;
144
  if (reuse_thread_) {
145 146
    auto &my_state = thread_state_for(0);
    thread_state::set(&my_state); // Make THIS THREAD become the main worker
147 148

    sync_barrier_.wait(); // Trigger threads to wake up
149
    work_thread_work_section(); // Simply also perform the work section on the main loop
150 151 152 153 154 155
    sync_barrier_.wait(); // Wait for threads to finish
  } else {
    // Simply trigger the others to do the work, this thread will sleep/wait for the time being
    sync_barrier_.wait(); // Trigger threads to wake up
    sync_barrier_.wait(); // Wait for threads to finish
  }
156 157 158
#if PLS_PROFILING_ENABLED
  profiler_.stop_profiler_run();
#endif
159 160
}

161
template<typename Function>
162 163 164
void scheduler::spawn_internal(Function &&lambda) {
  if (thread_state::is_scheduler_active()) {
    thread_state &spawning_state = thread_state::get();
165

166 167
    base_task *last_task = spawning_state.get_active_task();
    base_task *spawned_task = last_task->next_;
168

169 170 171 172 173 174 175 176 177 178 179 180 181
    // We are at the end of our allocated tasks, go over to serial execution.
    if (spawned_task == nullptr) {
      scheduler::serial_internal(std::forward<Function>(lambda));
      return;
    }
    // We are in a serial section.
    // For now we simply stay serial. Later versions could add nested parallel sections.
    if (last_task->is_serial_section_) {
      lambda();
      return;
    }

    // Carry over the resources allocated by parents
182 183 184
    auto *attached_resources = last_task->attached_resources_.load(std::memory_order_relaxed);
    spawned_task->attached_resources_.store(attached_resources, std::memory_order_relaxed);

185
#if PLS_PROFILING_ENABLED
186 187 188 189 190 191
    spawning_state.get_scheduler().profiler_.task_prepare_stack_measure(spawning_state.get_thread_id(),
                                                                        spawned_task->stack_memory_,
                                                                        spawned_task->stack_size_);
    auto *child_dag_node = spawning_state.get_scheduler().profiler_.task_spawn_child(spawning_state.get_thread_id(),
                                                                                     last_task->profiling_node_);
    spawned_task->profiling_node_ = child_dag_node;
192 193
#endif

194 195 196
    auto continuation = spawned_task->run_as_task([last_task, spawned_task, lambda, &spawning_state](auto cont) {
      // allow stealing threads to continue the last task.
      last_task->continuation_ = std::move(cont);
197

198 199 200
      // we are now executing the new task, allow others to steal the last task continuation.
      spawned_task->is_synchronized_ = true;
      spawning_state.set_active_task(spawned_task);
201

202
      spawning_state.get_task_manager().push_local_task(last_task);
203
#if PLS_SLEEP_WORKERS_ON_EMPTY
204 205
      data_structures::stamped_integer
          queue_empty_flag = spawning_state.get_queue_empty_flag().load(std::memory_order_relaxed);
206
      switch (queue_empty_flag.value_) {
207 208 209 210 211 212 213
        case EMPTY_QUEUE_STATE::QUEUE_NON_EMPTY: {
          // The queue was not found empty, ignore it.
          break;
        }
        case EMPTY_QUEUE_STATE::QUEUE_MAYBE_EMPTY: {
          // Someone tries to mark us empty and might be re-stealing right now.
          data_structures::stamped_integer
214
              queue_non_empty_flag{queue_empty_flag.stamp_++, EMPTY_QUEUE_STATE::QUEUE_NON_EMPTY};
215 216
          auto actual_empty_flag =
              spawning_state.get_queue_empty_flag().exchange(queue_non_empty_flag, std::memory_order_acq_rel);
217
          if (actual_empty_flag.value_ == EMPTY_QUEUE_STATE::QUEUE_EMPTY) {
218 219 220 221 222 223 224
            spawning_state.get_scheduler().empty_queue_decrease_counter_and_wake();
          }
          break;
        }
        case EMPTY_QUEUE_STATE::QUEUE_EMPTY: {
          // Someone already marked the queue empty, we must revert its action on the central queue.
          data_structures::stamped_integer
225
              queue_non_empty_flag{queue_empty_flag.stamp_++, EMPTY_QUEUE_STATE::QUEUE_NON_EMPTY};
226
          spawning_state.get_queue_empty_flag().store(queue_non_empty_flag, std::memory_order_release);
227
          spawning_state.get_scheduler().empty_queue_decrease_counter_and_wake();
228
          break;
229 230 231
        }
      }
#endif
232

233
      // execute the lambda itself, which could lead to a different thread returning.
234
#if PLS_PROFILING_ENABLED
235 236 237 238
      spawning_state.get_scheduler().profiler_.task_stop_running(spawning_state.get_thread_id(),
                                                                 last_task->profiling_node_);
      spawning_state.get_scheduler().profiler_.task_start_running(spawning_state.get_thread_id(),
                                                                  spawned_task->profiling_node_);
239
#endif
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261
      lambda();

      thread_state &syncing_state = thread_state::get();
      PLS_ASSERT(syncing_state.get_active_task() == spawned_task,
                 "Task manager must always point its active task onto whats executing.");

      // try to pop a task of the syncing task manager.
      // possible outcomes:
      // - this is a different task manager, it must have an empty deque and fail
      // - this is the same task manager and someone stole last tasks, thus this will fail
      // - this is the same task manager and no one stole the last task, this this will succeed
      base_task *popped_task = syncing_state.get_task_manager().pop_local_task();
      if (popped_task) {
        // Fast path, simply continue execution where we left of before spawn.
        PLS_ASSERT(popped_task == last_task,
                   "Fast path, nothing can have changed until here.");
        PLS_ASSERT(&spawning_state == &syncing_state,
                   "Fast path, we must only return if the task has not been stolen/moved to other thread.");
        PLS_ASSERT(last_task->continuation_.valid(),
                   "Fast path, no one can have continued working on the last task.");

        syncing_state.set_active_task(last_task);
262
#if PLS_PROFILING_ENABLED
263 264 265 266 267 268
        syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(),
                                                                          spawned_task->stack_memory_,
                                                                          spawned_task->stack_size_,
                                                                          spawned_task->profiling_node_);
        syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(),
                                                                  spawned_task->profiling_node_);
269
#endif
270 271 272
        return std::move(last_task->continuation_);
      } else {
        // Slow path, the last task was stolen. This path is common to sync() events.
273
#if PLS_PROFILING_ENABLED
274 275 276 277 278 279
        syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(),
                                                                          spawned_task->stack_memory_,
                                                                          spawned_task->stack_size_,
                                                                          spawned_task->profiling_node_);
        syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(),
                                                                  spawned_task->profiling_node_);
280
#endif
281
        auto continuation = slow_return(syncing_state, false);
282 283 284
        return continuation;
      }
    });
285

286
#if PLS_PROFILING_ENABLED
287 288
    thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(),
                                                                     last_task->profiling_node_);
289 290
#endif

291 292 293 294 295 296 297
    if (continuation.valid()) {
      // We jumped in here from the main loop, keep track!
      thread_state::get().main_continuation() = std::move(continuation);
    }
  } else {
    // Scheduler not active...
    lambda();
298
  }
299 300
}

301
template<typename Function>
302 303
void scheduler::spawn_and_sync_internal(Function &&lambda) {
  thread_state &spawning_state = thread_state::get();
304

305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342
  base_task *last_task = spawning_state.get_active_task();
  base_task *spawned_task = last_task->next_;

  // We are at the end of our allocated tasks, go over to serial execution.
  if (spawned_task == nullptr) {
    scheduler::serial_internal(std::forward<Function>(lambda));
    return;
  }
  // We are in a serial section.
  // For now we simply stay serial. Later versions could add nested parallel sections.
  if (last_task->is_serial_section_) {
    lambda();
    return;
  }

  // Carry over the resources allocated by parents
  auto *attached_resources = last_task->attached_resources_.load(std::memory_order_relaxed);
  spawned_task->attached_resources_.store(attached_resources, std::memory_order_relaxed);

#if PLS_PROFILING_ENABLED
  spawning_state.get_scheduler().profiler_.task_prepare_stack_measure(spawning_state.get_thread_id(),
                                                                      spawned_task->stack_memory_,
                                                                      spawned_task->stack_size_);
  auto *child_dag_node = spawning_state.get_scheduler().profiler_.task_spawn_child(spawning_state.get_thread_id(),
                                                                                   last_task->profiling_node_);
  spawned_task->profiling_node_ = child_dag_node;
#endif

  auto continuation = spawned_task->run_as_task([last_task, spawned_task, lambda, &spawning_state](auto cont) {
    // allow stealing threads to continue the last task.
    last_task->continuation_ = std::move(cont);

    // we are now executing the new task, allow others to steal the last task continuation.
    spawned_task->is_synchronized_ = true;
    spawning_state.set_active_task(spawned_task);

    // execute the lambda itself, which could lead to a different thread returning.
#if PLS_PROFILING_ENABLED
343 344 345 346 347 348
    spawning_state.get_scheduler().profiler_.task_finish_stack_measure(spawning_state.get_thread_id(),
                                                                       last_task->stack_memory_,
                                                                       last_task->stack_size_,
                                                                       last_task->profiling_node_);
    spawning_state.get_scheduler().profiler_.task_stop_running(spawning_state.get_thread_id(),
                                                               last_task->profiling_node_);
349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375
    auto *next_dag_node =
        spawning_state.get_scheduler().profiler_.task_sync(spawning_state.get_thread_id(), last_task->profiling_node_);
    last_task->profiling_node_ = next_dag_node;
    spawning_state.get_scheduler().profiler_.task_start_running(spawning_state.get_thread_id(),
                                                                spawned_task->profiling_node_);
#endif
    lambda();

    thread_state &syncing_state = thread_state::get();
    PLS_ASSERT(syncing_state.get_active_task() == spawned_task,
               "Task manager must always point its active task onto whats executing.");

    if (&syncing_state == &spawning_state && last_task->is_synchronized_) {
      // Fast path, simply continue execution where we left of before spawn.
      PLS_ASSERT(last_task->continuation_.valid(),
                 "Fast path, no one can have continued working on the last task.");

      syncing_state.set_active_task(last_task);
#if PLS_PROFILING_ENABLED
      syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(),
                                                                        spawned_task->stack_memory_,
                                                                        spawned_task->stack_size_,
                                                                        spawned_task->profiling_node_);
      syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(),
                                                                spawned_task->profiling_node_);
#endif
      return std::move(last_task->continuation_);
376
    } else {
377 378 379 380 381 382 383 384 385 386 387
      // Slow path, the last task was stolen. This path is common to sync() events.
#if PLS_PROFILING_ENABLED
      syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(),
                                                                        spawned_task->stack_memory_,
                                                                        spawned_task->stack_size_,
                                                                        spawned_task->profiling_node_);
      syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(),
                                                                spawned_task->profiling_node_);
#endif
      auto continuation = slow_return(syncing_state, false);
      return continuation;
388
    }
389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404
  });

#if PLS_PROFILING_ENABLED
  thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(),
                                                                   last_task->profiling_node_);
#endif

  PLS_ASSERT(!continuation.valid(), "We must not jump to a not-published (synced) spawn.");
}

template<typename Function>
void scheduler::serial_internal(Function &&lambda) {
  thread_state &spawning_state = thread_state::get();
  base_task *active_task = spawning_state.get_active_task();

  if (active_task->is_serial_section_) {
405
    lambda();
406 407 408 409 410 411 412 413 414
  } else {
    active_task->is_serial_section_ = true;
    context_switcher::enter_context(spawning_state.get_serial_call_stack(),
                                    spawning_state.get_serial_call_stack_size(),
                                    [&](auto cont) {
                                      lambda();
                                      return cont;
                                    });
    active_task->is_serial_section_ = false;
415 416 417
  }
}

418
}
419 420

#endif //PLS_SCHEDULER_IMPL_H