scheduler_impl.h 19.6 KB
Newer Older
1 2 3 4

#ifndef PLS_SCHEDULER_IMPL_H
#define PLS_SCHEDULER_IMPL_H

5
#include <utility>
6

7 8 9
#include "context_switcher/context_switcher.h"
#include "context_switcher/continuation.h"

10 11 12 13
#include "pls/internal/scheduling/task_manager.h"
#include "pls/internal/scheduling/base_task.h"
#include "base_task.h"

14 15
#include "pls/internal/profiling/dag_node.h"

16
namespace pls::internal::scheduling {
17

18 19 20 21 22
template<typename ALLOC>
scheduler::scheduler(unsigned int num_threads,
                     size_t computation_depth,
                     size_t stack_size,
                     bool reuse_thread,
23
                     size_t serial_stack_size,
24 25 26 27 28 29 30 31 32
                     ALLOC &&stack_allocator) :
    num_threads_{num_threads},
    reuse_thread_{reuse_thread},
    sync_barrier_{num_threads + 1 - reuse_thread},
    worker_threads_{},
    thread_states_{},
    main_thread_starter_function_{nullptr},
    work_section_done_{false},
    terminated_{false},
33 34
    stack_allocator_{std::make_shared<ALLOC>(std::forward<ALLOC>(stack_allocator))},
    serial_stack_size_{serial_stack_size}
35
#if PLS_PROFILING_ENABLED
36
, profiler_{num_threads}
37 38
#endif
{
39 40 41 42 43 44 45 46 47 48

  worker_threads_.reserve(num_threads);
  task_managers_.reserve(num_threads);
  thread_states_.reserve(num_threads);
  for (unsigned int i = 0; i < num_threads_; i++) {
    auto &this_task_manager =
        task_managers_.emplace_back(std::make_unique<task_manager>(i,
                                                                   computation_depth,
                                                                   stack_size,
                                                                   stack_allocator_));
49 50 51 52 53
    auto &this_thread_state = thread_states_.emplace_back(std::make_unique<thread_state>(*this,
                                                                                         i,
                                                                                         *this_task_manager,
                                                                                         stack_allocator_,
                                                                                         serial_stack_size));
54 55 56 57 58 59 60 61 62 63 64 65

    if (reuse_thread && i == 0) {
      worker_threads_.emplace_back();
      continue; // Skip over first/main thread when re-using the users thread, as this one will replace the first one.
    }

    auto *this_thread_state_pointer = this_thread_state.get();
    worker_threads_.emplace_back([this_thread_state_pointer] {
      thread_state::set(this_thread_state_pointer);
      work_thread_main_loop();
    });
  }
66 67 68 69 70 71 72 73 74 75 76 77 78 79

  // Make sure all threads are created and touched their stacks.
  // Executing a work section ensures one wakeup/sleep cycle of all workers
  // and explicitly forcing one task per worker forces them to initialize their stacks.
  std::atomic<unsigned> num_spawned;
  this->perform_work([&]() {
    for (unsigned i = 0; i < num_threads; i++) {
      spawn([&]() {
        num_spawned++;
        while (num_spawned < num_threads) std::this_thread::yield();
      });
    }
    sync();
  });
80 81
}

82 83
class scheduler::init_function {
 public:
84
  virtual void run() = 0;
85 86 87 88 89
};
template<typename F>
class scheduler::init_function_impl : public init_function {
 public:
  explicit init_function_impl(F &function) : function_{function} {}
90
  void run() override {
91
    base_task *root_task = thread_state::get().get_active_task();
92
    root_task->attached_resources_.store(nullptr, std::memory_order_relaxed);
93 94 95 96 97 98 99 100 101

#if PLS_PROFILING_ENABLED
    thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(),
                                                                     root_task->profiling_node_);
    thread_state::get().get_scheduler().profiler_.task_prepare_stack_measure(thread_state::get().get_thread_id(),
                                                                             root_task->stack_memory_,
                                                                             root_task->stack_size_);
#endif

102
    root_task->run_as_task([root_task, this](auto cont) {
103
      root_task->is_synchronized_ = true;
104
      thread_state::get().main_continuation() = std::move(cont);
105
      function_();
106

107
      thread_state::get().get_scheduler().work_section_done_.store(true);
108
      PLS_ASSERT(thread_state::get().main_continuation().valid(), "Must return valid continuation from main task.");
109

110 111 112 113 114
#if PLS_SLEEP_WORKERS_ON_EMPTY
      thread_state::get().get_scheduler().empty_queue_counter_.store(-1000);
      thread_state::get().get_scheduler().empty_queue_decrease_counter_and_wake();
#endif

115
#if PLS_PROFILING_ENABLED
FritzFlorian committed
116 117 118 119
      thread_state::get().get_scheduler().profiler_.task_finish_stack_measure(thread_state::get().get_thread_id(),
                                                                              root_task->stack_memory_,
                                                                              root_task->stack_size_,
                                                                              root_task->profiling_node_);
120 121 122
      thread_state::get().get_scheduler().profiler_.task_stop_running(thread_state::get().get_thread_id(),
                                                                      root_task->profiling_node_);
#endif
FritzFlorian committed
123

124
      return std::move(thread_state::get().main_continuation());
125 126 127 128 129
    });
  }
 private:
  F &function_;
};
130

131 132
template<typename Function>
void scheduler::perform_work(Function work_section) {
133
  // Prepare main root task
134 135
  init_function_impl<Function> starter_function{work_section};
  main_thread_starter_function_ = &starter_function;
136

137 138 139 140 141
#if PLS_PROFILING_ENABLED
  auto *root_task = thread_state_for(0).get_active_task();
  auto *root_node = profiler_.start_profiler_run();
  root_task->profiling_node_ = root_node;
#endif
142 143 144 145 146 147 148
#if PLS_SLEEP_WORKERS_ON_EMPTY
  empty_queue_counter_.store(0);
  for (auto &thread_state : thread_states_) {
    thread_state->get_queue_empty_flag().store({EMPTY_QUEUE_STATE::QUEUE_NON_EMPTY});
  }
#endif

149
  work_section_done_ = false;
150
  if (reuse_thread_) {
151 152
    auto &my_state = thread_state_for(0);
    thread_state::set(&my_state); // Make THIS THREAD become the main worker
153 154

    sync_barrier_.wait(); // Trigger threads to wake up
155
    work_thread_work_section(); // Simply also perform the work section on the main loop
156 157 158 159 160 161
    sync_barrier_.wait(); // Wait for threads to finish
  } else {
    // Simply trigger the others to do the work, this thread will sleep/wait for the time being
    sync_barrier_.wait(); // Trigger threads to wake up
    sync_barrier_.wait(); // Wait for threads to finish
  }
162 163 164
#if PLS_PROFILING_ENABLED
  profiler_.stop_profiler_run();
#endif
165 166
}

167
template<typename Function>
168 169 170
void scheduler::spawn_internal(Function &&lambda) {
  if (thread_state::is_scheduler_active()) {
    thread_state &spawning_state = thread_state::get();
171

172 173
    base_task *last_task = spawning_state.get_active_task();
    base_task *spawned_task = last_task->next_;
174

175 176 177 178 179 180 181 182 183 184 185 186 187
    // We are at the end of our allocated tasks, go over to serial execution.
    if (spawned_task == nullptr) {
      scheduler::serial_internal(std::forward<Function>(lambda));
      return;
    }
    // We are in a serial section.
    // For now we simply stay serial. Later versions could add nested parallel sections.
    if (last_task->is_serial_section_) {
      lambda();
      return;
    }

    // Carry over the resources allocated by parents
188 189 190
    auto *attached_resources = last_task->attached_resources_.load(std::memory_order_relaxed);
    spawned_task->attached_resources_.store(attached_resources, std::memory_order_relaxed);

191
#if PLS_PROFILING_ENABLED
192 193 194 195 196 197
    spawning_state.get_scheduler().profiler_.task_prepare_stack_measure(spawning_state.get_thread_id(),
                                                                        spawned_task->stack_memory_,
                                                                        spawned_task->stack_size_);
    auto *child_dag_node = spawning_state.get_scheduler().profiler_.task_spawn_child(spawning_state.get_thread_id(),
                                                                                     last_task->profiling_node_);
    spawned_task->profiling_node_ = child_dag_node;
198 199
#endif

200 201 202
    auto continuation = spawned_task->run_as_task([last_task, spawned_task, lambda, &spawning_state](auto cont) {
      // allow stealing threads to continue the last task.
      last_task->continuation_ = std::move(cont);
203

204 205 206
      // we are now executing the new task, allow others to steal the last task continuation.
      spawned_task->is_synchronized_ = true;
      spawning_state.set_active_task(spawned_task);
207

208
      spawning_state.get_task_manager().push_local_task(last_task);
209
#if PLS_SLEEP_WORKERS_ON_EMPTY
210 211
      data_structures::stamped_integer
          queue_empty_flag = spawning_state.get_queue_empty_flag().load(std::memory_order_relaxed);
212
      switch (queue_empty_flag.value_) {
213 214 215 216 217 218 219
        case EMPTY_QUEUE_STATE::QUEUE_NON_EMPTY: {
          // The queue was not found empty, ignore it.
          break;
        }
        case EMPTY_QUEUE_STATE::QUEUE_MAYBE_EMPTY: {
          // Someone tries to mark us empty and might be re-stealing right now.
          data_structures::stamped_integer
220
              queue_non_empty_flag{queue_empty_flag.stamp_++, EMPTY_QUEUE_STATE::QUEUE_NON_EMPTY};
221 222
          auto actual_empty_flag =
              spawning_state.get_queue_empty_flag().exchange(queue_non_empty_flag, std::memory_order_acq_rel);
223
          if (actual_empty_flag.value_ == EMPTY_QUEUE_STATE::QUEUE_EMPTY) {
224 225 226 227 228 229 230
            spawning_state.get_scheduler().empty_queue_decrease_counter_and_wake();
          }
          break;
        }
        case EMPTY_QUEUE_STATE::QUEUE_EMPTY: {
          // Someone already marked the queue empty, we must revert its action on the central queue.
          data_structures::stamped_integer
231
              queue_non_empty_flag{queue_empty_flag.stamp_++, EMPTY_QUEUE_STATE::QUEUE_NON_EMPTY};
232
          spawning_state.get_queue_empty_flag().store(queue_non_empty_flag, std::memory_order_release);
233
          spawning_state.get_scheduler().empty_queue_decrease_counter_and_wake();
234
          break;
235 236 237
        }
      }
#endif
238

239
      // execute the lambda itself, which could lead to a different thread returning.
240
#if PLS_PROFILING_ENABLED
241 242 243 244
      spawning_state.get_scheduler().profiler_.task_stop_running(spawning_state.get_thread_id(),
                                                                 last_task->profiling_node_);
      spawning_state.get_scheduler().profiler_.task_start_running(spawning_state.get_thread_id(),
                                                                  spawned_task->profiling_node_);
245
#endif
246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267
      lambda();

      thread_state &syncing_state = thread_state::get();
      PLS_ASSERT(syncing_state.get_active_task() == spawned_task,
                 "Task manager must always point its active task onto whats executing.");

      // try to pop a task of the syncing task manager.
      // possible outcomes:
      // - this is a different task manager, it must have an empty deque and fail
      // - this is the same task manager and someone stole last tasks, thus this will fail
      // - this is the same task manager and no one stole the last task, this this will succeed
      base_task *popped_task = syncing_state.get_task_manager().pop_local_task();
      if (popped_task) {
        // Fast path, simply continue execution where we left of before spawn.
        PLS_ASSERT(popped_task == last_task,
                   "Fast path, nothing can have changed until here.");
        PLS_ASSERT(&spawning_state == &syncing_state,
                   "Fast path, we must only return if the task has not been stolen/moved to other thread.");
        PLS_ASSERT(last_task->continuation_.valid(),
                   "Fast path, no one can have continued working on the last task.");

        syncing_state.set_active_task(last_task);
268
#if PLS_PROFILING_ENABLED
269 270 271 272 273 274
        syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(),
                                                                          spawned_task->stack_memory_,
                                                                          spawned_task->stack_size_,
                                                                          spawned_task->profiling_node_);
        syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(),
                                                                  spawned_task->profiling_node_);
275
#endif
276 277 278
        return std::move(last_task->continuation_);
      } else {
        // Slow path, the last task was stolen. This path is common to sync() events.
279
#if PLS_PROFILING_ENABLED
280 281 282 283 284 285
        syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(),
                                                                          spawned_task->stack_memory_,
                                                                          spawned_task->stack_size_,
                                                                          spawned_task->profiling_node_);
        syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(),
                                                                  spawned_task->profiling_node_);
286
#endif
287
        auto continuation = slow_return(syncing_state, false);
288 289 290
        return continuation;
      }
    });
291

292
#if PLS_PROFILING_ENABLED
293 294
    thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(),
                                                                     last_task->profiling_node_);
295 296
#endif

297 298 299 300 301 302 303
    if (continuation.valid()) {
      // We jumped in here from the main loop, keep track!
      thread_state::get().main_continuation() = std::move(continuation);
    }
  } else {
    // Scheduler not active...
    lambda();
304
  }
305 306
}

307
template<typename Function>
308 309
void scheduler::spawn_and_sync_internal(Function &&lambda) {
  thread_state &spawning_state = thread_state::get();
310

311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381
  base_task *last_task = spawning_state.get_active_task();
  base_task *spawned_task = last_task->next_;

  // We are at the end of our allocated tasks, go over to serial execution.
  if (spawned_task == nullptr) {
    scheduler::serial_internal(std::forward<Function>(lambda));
    return;
  }
  // We are in a serial section.
  // For now we simply stay serial. Later versions could add nested parallel sections.
  if (last_task->is_serial_section_) {
    lambda();
    return;
  }

  // Carry over the resources allocated by parents
  auto *attached_resources = last_task->attached_resources_.load(std::memory_order_relaxed);
  spawned_task->attached_resources_.store(attached_resources, std::memory_order_relaxed);

#if PLS_PROFILING_ENABLED
  spawning_state.get_scheduler().profiler_.task_prepare_stack_measure(spawning_state.get_thread_id(),
                                                                      spawned_task->stack_memory_,
                                                                      spawned_task->stack_size_);
  auto *child_dag_node = spawning_state.get_scheduler().profiler_.task_spawn_child(spawning_state.get_thread_id(),
                                                                                   last_task->profiling_node_);
  spawned_task->profiling_node_ = child_dag_node;
#endif

  auto continuation = spawned_task->run_as_task([last_task, spawned_task, lambda, &spawning_state](auto cont) {
    // allow stealing threads to continue the last task.
    last_task->continuation_ = std::move(cont);

    // we are now executing the new task, allow others to steal the last task continuation.
    spawned_task->is_synchronized_ = true;
    spawning_state.set_active_task(spawned_task);

    // execute the lambda itself, which could lead to a different thread returning.
#if PLS_PROFILING_ENABLED
    spawning_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(),
                                                                      last_task->stack_memory_,
                                                                      last_task->stack_size_,
                                                                      last_task->profiling_node_);
    syncing_state.get_scheduler().profiler_.task_stop_running(spawning_state.get_thread_id(),
                                                              last_task->profiling_node_);
    auto *next_dag_node =
        spawning_state.get_scheduler().profiler_.task_sync(spawning_state.get_thread_id(), last_task->profiling_node_);
    last_task->profiling_node_ = next_dag_node;
    spawning_state.get_scheduler().profiler_.task_start_running(spawning_state.get_thread_id(),
                                                                spawned_task->profiling_node_);
#endif
    lambda();

    thread_state &syncing_state = thread_state::get();
    PLS_ASSERT(syncing_state.get_active_task() == spawned_task,
               "Task manager must always point its active task onto whats executing.");

    if (&syncing_state == &spawning_state && last_task->is_synchronized_) {
      // Fast path, simply continue execution where we left of before spawn.
      PLS_ASSERT(last_task->continuation_.valid(),
                 "Fast path, no one can have continued working on the last task.");

      syncing_state.set_active_task(last_task);
#if PLS_PROFILING_ENABLED
      syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(),
                                                                        spawned_task->stack_memory_,
                                                                        spawned_task->stack_size_,
                                                                        spawned_task->profiling_node_);
      syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(),
                                                                spawned_task->profiling_node_);
#endif
      return std::move(last_task->continuation_);
382
    } else {
383 384 385 386 387 388 389 390 391 392 393
      // Slow path, the last task was stolen. This path is common to sync() events.
#if PLS_PROFILING_ENABLED
      syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(),
                                                                        spawned_task->stack_memory_,
                                                                        spawned_task->stack_size_,
                                                                        spawned_task->profiling_node_);
      syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(),
                                                                spawned_task->profiling_node_);
#endif
      auto continuation = slow_return(syncing_state, false);
      return continuation;
394
    }
395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410
  });

#if PLS_PROFILING_ENABLED
  thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(),
                                                                   last_task->profiling_node_);
#endif

  PLS_ASSERT(!continuation.valid(), "We must not jump to a not-published (synced) spawn.");
}

template<typename Function>
void scheduler::serial_internal(Function &&lambda) {
  thread_state &spawning_state = thread_state::get();
  base_task *active_task = spawning_state.get_active_task();

  if (active_task->is_serial_section_) {
411
    lambda();
412 413 414 415 416 417 418 419 420
  } else {
    active_task->is_serial_section_ = true;
    context_switcher::enter_context(spawning_state.get_serial_call_stack(),
                                    spawning_state.get_serial_call_stack_size(),
                                    [&](auto cont) {
                                      lambda();
                                      return cont;
                                    });
    active_task->is_serial_section_ = false;
421 422 423
  }
}

424
}
425 426

#endif //PLS_SCHEDULER_IMPL_H