scheduler.cpp 13.8 KB
Newer Older
1
#include "pls/internal/scheduling/scheduler.h"
2

3
#include "context_switcher/context_switcher.h"
4

5
#include "pls/internal/scheduling/strain_local_resource.h"
6
#include "pls/internal/build_flavour.h"
7
#include "pls/internal/base/error_handling.h"
8

9 10
#include <thread>

11
namespace pls::internal::scheduling {
12

13 14 15
scheduler::scheduler(unsigned int num_threads,
                     size_t computation_depth,
                     size_t stack_size,
16 17 18 19 20 21 22
                     bool reuse_thread,
                     size_t serial_stack_size) : scheduler(num_threads,
                                                           computation_depth,
                                                           stack_size,
                                                           reuse_thread,
                                                           serial_stack_size,
                                                           base::mmap_stack_allocator{}) {}
23 24 25 26

scheduler::~scheduler() {
  terminate();
}
27 28
void scheduler::work_thread_main_loop() {
  auto &scheduler = thread_state::get().get_scheduler();
29
  while (true) {
30
    // Wait to be triggered
31
    scheduler.sync_barrier_.wait();
32 33

    // Check for shutdown
34
    if (scheduler.terminated_) {
35 36 37
      return;
    }

38
    scheduler.work_thread_work_section();
39

40
    // Sync back with main thread
41 42 43 44 45
    scheduler.sync_barrier_.wait();
  }
}

void scheduler::work_thread_work_section() {
46
  thread_state &my_state = thread_state::get();
47
  my_state.set_scheduler_active(true);
48
  unsigned const num_threads = my_state.get_scheduler().num_threads();
49

50
  if (my_state.get_thread_id() == 0) {
51
    // Main Thread, kick off by executing the user's main code block.
52
    main_thread_starter_function_->run();
53
  }
54

55
  unsigned int failed_steals = 0;
56
  while (!work_section_done_) {
57 58 59
#if PLS_PROFILING_ENABLED
    my_state.get_scheduler().profiler_.stealing_start(my_state.get_thread_id());
#endif
60 61 62 63 64 65 66 67
    PLS_ASSERT(check_task_chain(*my_state.get_active_task()), "Must start stealing with a clean task chain.");

    size_t target;
    do {
      target = my_state.get_rand() % num_threads;
    } while (target == my_state.get_thread_id());

    thread_state &target_state = my_state.get_scheduler().thread_state_for(target);
68 69 70 71 72
#if PLS_SLEEP_WORKERS_ON_EMPTY
    queue_empty_flag_retry_steal:
    // TODO: relax atomics for empty flag
    data_structures::stamped_integer target_queue_empty_flag = target_state.get_queue_empty_flag().load();
#endif
73 74 75 76 77 78 79 80 81 82

    base_task *stolen_task;
    base_task *chain_after_stolen_task;
    bool cas_success;
    do {
#if PLS_PROFILING_ENABLED
      my_state.get_scheduler().profiler_.stealing_cas_op(my_state.get_thread_id());
#endif
      std::tie(stolen_task, chain_after_stolen_task, cas_success) =
          target_state.get_task_manager().steal_task(my_state);
83
    } while (!cas_success); // re-try on cas-conflicts, as in classic ws
84

85
    if (stolen_task) {
86
      // Keep task chain consistent. We want to appear as if we are working an a branch upwards of the stolen task.
87 88
      stolen_task->next_ = chain_after_stolen_task;
      chain_after_stolen_task->prev_ = stolen_task;
89
      my_state.set_active_task(stolen_task);
90 91 92
      // Keep locally owned resources consistent.
      auto *stolen_resources = stolen_task->attached_resources_.load(std::memory_order_relaxed);
      strain_local_resource::acquire_locally(stolen_resources, my_state.get_thread_id());
93

94
      PLS_ASSERT(check_task_chain_forward(*my_state.get_active_task()),
95 96 97 98 99
                 "We are sole owner of this chain, it has to be valid!");

      // Execute the stolen task by jumping to it's continuation.
      PLS_ASSERT(stolen_task->continuation_.valid(),
                 "A task that we can steal must have a valid continuation for us to start working.");
100
      stolen_task->is_synchronized_ = false;
101 102 103 104 105
#if PLS_PROFILING_ENABLED
      my_state.get_scheduler().profiler_.stealing_end(my_state.get_thread_id(), true);
      my_state.get_scheduler().profiler_.task_start_running(my_state.get_thread_id(),
                                                            stolen_task->profiling_node_);
#endif
106 107
      context_switcher::switch_context(std::move(stolen_task->continuation_));
      // We will continue execution in this line when we finished the stolen work.
108 109
      failed_steals = 0;
    } else {
110 111 112
      // TODO: tune value for when we start yielding
      const unsigned YIELD_AFTER = 1;

113
      failed_steals++;
114
      if (failed_steals >= YIELD_AFTER) {
115
        std::this_thread::yield();
116
#if PLS_PROFILING_ENABLED
117
        my_state.get_scheduler().profiler_.stealing_end(my_state.get_thread_id(), false);
118
#endif
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
#if PLS_SLEEP_WORKERS_ON_EMPTY
        switch (target_queue_empty_flag.value) {
          case EMPTY_QUEUE_STATE::QUEUE_NON_EMPTY: {
            // We found the queue empty, but the flag says it should still be full.
            // We want to declare it empty, bet we need to re-check the queue in a sub-step to avoid races.
            data_structures::stamped_integer
                maybe_empty_flag{target_queue_empty_flag.stamp + 1, EMPTY_QUEUE_STATE::QUEUE_MAYBE_EMPTY};
            if (target_state.get_queue_empty_flag().compare_exchange_strong(target_queue_empty_flag,
                                                                            maybe_empty_flag)) {
              goto queue_empty_flag_retry_steal;
            }
            break;
          }
          case EMPTY_QUEUE_STATE::QUEUE_MAYBE_EMPTY: {
            // We found the queue empty and it was already marked as maybe empty.
            // We can safely mark it empty and increment the central counter.
            data_structures::stamped_integer
                empty_flag{target_queue_empty_flag.stamp + 1, EMPTY_QUEUE_STATE::QUEUE_EMPTY};
            if (target_state.get_queue_empty_flag().compare_exchange_strong(target_queue_empty_flag, empty_flag)) {
              // We marked it empty, now its our duty to modify the central counter
              my_state.get_scheduler().empty_queue_increase_counter();
            }
            break;
          }
          case EMPTY_QUEUE_STATE::QUEUE_EMPTY: {
            // The queue was already marked empty, just do nothing
            break;
          }
        }
        // Disregarding if we found the thread empty, we should check if we can put ourself to sleep
        my_state.get_scheduler().empty_queue_try_sleep_worker();
#endif
      }
152 153
    }
  }
154
  my_state.set_scheduler_active(false);
155 156
}

157 158 159
void scheduler::sync_internal() {
  if (thread_state::is_scheduler_active()) {
    thread_state &syncing_state = thread_state::get();
160

161 162
    base_task *active_task = syncing_state.get_active_task();
    base_task *spawned_task = active_task->next_;
163

164
#if PLS_PROFILING_ENABLED
165 166 167 168 169 170 171 172 173
    syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(),
                                                                      active_task->stack_memory_,
                                                                      active_task->stack_size_,
                                                                      active_task->profiling_node_);
    syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(),
                                                              active_task->profiling_node_);
    auto *next_dag_node =
        syncing_state.get_scheduler().profiler_.task_sync(syncing_state.get_thread_id(), active_task->profiling_node_);
    active_task->profiling_node_ = next_dag_node;
174 175
#endif

176
    if (active_task->is_synchronized_) {
177
#if PLS_PROFILING_ENABLED
178 179
      thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(),
                                                                       thread_state::get().get_active_task()->profiling_node_);
180
#endif
181 182 183 184 185 186
      return; // We are already the sole owner of last_task
    } else {
      auto continuation =
          spawned_task->run_as_task([active_task, spawned_task, &syncing_state](context_switcher::continuation cont) {
            active_task->continuation_ = std::move(cont);
            syncing_state.set_active_task(spawned_task);
187
            return slow_return(syncing_state, true);
188 189 190 191 192
          });

      PLS_ASSERT(!continuation.valid(),
                 "We only return to a sync point, never jump to it directly."
                 "This must therefore never return an unfinished fiber/continuation.");
193
#if PLS_PROFILING_ENABLED
194 195
      thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(),
                                                                       thread_state::get().get_active_task()->profiling_node_);
196
#endif
197 198 199 200 201
      return; // We cleanly synced to the last one finishing work on last_task
    }
  } else {
    // Scheduler not active
    return;
202 203 204
  }
}

205
context_switcher::continuation scheduler::slow_return(thread_state &calling_state, bool in_sync) {
206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
  base_task *this_task = calling_state.get_active_task();
  PLS_ASSERT(this_task->depth_ > 0,
             "Must never try to return from a task at level 0 (no last task), as we must have a target to return to.");
  base_task *last_task = this_task->prev_;

  // Slow return means we try to finish the child 'this_task' of 'last_task' and we
  // do not know if we are the last child to finish.
  // If we are not the last one, we get a spare task chain for our resources and can return to the main scheduling loop.
  base_task *pop_result = calling_state.get_task_manager().pop_clean_task_chain(last_task);

  if (pop_result != nullptr) {
    base_task *clean_chain = pop_result;

    // We got a clean chain to fill up our resources.
    PLS_ASSERT(last_task->depth_ == clean_chain->depth_,
               "Resources must only reside in the correct depth!");
    PLS_ASSERT(last_task != clean_chain,
               "We want to swap out the last task and its chain to use a clean one, thus they must differ.");
    PLS_ASSERT(check_task_chain_backward(*clean_chain),
               "Can only acquire clean chains for clean returns!");

    // Acquire it/merge it with our task chain.
    this_task->prev_ = clean_chain;
    clean_chain->next_ = this_task;

231 232 233 234
    // Keep locally owned resources consistent.
    auto *clean_resources = clean_chain->attached_resources_.load(std::memory_order_relaxed);
    strain_local_resource::acquire_locally(clean_resources, calling_state.get_thread_id());

235 236 237 238 239 240 241 242
    base_task *active_task = clean_chain;
    while (active_task->depth_ > 0) {
      active_task = active_task->prev_;
    }
    calling_state.set_active_task(active_task);

    // Jump back to the continuation in main scheduling loop.
    context_switcher::continuation result_cont = std::move(thread_state::get().main_continuation());
243
    PLS_ASSERT(result_cont.valid(), "Must return a valid continuation (to main).");
244
    return result_cont;
245 246 247 248 249 250 251 252 253
  } else {
    // Make sure that we are owner of this full continuation/task chain.
    last_task->next_ = this_task;

    // We are the last one working on this task. Thus the sync must be finished, continue working.
    calling_state.set_active_task(last_task);
    last_task->is_synchronized_ = true;

    // Jump to parent task and continue working on it.
254 255 256 257 258 259
    if (in_sync) {
      PLS_ASSERT(last_task->continuation_.valid(), "Must return a valid continuation (to last task) in sync.");
    } else {
      PLS_ASSERT(last_task->continuation_.valid(), "Must return a valid continuation (to last task) in spawn.");
    }

260
    context_switcher::continuation result_cont = std::move(last_task->continuation_);
261
    return result_cont;
262 263 264
  }
}

265 266
base_task *scheduler::get_trade_task(base_task *stolen_task, thread_state &calling_state) {
  // Get task itself
267
  base_task *result = calling_state.get_active_task();
268
  while (result->depth_ > stolen_task->depth_) {
269 270
    result = result->prev_;
  }
271
  while (result->depth_ < stolen_task->depth_) {
272 273 274
    result = result->next_;
  }

275 276 277 278 279 280
  // Attach other resources we need to trade to it
  auto *stolen_resources = stolen_task->attached_resources_.load(std::memory_order_relaxed);
  auto *traded_resources = strain_local_resource::get_local_copy(stolen_resources, calling_state.get_thread_id());
  result->attached_resources_.store(traded_resources, std::memory_order_relaxed);

  return result;
281 282
}

283
void scheduler::terminate() {
284 285 286 287 288 289 290
  if (terminated_) {
    return;
  }

  terminated_ = true;
  sync_barrier_.wait();

291 292 293
  for (unsigned int i = 0; i < num_threads_; i++) {
    if (reuse_thread_ && i == 0) {
      continue;
294
    }
295
    worker_threads_[i].join();
296 297 298
  }
}

299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322
bool scheduler::check_task_chain_forward(base_task &start_task) {
  base_task *current = &start_task;
  while (current->next_) {
    if (current->next_->prev_ != current) {
      return false;
    }
    current = current->next_;
  }
  return true;
}

bool scheduler::check_task_chain_backward(base_task &start_task) {
  base_task *current = &start_task;
  while (current->prev_) {
    if (current->prev_->next_ != current) {
      return false;
    }
    current = current->prev_;
  }
  return true;
}

bool scheduler::check_task_chain(base_task &start_task) {
  return check_task_chain_backward(start_task) && check_task_chain_forward(start_task);
323 324
}

325 326 327 328 329
#if PLS_SLEEP_WORKERS_ON_EMPTY
// TODO: relax memory orderings

void scheduler::empty_queue_try_sleep_worker() {
  int32_t counter_value = empty_queue_counter_.load();
330 331 332 333
  if (counter_value == num_threads()) {
#if PLS_PROFILING_ENABLED
    get_profiler().sleep_start(thread_state::get().get_thread_id());
#endif
334
    threads_sleeping_++;
335
    base::futex_wait((int32_t *) &empty_queue_counter_, num_threads());
336
    threads_sleeping_--;
337 338 339 340
    base::futex_wakeup((int32_t *) &empty_queue_counter_, 1);
#if PLS_PROFILING_ENABLED
    get_profiler().sleep_stop(thread_state::get().get_thread_id());
#endif
341 342 343 344
  }
}

void scheduler::empty_queue_increase_counter() {
345
  empty_queue_counter_.fetch_add(1);
346 347 348 349 350 351 352 353 354 355
}

void scheduler::empty_queue_decrease_counter_and_wake() {
  empty_queue_counter_.fetch_sub(1);
  if (threads_sleeping_.load() > 0) {
    base::futex_wakeup((int32_t *) &empty_queue_counter_, 1);
  }
}
#endif

356
}