scheduler.cpp 9.33 KB
Newer Older
1
#include "pls/internal/scheduling/scheduler.h"
2

3
#include "context_switcher/context_switcher.h"
4
#include "pls/internal/base/error_handling.h"
5

6 7
#include <thread>

8
namespace pls::internal::scheduling {
9

10 11 12 13 14 15 16 17
scheduler::scheduler(unsigned int num_threads,
                     size_t computation_depth,
                     size_t stack_size,
                     bool reuse_thread) : scheduler(num_threads,
                                                    computation_depth,
                                                    stack_size,
                                                    reuse_thread,
                                                    base::mmap_stack_allocator{}) {}
18 19 20 21

scheduler::~scheduler() {
  terminate();
}
22 23
void scheduler::work_thread_main_loop() {
  auto &scheduler = thread_state::get().get_scheduler();
24
  while (true) {
25
    // Wait to be triggered
26
    scheduler.sync_barrier_.wait();
27 28

    // Check for shutdown
29
    if (scheduler.terminated_) {
30 31 32
      return;
    }

33
    scheduler.work_thread_work_section();
34

35
    // Sync back with main thread
36 37 38 39 40
    scheduler.sync_barrier_.wait();
  }
}

void scheduler::work_thread_work_section() {
41 42
  thread_state &my_state = thread_state::get();
  unsigned const num_threads = my_state.get_scheduler().num_threads();
43

44
  if (my_state.get_thread_id() == 0) {
45
    // Main Thread, kick off by executing the user's main code block.
46
    main_thread_starter_function_->run();
47
  }
48

49
  unsigned int failed_steals = 0;
50
  while (!work_section_done_) {
51 52 53
#if PLS_PROFILING_ENABLED
    my_state.get_scheduler().profiler_.stealing_start(my_state.get_thread_id());
#endif
54 55 56 57 58 59 60 61
    PLS_ASSERT(check_task_chain(*my_state.get_active_task()), "Must start stealing with a clean task chain.");

    size_t target;
    do {
      target = my_state.get_rand() % num_threads;
    } while (target == my_state.get_thread_id());

    thread_state &target_state = my_state.get_scheduler().thread_state_for(target);
62 63 64 65 66 67 68 69 70 71 72 73

    base_task *stolen_task;
    base_task *chain_after_stolen_task;
    bool cas_success;
    do {
#if PLS_PROFILING_ENABLED
      my_state.get_scheduler().profiler_.stealing_cas_op(my_state.get_thread_id());
#endif
      std::tie(stolen_task, chain_after_stolen_task, cas_success) =
          target_state.get_task_manager().steal_task(my_state);
    } while (!cas_success);

74
    if (stolen_task) {
75
      // Keep task chain consistent. We want to appear as if we are working an a branch upwards of the stolen task.
76 77
      stolen_task->next_ = chain_after_stolen_task;
      chain_after_stolen_task->prev_ = stolen_task;
78
      my_state.set_active_task(stolen_task);
79

80
      PLS_ASSERT(check_task_chain_forward(*my_state.get_active_task()),
81 82 83 84 85
                 "We are sole owner of this chain, it has to be valid!");

      // Execute the stolen task by jumping to it's continuation.
      PLS_ASSERT(stolen_task->continuation_.valid(),
                 "A task that we can steal must have a valid continuation for us to start working.");
86
      stolen_task->is_synchronized_ = false;
87 88 89 90 91
#if PLS_PROFILING_ENABLED
      my_state.get_scheduler().profiler_.stealing_end(my_state.get_thread_id(), true);
      my_state.get_scheduler().profiler_.task_start_running(my_state.get_thread_id(),
                                                            stolen_task->profiling_node_);
#endif
92 93
      context_switcher::switch_context(std::move(stolen_task->continuation_));
      // We will continue execution in this line when we finished the stolen work.
94 95 96 97
      failed_steals = 0;
    } else {
      failed_steals++;
      if (failed_steals >= num_threads) {
98
        std::this_thread::yield();
99
      }
100 101 102
#if PLS_PROFILING_ENABLED
      my_state.get_scheduler().profiler_.stealing_end(my_state.get_thread_id(), false);
#endif
103 104
    }
  }
105 106
}

107 108 109 110 111 112
void scheduler::sync() {
  thread_state &syncing_state = thread_state::get();

  base_task *active_task = syncing_state.get_active_task();
  base_task *spawned_task = active_task->next_;

113 114 115 116 117 118 119 120 121 122 123 124
#if PLS_PROFILING_ENABLED
  syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(),
                                                                    active_task->stack_memory_,
                                                                    active_task->stack_size_,
                                                                    active_task->profiling_node_);
  syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(),
                                                            active_task->profiling_node_);
  auto *next_dag_node =
      syncing_state.get_scheduler().profiler_.task_sync(syncing_state.get_thread_id(), active_task->profiling_node_);
  active_task->profiling_node_ = next_dag_node;
#endif

125
  if (active_task->is_synchronized_) {
126 127 128 129
#if PLS_PROFILING_ENABLED
    syncing_state.get_scheduler().profiler_.task_start_running(syncing_state.get_thread_id(),
                                                               active_task->profiling_node_);
#endif
130 131 132 133 134 135 136 137 138 139 140 141
    return; // We are already the sole owner of last_task
  } else {
    auto continuation =
        spawned_task->run_as_task([active_task, spawned_task, &syncing_state](context_switcher::continuation cont) {
          active_task->continuation_ = std::move(cont);
          syncing_state.set_active_task(spawned_task);
          return slow_return(syncing_state);
        });

    PLS_ASSERT(!continuation.valid(),
               "We only return to a sync point, never jump to it directly."
               "This must therefore never return an unfinished fiber/continuation.");
142 143 144 145
#if PLS_PROFILING_ENABLED
    thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(),
                                                                     thread_state::get().get_active_task()->profiling_node_);
#endif
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
    return; // We cleanly synced to the last one finishing work on last_task
  }
}

context_switcher::continuation scheduler::slow_return(thread_state &calling_state) {
  base_task *this_task = calling_state.get_active_task();
  PLS_ASSERT(this_task->depth_ > 0,
             "Must never try to return from a task at level 0 (no last task), as we must have a target to return to.");
  base_task *last_task = this_task->prev_;

  // Slow return means we try to finish the child 'this_task' of 'last_task' and we
  // do not know if we are the last child to finish.
  // If we are not the last one, we get a spare task chain for our resources and can return to the main scheduling loop.
  base_task *pop_result = calling_state.get_task_manager().pop_clean_task_chain(last_task);

  if (pop_result != nullptr) {
    base_task *clean_chain = pop_result;

    // We got a clean chain to fill up our resources.
    PLS_ASSERT(last_task->depth_ == clean_chain->depth_,
               "Resources must only reside in the correct depth!");
    PLS_ASSERT(last_task != clean_chain,
               "We want to swap out the last task and its chain to use a clean one, thus they must differ.");
    PLS_ASSERT(check_task_chain_backward(*clean_chain),
               "Can only acquire clean chains for clean returns!");

    // Acquire it/merge it with our task chain.
    this_task->prev_ = clean_chain;
    clean_chain->next_ = this_task;

    base_task *active_task = clean_chain;
    while (active_task->depth_ > 0) {
      active_task = active_task->prev_;
    }
    calling_state.set_active_task(active_task);

    // Jump back to the continuation in main scheduling loop.
    context_switcher::continuation result_cont = std::move(thread_state::get().main_continuation());
    PLS_ASSERT(result_cont.valid(), "Must return a valid continuation.");
185
    return std::move(result_cont);
186 187 188 189 190 191 192 193 194 195 196
  } else {
    // Make sure that we are owner of this full continuation/task chain.
    last_task->next_ = this_task;

    // We are the last one working on this task. Thus the sync must be finished, continue working.
    calling_state.set_active_task(last_task);
    last_task->is_synchronized_ = true;

    // Jump to parent task and continue working on it.
    context_switcher::continuation result_cont = std::move(last_task->continuation_);
    PLS_ASSERT(result_cont.valid(), "Must return a valid continuation.");
197
    return std::move(result_cont);
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
  }
}

base_task &scheduler::task_chain_at(unsigned int depth, thread_state &calling_state) {
  // TODO: possible optimize with cache array at steal events
  base_task *result = calling_state.get_active_task();
  while (result->depth_ > depth) {
    result = result->prev_;
  }
  while (result->depth_ < depth) {
    result = result->next_;
  }

  return *result;
}

214
void scheduler::terminate() {
215 216 217 218 219 220 221
  if (terminated_) {
    return;
  }

  terminated_ = true;
  sync_barrier_.wait();

222 223 224
  for (unsigned int i = 0; i < num_threads_; i++) {
    if (reuse_thread_ && i == 0) {
      continue;
225
    }
226
    worker_threads_[i].join();
227 228 229
  }
}

230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253
bool scheduler::check_task_chain_forward(base_task &start_task) {
  base_task *current = &start_task;
  while (current->next_) {
    if (current->next_->prev_ != current) {
      return false;
    }
    current = current->next_;
  }
  return true;
}

bool scheduler::check_task_chain_backward(base_task &start_task) {
  base_task *current = &start_task;
  while (current->prev_) {
    if (current->prev_->next_ != current) {
      return false;
    }
    current = current->prev_;
  }
  return true;
}

bool scheduler::check_task_chain(base_task &start_task) {
  return check_task_chain_backward(start_task) && check_task_chain_forward(start_task);
254 255
}

256
}