scheduler.cpp 15 KB
Newer Older
1
#include "pls/internal/scheduling/scheduler.h"
2

3
#include "context_switcher/context_switcher.h"
4

5
#include "pls/internal/scheduling/strain_local_resource.h"
6
#include "pls/internal/build_flavour.h"
7
#include "pls/internal/base/error_handling.h"
8
#include "pls/internal/base/futex_wrapper.h"
9

10 11
#include <thread>

12
namespace pls::internal::scheduling {
13

14 15 16
scheduler::scheduler(unsigned int num_threads,
                     size_t computation_depth,
                     size_t stack_size,
17 18 19 20 21 22 23
                     bool reuse_thread,
                     size_t serial_stack_size) : scheduler(num_threads,
                                                           computation_depth,
                                                           stack_size,
                                                           reuse_thread,
                                                           serial_stack_size,
                                                           base::mmap_stack_allocator{}) {}
24 25 26 27

scheduler::~scheduler() {
  terminate();
}
28 29
void scheduler::work_thread_main_loop() {
  auto &scheduler = thread_state::get().get_scheduler();
30
  while (true) {
31
    // Wait to be triggered
32
    scheduler.sync_barrier_.wait();
33 34

    // Check for shutdown
35
    if (scheduler.terminated_) {
36 37 38
      return;
    }

39
    scheduler.work_thread_work_section();
40

41
    // Sync back with main thread
42 43 44 45 46
    scheduler.sync_barrier_.wait();
  }
}

void scheduler::work_thread_work_section() {
47
  thread_state &my_state = thread_state::get();
48
  my_state.set_scheduler_active(true);
49
  unsigned const num_threads = my_state.get_scheduler().num_threads();
50

51
  if (my_state.get_thread_id() == 0) {
52
    // Main Thread, kick off by executing the user's main code block.
53
    main_thread_starter_function_->run();
54
  }
55

56
  while (!work_section_done_) {
57 58 59 60 61 62 63 64 65 66 67 68
#if PLS_SLEEP_WORKERS_ON_EMPTY
    // Mark us empty when beginning to steal, this spares another thread from finding us.
    data_structures::stamped_integer my_empty_flag = my_state.get_queue_empty_flag().load(std::memory_order_relaxed);
    if (my_empty_flag.value_ != EMPTY_QUEUE_STATE::QUEUE_EMPTY) {
      data_structures::stamped_integer target_emtpy_flag{my_empty_flag.stamp_, EMPTY_QUEUE_STATE::QUEUE_EMPTY};
      if (my_state.get_queue_empty_flag().compare_exchange_strong(my_empty_flag,
                                                                  target_emtpy_flag,
                                                                  std::memory_order_relaxed)) {
        // Only increase the counter if we got to mark us empty (could be that someone else already marked us!)
        empty_queue_increase_counter();
      }
    }
69
#endif
70
    PLS_ASSERT_EXPENSIVE(check_task_chain(*my_state.get_active_task()), "Must start stealing with a clean task chain.");
71 72 73 74 75 76 77

    size_t target;
    do {
      target = my_state.get_rand() % num_threads;
    } while (target == my_state.get_thread_id());

    thread_state &target_state = my_state.get_scheduler().thread_state_for(target);
78 79
#if PLS_SLEEP_WORKERS_ON_EMPTY
    queue_empty_flag_retry_steal:
80 81 82 83 84
    data_structures::stamped_integer target_queue_empty_flag =
        target_state.get_queue_empty_flag().load(std::memory_order_relaxed);
#endif
#if PLS_PROFILING_ENABLED
    my_state.get_scheduler().profiler_.stealing_start(my_state.get_thread_id());
85
#endif
86 87 88 89 90 91 92 93 94 95

    base_task *stolen_task;
    base_task *chain_after_stolen_task;
    bool cas_success;
    do {
#if PLS_PROFILING_ENABLED
      my_state.get_scheduler().profiler_.stealing_cas_op(my_state.get_thread_id());
#endif
      std::tie(stolen_task, chain_after_stolen_task, cas_success) =
          target_state.get_task_manager().steal_task(my_state);
96
    } while (!cas_success); // re-try on cas-conflicts, as in classic ws
97

98
    if (stolen_task) {
99
      // Keep task chain consistent. We want to appear as if we are working an a branch upwards of the stolen task.
100 101
      stolen_task->next_ = chain_after_stolen_task;
      chain_after_stolen_task->prev_ = stolen_task;
102
      my_state.set_active_task(stolen_task);
103 104 105
      // Keep locally owned resources consistent.
      auto *stolen_resources = stolen_task->attached_resources_.load(std::memory_order_relaxed);
      strain_local_resource::acquire_locally(stolen_resources, my_state.get_thread_id());
106

107
      PLS_ASSERT_EXPENSIVE(check_task_chain_forward(*my_state.get_active_task()),
108
                           "We are sole owner of this chain, it has to be valid!");
109 110 111 112

      // Execute the stolen task by jumping to it's continuation.
      PLS_ASSERT(stolen_task->continuation_.valid(),
                 "A task that we can steal must have a valid continuation for us to start working.");
113
      stolen_task->is_synchronized_ = false;
114 115 116 117 118
#if PLS_PROFILING_ENABLED
      my_state.get_scheduler().profiler_.stealing_end(my_state.get_thread_id(), true);
      my_state.get_scheduler().profiler_.task_start_running(my_state.get_thread_id(),
                                                            stolen_task->profiling_node_);
#endif
119 120
      context_switcher::switch_context(std::move(stolen_task->continuation_));
      // We will continue execution in this line when we finished the stolen work.
121
    } else {
122 123
      // Always yield on failed steals
      std::this_thread::yield();
124
#if PLS_PROFILING_ENABLED
125
      my_state.get_scheduler().profiler_.stealing_end(my_state.get_thread_id(), false);
126
#endif
127
#if PLS_SLEEP_WORKERS_ON_EMPTY
128 129 130 131 132 133 134 135 136 137
      switch (target_queue_empty_flag.value_) {
        case EMPTY_QUEUE_STATE::QUEUE_NON_EMPTY: {
          // We found the queue empty, but the flag says it should still be full.
          // We want to declare it empty, bet we need to re-check the queue in a sub-step to avoid races.
          data_structures::stamped_integer
              maybe_empty_flag{target_queue_empty_flag.stamp_ + 1, EMPTY_QUEUE_STATE::QUEUE_MAYBE_EMPTY};
          if (target_state.get_queue_empty_flag().compare_exchange_strong(target_queue_empty_flag,
                                                                          maybe_empty_flag,
                                                                          std::memory_order_acq_rel)) {
            goto queue_empty_flag_retry_steal;
138
          }
139 140 141 142 143 144 145 146 147 148 149 150
          break;
        }
        case EMPTY_QUEUE_STATE::QUEUE_MAYBE_EMPTY: {
          // We found the queue empty and it was already marked as maybe empty.
          // We can safely mark it empty and increment the central counter.
          data_structures::stamped_integer
              empty_flag{target_queue_empty_flag.stamp_ + 1, EMPTY_QUEUE_STATE::QUEUE_EMPTY};
          if (target_state.get_queue_empty_flag().compare_exchange_strong(target_queue_empty_flag,
                                                                          empty_flag,
                                                                          std::memory_order_acq_rel)) {
            // We marked it empty, now its our duty to modify the central counter
            my_state.get_scheduler().empty_queue_increase_counter();
151
          }
152 153 154 155 156 157 158 159 160
          break;
        }
        case EMPTY_QUEUE_STATE::QUEUE_EMPTY: {
          // The queue was already marked empty, just do nothing
          break;
        }
        default: {
          PLS_ASSERT(false, "The sleeping flag only has three possible states!");
          break;
161 162
        }
      }
163 164 165
      // Disregarding if we found the thread empty, we should check if we can put ourself to sleep
      my_state.get_scheduler().empty_queue_try_sleep_worker();
#endif
166 167
    }
  }
168
  my_state.set_scheduler_active(false);
169 170
}

171
void scheduler::sync_internal() {
172
  thread_state &syncing_state = thread_state::get();
173

174 175
  base_task *active_task = syncing_state.get_active_task();
  base_task *spawned_task = active_task->next_;
176

177
#if PLS_PROFILING_ENABLED
178 179 180 181 182 183 184 185 186
  syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(),
                                                                    active_task->stack_memory_,
                                                                    active_task->stack_size_,
                                                                    active_task->profiling_node_);
  syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(),
                                                            active_task->profiling_node_);
  auto *next_dag_node =
      syncing_state.get_scheduler().profiler_.task_sync(syncing_state.get_thread_id(), active_task->profiling_node_);
  active_task->profiling_node_ = next_dag_node;
187 188
#endif

189
  if (active_task->is_synchronized_) {
190
#if PLS_PROFILING_ENABLED
191 192
    thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(),
                                                                     thread_state::get().get_active_task()->profiling_node_);
193
#endif
194 195 196 197 198 199 200 201 202 203 204 205
    return; // We are already the sole owner of last_task
  } else {
    auto continuation =
        spawned_task->run_as_task([active_task, spawned_task, &syncing_state](context_switcher::continuation cont) {
          active_task->continuation_ = std::move(cont);
          syncing_state.set_active_task(spawned_task);
          return slow_return(syncing_state, true);
        });

    PLS_ASSERT(!continuation.valid(),
               "We only return to a sync point, never jump to it directly."
               "This must therefore never return an unfinished fiber/continuation.");
206
#if PLS_PROFILING_ENABLED
207 208
    thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(),
                                                                     thread_state::get().get_active_task()->profiling_node_);
209
#endif
210
    return; // We cleanly synced to the last one finishing work on last_task
211 212 213
  }
}

214
context_switcher::continuation scheduler::slow_return(thread_state &calling_state, bool in_sync) {
215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
  base_task *this_task = calling_state.get_active_task();
  PLS_ASSERT(this_task->depth_ > 0,
             "Must never try to return from a task at level 0 (no last task), as we must have a target to return to.");
  base_task *last_task = this_task->prev_;

  // Slow return means we try to finish the child 'this_task' of 'last_task' and we
  // do not know if we are the last child to finish.
  // If we are not the last one, we get a spare task chain for our resources and can return to the main scheduling loop.
  base_task *pop_result = calling_state.get_task_manager().pop_clean_task_chain(last_task);

  if (pop_result != nullptr) {
    base_task *clean_chain = pop_result;

    // We got a clean chain to fill up our resources.
    PLS_ASSERT(last_task->depth_ == clean_chain->depth_,
               "Resources must only reside in the correct depth!");
    PLS_ASSERT(last_task != clean_chain,
               "We want to swap out the last task and its chain to use a clean one, thus they must differ.");
233 234
    PLS_ASSERT_EXPENSIVE(check_task_chain_backward(*clean_chain),
                         "Can only acquire clean chains for clean returns!");
235 236 237 238 239

    // Acquire it/merge it with our task chain.
    this_task->prev_ = clean_chain;
    clean_chain->next_ = this_task;

240 241 242 243
    // Keep locally owned resources consistent.
    auto *clean_resources = clean_chain->attached_resources_.load(std::memory_order_relaxed);
    strain_local_resource::acquire_locally(clean_resources, calling_state.get_thread_id());

244 245 246 247 248 249 250 251
    base_task *active_task = clean_chain;
    while (active_task->depth_ > 0) {
      active_task = active_task->prev_;
    }
    calling_state.set_active_task(active_task);

    // Jump back to the continuation in main scheduling loop.
    context_switcher::continuation result_cont = std::move(thread_state::get().main_continuation());
252
    PLS_ASSERT(result_cont.valid(), "Must return a valid continuation (to main).");
253
    return result_cont;
254 255 256 257 258 259 260 261 262
  } else {
    // Make sure that we are owner of this full continuation/task chain.
    last_task->next_ = this_task;

    // We are the last one working on this task. Thus the sync must be finished, continue working.
    calling_state.set_active_task(last_task);
    last_task->is_synchronized_ = true;

    // Jump to parent task and continue working on it.
263 264 265 266 267 268
    if (in_sync) {
      PLS_ASSERT(last_task->continuation_.valid(), "Must return a valid continuation (to last task) in sync.");
    } else {
      PLS_ASSERT(last_task->continuation_.valid(), "Must return a valid continuation (to last task) in spawn.");
    }

269
    context_switcher::continuation result_cont = std::move(last_task->continuation_);
270
    return result_cont;
271 272 273
  }
}

274 275
base_task *scheduler::get_trade_task(base_task *stolen_task, thread_state &calling_state) {
  // Get task itself
276
  base_task *result = calling_state.get_active_task();
277
  while (result->depth_ > stolen_task->depth_) {
278 279
    result = result->prev_;
  }
280
  while (result->depth_ < stolen_task->depth_) {
281 282 283
    result = result->next_;
  }

284 285 286 287 288 289
  // Attach other resources we need to trade to it
  auto *stolen_resources = stolen_task->attached_resources_.load(std::memory_order_relaxed);
  auto *traded_resources = strain_local_resource::get_local_copy(stolen_resources, calling_state.get_thread_id());
  result->attached_resources_.store(traded_resources, std::memory_order_relaxed);

  return result;
290 291
}

292
void scheduler::terminate() {
293 294 295 296 297 298 299
  if (terminated_) {
    return;
  }

  terminated_ = true;
  sync_barrier_.wait();

300 301 302
  for (unsigned int i = 0; i < num_threads_; i++) {
    if (reuse_thread_ && i == 0) {
      continue;
303
    }
304
    worker_threads_[i].join();
305 306 307
  }
}

308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331
bool scheduler::check_task_chain_forward(base_task &start_task) {
  base_task *current = &start_task;
  while (current->next_) {
    if (current->next_->prev_ != current) {
      return false;
    }
    current = current->next_;
  }
  return true;
}

bool scheduler::check_task_chain_backward(base_task &start_task) {
  base_task *current = &start_task;
  while (current->prev_) {
    if (current->prev_->next_ != current) {
      return false;
    }
    current = current->prev_;
  }
  return true;
}

bool scheduler::check_task_chain(base_task &start_task) {
  return check_task_chain_backward(start_task) && check_task_chain_forward(start_task);
332 333
}

334 335 336 337 338
#if PLS_SLEEP_WORKERS_ON_EMPTY
// TODO: relax memory orderings

void scheduler::empty_queue_try_sleep_worker() {
  int32_t counter_value = empty_queue_counter_.load();
339
  if ((int) counter_value >= (int) num_threads()) {
340 341 342
#if PLS_PROFILING_ENABLED
    get_profiler().sleep_start(thread_state::get().get_thread_id());
#endif
343 344 345 346 347

    // Start sleeping
    base::futex_wait((int32_t *) &empty_queue_counter_, counter_value);
    // Stop sleeping and wake up others
    std::this_thread::yield();
348 349 350 351
    base::futex_wakeup((int32_t *) &empty_queue_counter_, 1);
#if PLS_PROFILING_ENABLED
    get_profiler().sleep_stop(thread_state::get().get_thread_id());
#endif
352 353 354 355
  }
}

void scheduler::empty_queue_increase_counter() {
356
  empty_queue_counter_.fetch_add(1);
357 358 359
}

void scheduler::empty_queue_decrease_counter_and_wake() {
360 361
  auto old_counter = empty_queue_counter_.fetch_sub(1, std::memory_order_acq_rel);
  if ((int) old_counter >= (int) num_threads()) {
362 363 364
    base::futex_wakeup((int32_t *) &empty_queue_counter_, 1);
  }
}
365 366 367 368 369 370 371 372 373

void scheduler::empty_queue_reset_and_wake_all() {
  if (num_threads() == 1) {
    return;
  }

  empty_queue_counter_.store(-1000);
  base::futex_wakeup((int32_t *) &empty_queue_counter_, 1);
}
374 375
#endif

376
}