scheduler.cpp 13.5 KB
Newer Older
1
#include "pls/internal/scheduling/scheduler.h"
2

3
#include "context_switcher/context_switcher.h"
4

5
#include "pls/internal/scheduling/strain_local_resource.h"
6 7
#include "pls/internal/build_flavour.h"
#include "pls/internal/base/futex_wrapper.h"
8
#include "pls/internal/base/error_handling.h"
9

10 11
#include <thread>

12
namespace pls::internal::scheduling {
13

14 15 16 17 18 19 20 21
scheduler::scheduler(unsigned int num_threads,
                     size_t computation_depth,
                     size_t stack_size,
                     bool reuse_thread) : scheduler(num_threads,
                                                    computation_depth,
                                                    stack_size,
                                                    reuse_thread,
                                                    base::mmap_stack_allocator{}) {}
22 23 24 25

scheduler::~scheduler() {
  terminate();
}
26 27
void scheduler::work_thread_main_loop() {
  auto &scheduler = thread_state::get().get_scheduler();
28
  while (true) {
29
    // Wait to be triggered
30
    scheduler.sync_barrier_.wait();
31 32

    // Check for shutdown
33
    if (scheduler.terminated_) {
34 35 36
      return;
    }

37
    scheduler.work_thread_work_section();
38

39
    // Sync back with main thread
40 41 42 43 44
    scheduler.sync_barrier_.wait();
  }
}

void scheduler::work_thread_work_section() {
45
  thread_state &my_state = thread_state::get();
46
  my_state.set_scheduler_active(true);
47
  unsigned const num_threads = my_state.get_scheduler().num_threads();
48

49
  if (my_state.get_thread_id() == 0) {
50
    // Main Thread, kick off by executing the user's main code block.
51
    main_thread_starter_function_->run();
52
  }
53

54
  unsigned int failed_steals = 0;
55
  while (!work_section_done_) {
56 57 58
#if PLS_PROFILING_ENABLED
    my_state.get_scheduler().profiler_.stealing_start(my_state.get_thread_id());
#endif
59 60 61 62 63 64 65 66
    PLS_ASSERT(check_task_chain(*my_state.get_active_task()), "Must start stealing with a clean task chain.");

    size_t target;
    do {
      target = my_state.get_rand() % num_threads;
    } while (target == my_state.get_thread_id());

    thread_state &target_state = my_state.get_scheduler().thread_state_for(target);
67 68 69 70 71
#if PLS_SLEEP_WORKERS_ON_EMPTY
    queue_empty_flag_retry_steal:
    // TODO: relax atomics for empty flag
    data_structures::stamped_integer target_queue_empty_flag = target_state.get_queue_empty_flag().load();
#endif
72 73 74 75 76 77 78 79 80 81

    base_task *stolen_task;
    base_task *chain_after_stolen_task;
    bool cas_success;
    do {
#if PLS_PROFILING_ENABLED
      my_state.get_scheduler().profiler_.stealing_cas_op(my_state.get_thread_id());
#endif
      std::tie(stolen_task, chain_after_stolen_task, cas_success) =
          target_state.get_task_manager().steal_task(my_state);
82
    } while (!cas_success); // re-try on cas-conflicts, as in classic ws
83

84
    if (stolen_task) {
85
      // Keep task chain consistent. We want to appear as if we are working an a branch upwards of the stolen task.
86 87
      stolen_task->next_ = chain_after_stolen_task;
      chain_after_stolen_task->prev_ = stolen_task;
88
      my_state.set_active_task(stolen_task);
89 90 91
      // Keep locally owned resources consistent.
      auto *stolen_resources = stolen_task->attached_resources_.load(std::memory_order_relaxed);
      strain_local_resource::acquire_locally(stolen_resources, my_state.get_thread_id());
92

93
      PLS_ASSERT(check_task_chain_forward(*my_state.get_active_task()),
94 95 96 97 98
                 "We are sole owner of this chain, it has to be valid!");

      // Execute the stolen task by jumping to it's continuation.
      PLS_ASSERT(stolen_task->continuation_.valid(),
                 "A task that we can steal must have a valid continuation for us to start working.");
99
      stolen_task->is_synchronized_ = false;
100 101 102 103 104
#if PLS_PROFILING_ENABLED
      my_state.get_scheduler().profiler_.stealing_end(my_state.get_thread_id(), true);
      my_state.get_scheduler().profiler_.task_start_running(my_state.get_thread_id(),
                                                            stolen_task->profiling_node_);
#endif
105 106
      context_switcher::switch_context(std::move(stolen_task->continuation_));
      // We will continue execution in this line when we finished the stolen work.
107 108
      failed_steals = 0;
    } else {
109 110 111
      // TODO: tune value for when we start yielding
      const unsigned YIELD_AFTER = 1;

112
      failed_steals++;
113
      if (failed_steals >= YIELD_AFTER) {
114
        std::this_thread::yield();
115
#if PLS_PROFILING_ENABLED
116
        my_state.get_scheduler().profiler_.stealing_end(my_state.get_thread_id(), false);
117
#endif
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
#if PLS_SLEEP_WORKERS_ON_EMPTY
        switch (target_queue_empty_flag.value) {
          case EMPTY_QUEUE_STATE::QUEUE_NON_EMPTY: {
            // We found the queue empty, but the flag says it should still be full.
            // We want to declare it empty, bet we need to re-check the queue in a sub-step to avoid races.
            data_structures::stamped_integer
                maybe_empty_flag{target_queue_empty_flag.stamp + 1, EMPTY_QUEUE_STATE::QUEUE_MAYBE_EMPTY};
            if (target_state.get_queue_empty_flag().compare_exchange_strong(target_queue_empty_flag,
                                                                            maybe_empty_flag)) {
              goto queue_empty_flag_retry_steal;
            }
            break;
          }
          case EMPTY_QUEUE_STATE::QUEUE_MAYBE_EMPTY: {
            // We found the queue empty and it was already marked as maybe empty.
            // We can safely mark it empty and increment the central counter.
            data_structures::stamped_integer
                empty_flag{target_queue_empty_flag.stamp + 1, EMPTY_QUEUE_STATE::QUEUE_EMPTY};
            if (target_state.get_queue_empty_flag().compare_exchange_strong(target_queue_empty_flag, empty_flag)) {
              // We marked it empty, now its our duty to modify the central counter
              my_state.get_scheduler().empty_queue_increase_counter();
            }
            break;
          }
          case EMPTY_QUEUE_STATE::QUEUE_EMPTY: {
            // The queue was already marked empty, just do nothing
            break;
          }
        }
        // Disregarding if we found the thread empty, we should check if we can put ourself to sleep
        my_state.get_scheduler().empty_queue_try_sleep_worker();
#endif
      }
151 152
    }
  }
153
  my_state.set_scheduler_active(false);
154 155
}

156 157 158
void scheduler::sync_internal() {
  if (thread_state::is_scheduler_active()) {
    thread_state &syncing_state = thread_state::get();
159

160 161
    base_task *active_task = syncing_state.get_active_task();
    base_task *spawned_task = active_task->next_;
162

163
#if PLS_PROFILING_ENABLED
164 165 166 167 168 169 170 171 172
    syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(),
                                                                      active_task->stack_memory_,
                                                                      active_task->stack_size_,
                                                                      active_task->profiling_node_);
    syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(),
                                                              active_task->profiling_node_);
    auto *next_dag_node =
        syncing_state.get_scheduler().profiler_.task_sync(syncing_state.get_thread_id(), active_task->profiling_node_);
    active_task->profiling_node_ = next_dag_node;
173 174
#endif

175
    if (active_task->is_synchronized_) {
176
#if PLS_PROFILING_ENABLED
177 178
      thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(),
                                                                       thread_state::get().get_active_task()->profiling_node_);
179
#endif
180 181 182 183 184 185 186 187 188 189 190 191
      return; // We are already the sole owner of last_task
    } else {
      auto continuation =
          spawned_task->run_as_task([active_task, spawned_task, &syncing_state](context_switcher::continuation cont) {
            active_task->continuation_ = std::move(cont);
            syncing_state.set_active_task(spawned_task);
            return slow_return(syncing_state);
          });

      PLS_ASSERT(!continuation.valid(),
                 "We only return to a sync point, never jump to it directly."
                 "This must therefore never return an unfinished fiber/continuation.");
192
#if PLS_PROFILING_ENABLED
193 194
      thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(),
                                                                       thread_state::get().get_active_task()->profiling_node_);
195
#endif
196 197 198 199 200
      return; // We cleanly synced to the last one finishing work on last_task
    }
  } else {
    // Scheduler not active
    return;
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
  }
}

context_switcher::continuation scheduler::slow_return(thread_state &calling_state) {
  base_task *this_task = calling_state.get_active_task();
  PLS_ASSERT(this_task->depth_ > 0,
             "Must never try to return from a task at level 0 (no last task), as we must have a target to return to.");
  base_task *last_task = this_task->prev_;

  // Slow return means we try to finish the child 'this_task' of 'last_task' and we
  // do not know if we are the last child to finish.
  // If we are not the last one, we get a spare task chain for our resources and can return to the main scheduling loop.
  base_task *pop_result = calling_state.get_task_manager().pop_clean_task_chain(last_task);

  if (pop_result != nullptr) {
    base_task *clean_chain = pop_result;

    // We got a clean chain to fill up our resources.
    PLS_ASSERT(last_task->depth_ == clean_chain->depth_,
               "Resources must only reside in the correct depth!");
    PLS_ASSERT(last_task != clean_chain,
               "We want to swap out the last task and its chain to use a clean one, thus they must differ.");
    PLS_ASSERT(check_task_chain_backward(*clean_chain),
               "Can only acquire clean chains for clean returns!");

    // Acquire it/merge it with our task chain.
    this_task->prev_ = clean_chain;
    clean_chain->next_ = this_task;

230 231 232 233
    // Keep locally owned resources consistent.
    auto *clean_resources = clean_chain->attached_resources_.load(std::memory_order_relaxed);
    strain_local_resource::acquire_locally(clean_resources, calling_state.get_thread_id());

234 235 236 237 238 239 240 241 242
    base_task *active_task = clean_chain;
    while (active_task->depth_ > 0) {
      active_task = active_task->prev_;
    }
    calling_state.set_active_task(active_task);

    // Jump back to the continuation in main scheduling loop.
    context_switcher::continuation result_cont = std::move(thread_state::get().main_continuation());
    PLS_ASSERT(result_cont.valid(), "Must return a valid continuation.");
243
    return result_cont;
244 245 246 247 248 249 250 251 252 253 254
  } else {
    // Make sure that we are owner of this full continuation/task chain.
    last_task->next_ = this_task;

    // We are the last one working on this task. Thus the sync must be finished, continue working.
    calling_state.set_active_task(last_task);
    last_task->is_synchronized_ = true;

    // Jump to parent task and continue working on it.
    context_switcher::continuation result_cont = std::move(last_task->continuation_);
    PLS_ASSERT(result_cont.valid(), "Must return a valid continuation.");
255
    return result_cont;
256 257 258
  }
}

259 260
base_task *scheduler::get_trade_task(base_task *stolen_task, thread_state &calling_state) {
  // Get task itself
261
  base_task *result = calling_state.get_active_task();
262
  while (result->depth_ > stolen_task->depth_) {
263 264
    result = result->prev_;
  }
265
  while (result->depth_ < stolen_task->depth_) {
266 267 268
    result = result->next_;
  }

269 270 271 272 273 274
  // Attach other resources we need to trade to it
  auto *stolen_resources = stolen_task->attached_resources_.load(std::memory_order_relaxed);
  auto *traded_resources = strain_local_resource::get_local_copy(stolen_resources, calling_state.get_thread_id());
  result->attached_resources_.store(traded_resources, std::memory_order_relaxed);

  return result;
275 276
}

277
void scheduler::terminate() {
278 279 280 281 282 283 284
  if (terminated_) {
    return;
  }

  terminated_ = true;
  sync_barrier_.wait();

285 286 287
  for (unsigned int i = 0; i < num_threads_; i++) {
    if (reuse_thread_ && i == 0) {
      continue;
288
    }
289
    worker_threads_[i].join();
290 291 292
  }
}

293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316
bool scheduler::check_task_chain_forward(base_task &start_task) {
  base_task *current = &start_task;
  while (current->next_) {
    if (current->next_->prev_ != current) {
      return false;
    }
    current = current->next_;
  }
  return true;
}

bool scheduler::check_task_chain_backward(base_task &start_task) {
  base_task *current = &start_task;
  while (current->prev_) {
    if (current->prev_->next_ != current) {
      return false;
    }
    current = current->prev_;
  }
  return true;
}

bool scheduler::check_task_chain(base_task &start_task) {
  return check_task_chain_backward(start_task) && check_task_chain_forward(start_task);
317 318
}

319 320 321 322 323
#if PLS_SLEEP_WORKERS_ON_EMPTY
// TODO: relax memory orderings

void scheduler::empty_queue_try_sleep_worker() {
  int32_t counter_value = empty_queue_counter_.load();
324 325 326 327
  if (counter_value == num_threads()) {
#if PLS_PROFILING_ENABLED
    get_profiler().sleep_start(thread_state::get().get_thread_id());
#endif
328
    threads_sleeping_++;
329
    base::futex_wait((int32_t *) &empty_queue_counter_, num_threads());
330
    threads_sleeping_--;
331 332 333 334
    base::futex_wakeup((int32_t *) &empty_queue_counter_, 1);
#if PLS_PROFILING_ENABLED
    get_profiler().sleep_stop(thread_state::get().get_thread_id());
#endif
335 336 337 338
  }
}

void scheduler::empty_queue_increase_counter() {
339
  empty_queue_counter_.fetch_add(1);
340 341 342 343 344 345 346 347 348 349
}

void scheduler::empty_queue_decrease_counter_and_wake() {
  empty_queue_counter_.fetch_sub(1);
  if (threads_sleeping_.load() > 0) {
    base::futex_wakeup((int32_t *) &empty_queue_counter_, 1);
  }
}
#endif

350
}