scheduler.cpp 13.6 KB
Newer Older
1
#include "pls/internal/scheduling/scheduler.h"
2

3
#include "context_switcher/context_switcher.h"
4

5
#include "pls/internal/scheduling/strain_local_resource.h"
6
#include "pls/internal/build_flavour.h"
7
#include "pls/internal/base/error_handling.h"
8

9 10
#include <thread>

11
namespace pls::internal::scheduling {
12

13 14 15
scheduler::scheduler(unsigned int num_threads,
                     size_t computation_depth,
                     size_t stack_size,
16 17 18 19 20 21 22
                     bool reuse_thread,
                     size_t serial_stack_size) : scheduler(num_threads,
                                                           computation_depth,
                                                           stack_size,
                                                           reuse_thread,
                                                           serial_stack_size,
                                                           base::mmap_stack_allocator{}) {}
23 24 25 26

scheduler::~scheduler() {
  terminate();
}
27 28
void scheduler::work_thread_main_loop() {
  auto &scheduler = thread_state::get().get_scheduler();
29
  while (true) {
30
    // Wait to be triggered
31
    scheduler.sync_barrier_.wait();
32 33

    // Check for shutdown
34
    if (scheduler.terminated_) {
35 36 37
      return;
    }

38
    scheduler.work_thread_work_section();
39

40
    // Sync back with main thread
41 42 43 44 45
    scheduler.sync_barrier_.wait();
  }
}

void scheduler::work_thread_work_section() {
46
  thread_state &my_state = thread_state::get();
47
  my_state.set_scheduler_active(true);
48
  unsigned const num_threads = my_state.get_scheduler().num_threads();
49

50
  if (my_state.get_thread_id() == 0) {
51
    // Main Thread, kick off by executing the user's main code block.
52
    main_thread_starter_function_->run();
53
  }
54

55
  unsigned int failed_steals = 0;
56
  while (!work_section_done_) {
57 58 59
#if PLS_PROFILING_ENABLED
    my_state.get_scheduler().profiler_.stealing_start(my_state.get_thread_id());
#endif
60 61 62 63 64 65 66 67
    PLS_ASSERT(check_task_chain(*my_state.get_active_task()), "Must start stealing with a clean task chain.");

    size_t target;
    do {
      target = my_state.get_rand() % num_threads;
    } while (target == my_state.get_thread_id());

    thread_state &target_state = my_state.get_scheduler().thread_state_for(target);
68 69 70 71 72
#if PLS_SLEEP_WORKERS_ON_EMPTY
    queue_empty_flag_retry_steal:
    // TODO: relax atomics for empty flag
    data_structures::stamped_integer target_queue_empty_flag = target_state.get_queue_empty_flag().load();
#endif
73 74 75 76 77 78 79 80 81 82

    base_task *stolen_task;
    base_task *chain_after_stolen_task;
    bool cas_success;
    do {
#if PLS_PROFILING_ENABLED
      my_state.get_scheduler().profiler_.stealing_cas_op(my_state.get_thread_id());
#endif
      std::tie(stolen_task, chain_after_stolen_task, cas_success) =
          target_state.get_task_manager().steal_task(my_state);
83
    } while (!cas_success); // re-try on cas-conflicts, as in classic ws
84

85
    if (stolen_task) {
86
      // Keep task chain consistent. We want to appear as if we are working an a branch upwards of the stolen task.
87 88
      stolen_task->next_ = chain_after_stolen_task;
      chain_after_stolen_task->prev_ = stolen_task;
89
      my_state.set_active_task(stolen_task);
90 91 92
      // Keep locally owned resources consistent.
      auto *stolen_resources = stolen_task->attached_resources_.load(std::memory_order_relaxed);
      strain_local_resource::acquire_locally(stolen_resources, my_state.get_thread_id());
93

94
      PLS_ASSERT(check_task_chain_forward(*my_state.get_active_task()),
95 96 97 98 99
                 "We are sole owner of this chain, it has to be valid!");

      // Execute the stolen task by jumping to it's continuation.
      PLS_ASSERT(stolen_task->continuation_.valid(),
                 "A task that we can steal must have a valid continuation for us to start working.");
100
      stolen_task->is_synchronized_ = false;
101 102 103 104 105
#if PLS_PROFILING_ENABLED
      my_state.get_scheduler().profiler_.stealing_end(my_state.get_thread_id(), true);
      my_state.get_scheduler().profiler_.task_start_running(my_state.get_thread_id(),
                                                            stolen_task->profiling_node_);
#endif
106 107
      context_switcher::switch_context(std::move(stolen_task->continuation_));
      // We will continue execution in this line when we finished the stolen work.
108 109
      failed_steals = 0;
    } else {
110 111 112
      // TODO: tune value for when we start yielding
      const unsigned YIELD_AFTER = 1;

113
      failed_steals++;
114
      if (failed_steals >= YIELD_AFTER) {
115
        std::this_thread::yield();
116
#if PLS_PROFILING_ENABLED
117
        my_state.get_scheduler().profiler_.stealing_end(my_state.get_thread_id(), false);
118
#endif
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
#if PLS_SLEEP_WORKERS_ON_EMPTY
        switch (target_queue_empty_flag.value) {
          case EMPTY_QUEUE_STATE::QUEUE_NON_EMPTY: {
            // We found the queue empty, but the flag says it should still be full.
            // We want to declare it empty, bet we need to re-check the queue in a sub-step to avoid races.
            data_structures::stamped_integer
                maybe_empty_flag{target_queue_empty_flag.stamp + 1, EMPTY_QUEUE_STATE::QUEUE_MAYBE_EMPTY};
            if (target_state.get_queue_empty_flag().compare_exchange_strong(target_queue_empty_flag,
                                                                            maybe_empty_flag)) {
              goto queue_empty_flag_retry_steal;
            }
            break;
          }
          case EMPTY_QUEUE_STATE::QUEUE_MAYBE_EMPTY: {
            // We found the queue empty and it was already marked as maybe empty.
            // We can safely mark it empty and increment the central counter.
            data_structures::stamped_integer
                empty_flag{target_queue_empty_flag.stamp + 1, EMPTY_QUEUE_STATE::QUEUE_EMPTY};
            if (target_state.get_queue_empty_flag().compare_exchange_strong(target_queue_empty_flag, empty_flag)) {
              // We marked it empty, now its our duty to modify the central counter
              my_state.get_scheduler().empty_queue_increase_counter();
            }
            break;
          }
          case EMPTY_QUEUE_STATE::QUEUE_EMPTY: {
            // The queue was already marked empty, just do nothing
            break;
          }
        }
        // Disregarding if we found the thread empty, we should check if we can put ourself to sleep
        my_state.get_scheduler().empty_queue_try_sleep_worker();
#endif
      }
152 153
    }
  }
154
  my_state.set_scheduler_active(false);
155 156
}

157 158 159
void scheduler::sync_internal() {
  if (thread_state::is_scheduler_active()) {
    thread_state &syncing_state = thread_state::get();
160

161 162
    base_task *active_task = syncing_state.get_active_task();
    base_task *spawned_task = active_task->next_;
163

164
#if PLS_PROFILING_ENABLED
165 166 167 168 169 170 171 172 173
    syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(),
                                                                      active_task->stack_memory_,
                                                                      active_task->stack_size_,
                                                                      active_task->profiling_node_);
    syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(),
                                                              active_task->profiling_node_);
    auto *next_dag_node =
        syncing_state.get_scheduler().profiler_.task_sync(syncing_state.get_thread_id(), active_task->profiling_node_);
    active_task->profiling_node_ = next_dag_node;
174 175
#endif

176
    if (active_task->is_synchronized_) {
177
#if PLS_PROFILING_ENABLED
178 179
      thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(),
                                                                       thread_state::get().get_active_task()->profiling_node_);
180
#endif
181 182 183 184 185 186 187 188 189 190 191 192
      return; // We are already the sole owner of last_task
    } else {
      auto continuation =
          spawned_task->run_as_task([active_task, spawned_task, &syncing_state](context_switcher::continuation cont) {
            active_task->continuation_ = std::move(cont);
            syncing_state.set_active_task(spawned_task);
            return slow_return(syncing_state);
          });

      PLS_ASSERT(!continuation.valid(),
                 "We only return to a sync point, never jump to it directly."
                 "This must therefore never return an unfinished fiber/continuation.");
193
#if PLS_PROFILING_ENABLED
194 195
      thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(),
                                                                       thread_state::get().get_active_task()->profiling_node_);
196
#endif
197 198 199 200 201
      return; // We cleanly synced to the last one finishing work on last_task
    }
  } else {
    // Scheduler not active
    return;
202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
  }
}

context_switcher::continuation scheduler::slow_return(thread_state &calling_state) {
  base_task *this_task = calling_state.get_active_task();
  PLS_ASSERT(this_task->depth_ > 0,
             "Must never try to return from a task at level 0 (no last task), as we must have a target to return to.");
  base_task *last_task = this_task->prev_;

  // Slow return means we try to finish the child 'this_task' of 'last_task' and we
  // do not know if we are the last child to finish.
  // If we are not the last one, we get a spare task chain for our resources and can return to the main scheduling loop.
  base_task *pop_result = calling_state.get_task_manager().pop_clean_task_chain(last_task);

  if (pop_result != nullptr) {
    base_task *clean_chain = pop_result;

    // We got a clean chain to fill up our resources.
    PLS_ASSERT(last_task->depth_ == clean_chain->depth_,
               "Resources must only reside in the correct depth!");
    PLS_ASSERT(last_task != clean_chain,
               "We want to swap out the last task and its chain to use a clean one, thus they must differ.");
    PLS_ASSERT(check_task_chain_backward(*clean_chain),
               "Can only acquire clean chains for clean returns!");

    // Acquire it/merge it with our task chain.
    this_task->prev_ = clean_chain;
    clean_chain->next_ = this_task;

231 232 233 234
    // Keep locally owned resources consistent.
    auto *clean_resources = clean_chain->attached_resources_.load(std::memory_order_relaxed);
    strain_local_resource::acquire_locally(clean_resources, calling_state.get_thread_id());

235 236 237 238 239 240 241 242 243
    base_task *active_task = clean_chain;
    while (active_task->depth_ > 0) {
      active_task = active_task->prev_;
    }
    calling_state.set_active_task(active_task);

    // Jump back to the continuation in main scheduling loop.
    context_switcher::continuation result_cont = std::move(thread_state::get().main_continuation());
    PLS_ASSERT(result_cont.valid(), "Must return a valid continuation.");
244
    return result_cont;
245 246 247 248 249 250 251 252 253 254 255
  } else {
    // Make sure that we are owner of this full continuation/task chain.
    last_task->next_ = this_task;

    // We are the last one working on this task. Thus the sync must be finished, continue working.
    calling_state.set_active_task(last_task);
    last_task->is_synchronized_ = true;

    // Jump to parent task and continue working on it.
    context_switcher::continuation result_cont = std::move(last_task->continuation_);
    PLS_ASSERT(result_cont.valid(), "Must return a valid continuation.");
256
    return result_cont;
257 258 259
  }
}

260 261
base_task *scheduler::get_trade_task(base_task *stolen_task, thread_state &calling_state) {
  // Get task itself
262
  base_task *result = calling_state.get_active_task();
263
  while (result->depth_ > stolen_task->depth_) {
264 265
    result = result->prev_;
  }
266
  while (result->depth_ < stolen_task->depth_) {
267 268 269
    result = result->next_;
  }

270 271 272 273 274 275
  // Attach other resources we need to trade to it
  auto *stolen_resources = stolen_task->attached_resources_.load(std::memory_order_relaxed);
  auto *traded_resources = strain_local_resource::get_local_copy(stolen_resources, calling_state.get_thread_id());
  result->attached_resources_.store(traded_resources, std::memory_order_relaxed);

  return result;
276 277
}

278
void scheduler::terminate() {
279 280 281 282 283 284 285
  if (terminated_) {
    return;
  }

  terminated_ = true;
  sync_barrier_.wait();

286 287 288
  for (unsigned int i = 0; i < num_threads_; i++) {
    if (reuse_thread_ && i == 0) {
      continue;
289
    }
290
    worker_threads_[i].join();
291 292 293
  }
}

294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317
bool scheduler::check_task_chain_forward(base_task &start_task) {
  base_task *current = &start_task;
  while (current->next_) {
    if (current->next_->prev_ != current) {
      return false;
    }
    current = current->next_;
  }
  return true;
}

bool scheduler::check_task_chain_backward(base_task &start_task) {
  base_task *current = &start_task;
  while (current->prev_) {
    if (current->prev_->next_ != current) {
      return false;
    }
    current = current->prev_;
  }
  return true;
}

bool scheduler::check_task_chain(base_task &start_task) {
  return check_task_chain_backward(start_task) && check_task_chain_forward(start_task);
318 319
}

320 321 322 323 324
#if PLS_SLEEP_WORKERS_ON_EMPTY
// TODO: relax memory orderings

void scheduler::empty_queue_try_sleep_worker() {
  int32_t counter_value = empty_queue_counter_.load();
325 326 327 328
  if (counter_value == num_threads()) {
#if PLS_PROFILING_ENABLED
    get_profiler().sleep_start(thread_state::get().get_thread_id());
#endif
329
    threads_sleeping_++;
330
    base::futex_wait((int32_t *) &empty_queue_counter_, num_threads());
331
    threads_sleeping_--;
332 333 334 335
    base::futex_wakeup((int32_t *) &empty_queue_counter_, 1);
#if PLS_PROFILING_ENABLED
    get_profiler().sleep_stop(thread_state::get().get_thread_id());
#endif
336 337 338 339
  }
}

void scheduler::empty_queue_increase_counter() {
340
  empty_queue_counter_.fetch_add(1);
341 342 343 344 345 346 347 348 349 350
}

void scheduler::empty_queue_decrease_counter_and_wake() {
  empty_queue_counter_.fetch_sub(1);
  if (threads_sleeping_.load() > 0) {
    base::futex_wakeup((int32_t *) &empty_queue_counter_, 1);
  }
}
#endif

351
}