scheduler_impl.h 9.91 KB
Newer Older
1 2 3 4

#ifndef PLS_SCHEDULER_IMPL_H
#define PLS_SCHEDULER_IMPL_H

5
#include <utility>
6

7 8 9
#include "context_switcher/context_switcher.h"
#include "context_switcher/continuation.h"

10 11 12 13
#include "pls/internal/scheduling/task_manager.h"
#include "pls/internal/scheduling/base_task.h"
#include "base_task.h"

14 15
#include "pls/internal/profiling/dag_node.h"

16
namespace pls::internal::scheduling {
17

18 19 20 21 22 23 24 25 26 27 28 29 30 31
template<typename ALLOC>
scheduler::scheduler(unsigned int num_threads,
                     size_t computation_depth,
                     size_t stack_size,
                     bool reuse_thread,
                     ALLOC &&stack_allocator) :
    num_threads_{num_threads},
    reuse_thread_{reuse_thread},
    sync_barrier_{num_threads + 1 - reuse_thread},
    worker_threads_{},
    thread_states_{},
    main_thread_starter_function_{nullptr},
    work_section_done_{false},
    terminated_{false},
32 33
    stack_allocator_{std::make_shared<ALLOC>(std::forward<ALLOC>(stack_allocator))}
#if PLS_PROFILING_ENABLED
FritzFlorian committed
34
    , profiler_{num_threads}
35 36
#endif
{
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61

  worker_threads_.reserve(num_threads);
  task_managers_.reserve(num_threads);
  thread_states_.reserve(num_threads);
  for (unsigned int i = 0; i < num_threads_; i++) {
    auto &this_task_manager =
        task_managers_.emplace_back(std::make_unique<task_manager>(i,
                                                                   computation_depth,
                                                                   stack_size,
                                                                   stack_allocator_));
    auto &this_thread_state = thread_states_.emplace_back(std::make_unique<thread_state>(*this, i, *this_task_manager));

    if (reuse_thread && i == 0) {
      worker_threads_.emplace_back();
      continue; // Skip over first/main thread when re-using the users thread, as this one will replace the first one.
    }

    auto *this_thread_state_pointer = this_thread_state.get();
    worker_threads_.emplace_back([this_thread_state_pointer] {
      thread_state::set(this_thread_state_pointer);
      work_thread_main_loop();
    });
  }
}

62 63
class scheduler::init_function {
 public:
64
  virtual void run() = 0;
65 66 67 68 69
};
template<typename F>
class scheduler::init_function_impl : public init_function {
 public:
  explicit init_function_impl(F &function) : function_{function} {}
70
  void run() override {
71
    base_task *root_task = thread_state::get().get_active_task();
72 73 74 75 76 77 78 79 80

#if PLS_PROFILING_ENABLED
    thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(),
                                                                     root_task->profiling_node_);
    thread_state::get().get_scheduler().profiler_.task_prepare_stack_measure(thread_state::get().get_thread_id(),
                                                                             root_task->stack_memory_,
                                                                             root_task->stack_size_);
#endif

81
    root_task->run_as_task([root_task, this](auto cont) {
82
      root_task->is_synchronized_ = true;
83
      thread_state::get().main_continuation() = std::move(cont);
84
      function_();
85

86
      thread_state::get().get_scheduler().work_section_done_.store(true);
87
      PLS_ASSERT(thread_state::get().main_continuation().valid(), "Must return valid continuation from main task.");
88 89

#if PLS_PROFILING_ENABLED
FritzFlorian committed
90 91 92 93
      thread_state::get().get_scheduler().profiler_.task_finish_stack_measure(thread_state::get().get_thread_id(),
                                                                              root_task->stack_memory_,
                                                                              root_task->stack_size_,
                                                                              root_task->profiling_node_);
94 95 96
      thread_state::get().get_scheduler().profiler_.task_stop_running(thread_state::get().get_thread_id(),
                                                                      root_task->profiling_node_);
#endif
FritzFlorian committed
97

98
      return std::move(thread_state::get().main_continuation());
99 100 101 102 103
    });
  }
 private:
  F &function_;
};
104

105 106
template<typename Function>
void scheduler::perform_work(Function work_section) {
107
  // Prepare main root task
108 109
  init_function_impl<Function> starter_function{work_section};
  main_thread_starter_function_ = &starter_function;
110

111 112 113 114 115
#if PLS_PROFILING_ENABLED
  auto *root_task = thread_state_for(0).get_active_task();
  auto *root_node = profiler_.start_profiler_run();
  root_task->profiling_node_ = root_node;
#endif
116
  work_section_done_ = false;
117
  if (reuse_thread_) {
118 119
    auto &my_state = thread_state_for(0);
    thread_state::set(&my_state); // Make THIS THREAD become the main worker
120 121

    sync_barrier_.wait(); // Trigger threads to wake up
122
    work_thread_work_section(); // Simply also perform the work section on the main loop
123 124 125 126 127 128
    sync_barrier_.wait(); // Wait for threads to finish
  } else {
    // Simply trigger the others to do the work, this thread will sleep/wait for the time being
    sync_barrier_.wait(); // Trigger threads to wake up
    sync_barrier_.wait(); // Wait for threads to finish
  }
129 130 131
#if PLS_PROFILING_ENABLED
  profiler_.stop_profiler_run();
#endif
132 133
}

134 135
template<typename Function>
void scheduler::spawn(Function &&lambda) {
136 137 138 139 140
  thread_state &spawning_state = thread_state::get();

  base_task *last_task = spawning_state.get_active_task();
  base_task *spawned_task = last_task->next_;

141 142 143 144 145 146 147 148 149
#if PLS_PROFILING_ENABLED
  spawning_state.get_scheduler().profiler_.task_prepare_stack_measure(spawning_state.get_thread_id(),
                                                                      spawned_task->stack_memory_,
                                                                      spawned_task->stack_size_);
  auto *child_dag_node = spawning_state.get_scheduler().profiler_.task_spawn_child(spawning_state.get_thread_id(),
                                                                                   last_task->profiling_node_);
  spawned_task->profiling_node_ = child_dag_node;
#endif

150 151 152 153 154 155 156 157 158 159
  auto continuation = spawned_task->run_as_task([last_task, spawned_task, lambda, &spawning_state](auto cont) {
    // allow stealing threads to continue the last task.
    last_task->continuation_ = std::move(cont);

    // we are now executing the new task, allow others to steal the last task continuation.
    spawned_task->is_synchronized_ = true;
    spawning_state.set_active_task(spawned_task);
    spawning_state.get_task_manager().push_local_task(last_task);

    // execute the lambda itself, which could lead to a different thread returning.
160 161 162 163 164 165
#if PLS_PROFILING_ENABLED
    spawning_state.get_scheduler().profiler_.task_stop_running(spawning_state.get_thread_id(),
                                                               last_task->profiling_node_);
    spawning_state.get_scheduler().profiler_.task_start_running(spawning_state.get_thread_id(),
                                                                spawned_task->profiling_node_);
#endif
166
    lambda();
167

168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
    thread_state &syncing_state = thread_state::get();
    PLS_ASSERT(syncing_state.get_active_task() == spawned_task,
               "Task manager must always point its active task onto whats executing.");

    // try to pop a task of the syncing task manager.
    // possible outcomes:
    // - this is a different task manager, it must have an empty deque and fail
    // - this is the same task manager and someone stole last tasks, thus this will fail
    // - this is the same task manager and no one stole the last task, this this will succeed
    base_task *popped_task = syncing_state.get_task_manager().pop_local_task();
    if (popped_task) {
      // Fast path, simply continue execution where we left of before spawn.
      PLS_ASSERT(popped_task == last_task,
                 "Fast path, nothing can have changed until here.");
      PLS_ASSERT(&spawning_state == &syncing_state,
                 "Fast path, we must only return if the task has not been stolen/moved to other thread.");
      PLS_ASSERT(last_task->continuation_.valid(),
                 "Fast path, no one can have continued working on the last task.");

      syncing_state.set_active_task(last_task);
188
#if PLS_PROFILING_ENABLED
FritzFlorian committed
189 190 191 192
      syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(),
                                                                        spawned_task->stack_memory_,
                                                                        spawned_task->stack_size_,
                                                                        spawned_task->profiling_node_);
193 194 195
      syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(),
                                                                spawned_task->profiling_node_);
#endif
196 197 198
      return std::move(last_task->continuation_);
    } else {
      // Slow path, the last task was stolen. This path is common to sync() events.
199
#if PLS_PROFILING_ENABLED
FritzFlorian committed
200 201 202 203
      syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(),
                                                                        spawned_task->stack_memory_,
                                                                        spawned_task->stack_size_,
                                                                        spawned_task->profiling_node_);
204 205 206
      syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(),
                                                                spawned_task->profiling_node_);
#endif
FritzFlorian committed
207
      auto continuation = slow_return(syncing_state);
208
      return continuation;
209 210 211
    }
  });

212 213
#if PLS_PROFILING_ENABLED
  thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(),
FritzFlorian committed
214
                                                                   last_task->profiling_node_);
215 216
#endif

217 218 219 220
  if (continuation.valid()) {
    // We jumped in here from the main loop, keep track!
    thread_state::get().main_continuation() = std::move(continuation);
  }
221 222
}

223
}
224 225

#endif //PLS_SCHEDULER_IMPL_H