scheduler.cpp 12.6 KB
Newer Older
1
#include "pls/internal/scheduling/scheduler.h"
2

3
#include "context_switcher/context_switcher.h"
4 5 6

#include "pls/internal/build_flavour.h"
#include "pls/internal/base/futex_wrapper.h"
7
#include "pls/internal/base/error_handling.h"
8

9 10
#include <thread>

11
namespace pls::internal::scheduling {
12

13 14 15 16 17 18 19 20
scheduler::scheduler(unsigned int num_threads,
                     size_t computation_depth,
                     size_t stack_size,
                     bool reuse_thread) : scheduler(num_threads,
                                                    computation_depth,
                                                    stack_size,
                                                    reuse_thread,
                                                    base::mmap_stack_allocator{}) {}
21 22 23 24

scheduler::~scheduler() {
  terminate();
}
25 26
void scheduler::work_thread_main_loop() {
  auto &scheduler = thread_state::get().get_scheduler();
27
  while (true) {
28
    // Wait to be triggered
29
    scheduler.sync_barrier_.wait();
30 31

    // Check for shutdown
32
    if (scheduler.terminated_) {
33 34 35
      return;
    }

36
    scheduler.work_thread_work_section();
37

38
    // Sync back with main thread
39 40 41 42 43
    scheduler.sync_barrier_.wait();
  }
}

void scheduler::work_thread_work_section() {
44
  thread_state &my_state = thread_state::get();
45
  my_state.set_scheduler_active(true);
46
  unsigned const num_threads = my_state.get_scheduler().num_threads();
47

48
  if (my_state.get_thread_id() == 0) {
49
    // Main Thread, kick off by executing the user's main code block.
50
    main_thread_starter_function_->run();
51
  }
52

53
  unsigned int failed_steals = 0;
54
  while (!work_section_done_) {
55 56 57
#if PLS_PROFILING_ENABLED
    my_state.get_scheduler().profiler_.stealing_start(my_state.get_thread_id());
#endif
58 59 60 61 62 63 64 65
    PLS_ASSERT(check_task_chain(*my_state.get_active_task()), "Must start stealing with a clean task chain.");

    size_t target;
    do {
      target = my_state.get_rand() % num_threads;
    } while (target == my_state.get_thread_id());

    thread_state &target_state = my_state.get_scheduler().thread_state_for(target);
66 67 68 69 70
#if PLS_SLEEP_WORKERS_ON_EMPTY
    queue_empty_flag_retry_steal:
    // TODO: relax atomics for empty flag
    data_structures::stamped_integer target_queue_empty_flag = target_state.get_queue_empty_flag().load();
#endif
71 72 73 74 75 76 77 78 79 80

    base_task *stolen_task;
    base_task *chain_after_stolen_task;
    bool cas_success;
    do {
#if PLS_PROFILING_ENABLED
      my_state.get_scheduler().profiler_.stealing_cas_op(my_state.get_thread_id());
#endif
      std::tie(stolen_task, chain_after_stolen_task, cas_success) =
          target_state.get_task_manager().steal_task(my_state);
81
    } while (!cas_success); // re-try on cas-conflicts, as in classic ws
82

83
    if (stolen_task) {
84
      // Keep task chain consistent. We want to appear as if we are working an a branch upwards of the stolen task.
85 86
      stolen_task->next_ = chain_after_stolen_task;
      chain_after_stolen_task->prev_ = stolen_task;
87
      my_state.set_active_task(stolen_task);
88

89
      PLS_ASSERT(check_task_chain_forward(*my_state.get_active_task()),
90 91 92 93 94
                 "We are sole owner of this chain, it has to be valid!");

      // Execute the stolen task by jumping to it's continuation.
      PLS_ASSERT(stolen_task->continuation_.valid(),
                 "A task that we can steal must have a valid continuation for us to start working.");
95
      stolen_task->is_synchronized_ = false;
96 97 98 99 100
#if PLS_PROFILING_ENABLED
      my_state.get_scheduler().profiler_.stealing_end(my_state.get_thread_id(), true);
      my_state.get_scheduler().profiler_.task_start_running(my_state.get_thread_id(),
                                                            stolen_task->profiling_node_);
#endif
101 102
      context_switcher::switch_context(std::move(stolen_task->continuation_));
      // We will continue execution in this line when we finished the stolen work.
103 104
      failed_steals = 0;
    } else {
105 106 107
      // TODO: tune value for when we start yielding
      const unsigned YIELD_AFTER = 1;

108
      failed_steals++;
109
      if (failed_steals >= YIELD_AFTER) {
110
        std::this_thread::yield();
111
#if PLS_PROFILING_ENABLED
112
        my_state.get_scheduler().profiler_.stealing_end(my_state.get_thread_id(), false);
113
#endif
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
#if PLS_SLEEP_WORKERS_ON_EMPTY
        switch (target_queue_empty_flag.value) {
          case EMPTY_QUEUE_STATE::QUEUE_NON_EMPTY: {
            // We found the queue empty, but the flag says it should still be full.
            // We want to declare it empty, bet we need to re-check the queue in a sub-step to avoid races.
            data_structures::stamped_integer
                maybe_empty_flag{target_queue_empty_flag.stamp + 1, EMPTY_QUEUE_STATE::QUEUE_MAYBE_EMPTY};
            if (target_state.get_queue_empty_flag().compare_exchange_strong(target_queue_empty_flag,
                                                                            maybe_empty_flag)) {
              goto queue_empty_flag_retry_steal;
            }
            break;
          }
          case EMPTY_QUEUE_STATE::QUEUE_MAYBE_EMPTY: {
            // We found the queue empty and it was already marked as maybe empty.
            // We can safely mark it empty and increment the central counter.
            data_structures::stamped_integer
                empty_flag{target_queue_empty_flag.stamp + 1, EMPTY_QUEUE_STATE::QUEUE_EMPTY};
            if (target_state.get_queue_empty_flag().compare_exchange_strong(target_queue_empty_flag, empty_flag)) {
              // We marked it empty, now its our duty to modify the central counter
              my_state.get_scheduler().empty_queue_increase_counter();
            }
            break;
          }
          case EMPTY_QUEUE_STATE::QUEUE_EMPTY: {
            // The queue was already marked empty, just do nothing
            break;
          }
        }
        // Disregarding if we found the thread empty, we should check if we can put ourself to sleep
        my_state.get_scheduler().empty_queue_try_sleep_worker();
#endif
      }
147 148
    }
  }
149
  my_state.set_scheduler_active(false);
150 151
}

152 153 154
void scheduler::sync_internal() {
  if (thread_state::is_scheduler_active()) {
    thread_state &syncing_state = thread_state::get();
155

156 157
    base_task *active_task = syncing_state.get_active_task();
    base_task *spawned_task = active_task->next_;
158

159
#if PLS_PROFILING_ENABLED
160 161 162 163 164 165 166 167 168
    syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(),
                                                                      active_task->stack_memory_,
                                                                      active_task->stack_size_,
                                                                      active_task->profiling_node_);
    syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(),
                                                              active_task->profiling_node_);
    auto *next_dag_node =
        syncing_state.get_scheduler().profiler_.task_sync(syncing_state.get_thread_id(), active_task->profiling_node_);
    active_task->profiling_node_ = next_dag_node;
169 170
#endif

171
    if (active_task->is_synchronized_) {
172
#if PLS_PROFILING_ENABLED
173 174
      thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(),
                                                                       thread_state::get().get_active_task()->profiling_node_);
175
#endif
176 177 178 179 180 181 182 183 184 185 186 187
      return; // We are already the sole owner of last_task
    } else {
      auto continuation =
          spawned_task->run_as_task([active_task, spawned_task, &syncing_state](context_switcher::continuation cont) {
            active_task->continuation_ = std::move(cont);
            syncing_state.set_active_task(spawned_task);
            return slow_return(syncing_state);
          });

      PLS_ASSERT(!continuation.valid(),
                 "We only return to a sync point, never jump to it directly."
                 "This must therefore never return an unfinished fiber/continuation.");
188
#if PLS_PROFILING_ENABLED
189 190
      thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(),
                                                                       thread_state::get().get_active_task()->profiling_node_);
191
#endif
192 193 194 195 196
      return; // We cleanly synced to the last one finishing work on last_task
    }
  } else {
    // Scheduler not active
    return;
197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234
  }
}

context_switcher::continuation scheduler::slow_return(thread_state &calling_state) {
  base_task *this_task = calling_state.get_active_task();
  PLS_ASSERT(this_task->depth_ > 0,
             "Must never try to return from a task at level 0 (no last task), as we must have a target to return to.");
  base_task *last_task = this_task->prev_;

  // Slow return means we try to finish the child 'this_task' of 'last_task' and we
  // do not know if we are the last child to finish.
  // If we are not the last one, we get a spare task chain for our resources and can return to the main scheduling loop.
  base_task *pop_result = calling_state.get_task_manager().pop_clean_task_chain(last_task);

  if (pop_result != nullptr) {
    base_task *clean_chain = pop_result;

    // We got a clean chain to fill up our resources.
    PLS_ASSERT(last_task->depth_ == clean_chain->depth_,
               "Resources must only reside in the correct depth!");
    PLS_ASSERT(last_task != clean_chain,
               "We want to swap out the last task and its chain to use a clean one, thus they must differ.");
    PLS_ASSERT(check_task_chain_backward(*clean_chain),
               "Can only acquire clean chains for clean returns!");

    // Acquire it/merge it with our task chain.
    this_task->prev_ = clean_chain;
    clean_chain->next_ = this_task;

    base_task *active_task = clean_chain;
    while (active_task->depth_ > 0) {
      active_task = active_task->prev_;
    }
    calling_state.set_active_task(active_task);

    // Jump back to the continuation in main scheduling loop.
    context_switcher::continuation result_cont = std::move(thread_state::get().main_continuation());
    PLS_ASSERT(result_cont.valid(), "Must return a valid continuation.");
235
    return result_cont;
236 237 238 239 240 241 242 243 244 245 246
  } else {
    // Make sure that we are owner of this full continuation/task chain.
    last_task->next_ = this_task;

    // We are the last one working on this task. Thus the sync must be finished, continue working.
    calling_state.set_active_task(last_task);
    last_task->is_synchronized_ = true;

    // Jump to parent task and continue working on it.
    context_switcher::continuation result_cont = std::move(last_task->continuation_);
    PLS_ASSERT(result_cont.valid(), "Must return a valid continuation.");
247
    return result_cont;
248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
  }
}

base_task &scheduler::task_chain_at(unsigned int depth, thread_state &calling_state) {
  // TODO: possible optimize with cache array at steal events
  base_task *result = calling_state.get_active_task();
  while (result->depth_ > depth) {
    result = result->prev_;
  }
  while (result->depth_ < depth) {
    result = result->next_;
  }

  return *result;
}

264
void scheduler::terminate() {
265 266 267 268 269 270 271
  if (terminated_) {
    return;
  }

  terminated_ = true;
  sync_barrier_.wait();

272 273 274
  for (unsigned int i = 0; i < num_threads_; i++) {
    if (reuse_thread_ && i == 0) {
      continue;
275
    }
276
    worker_threads_[i].join();
277 278 279
  }
}

280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303
bool scheduler::check_task_chain_forward(base_task &start_task) {
  base_task *current = &start_task;
  while (current->next_) {
    if (current->next_->prev_ != current) {
      return false;
    }
    current = current->next_;
  }
  return true;
}

bool scheduler::check_task_chain_backward(base_task &start_task) {
  base_task *current = &start_task;
  while (current->prev_) {
    if (current->prev_->next_ != current) {
      return false;
    }
    current = current->prev_;
  }
  return true;
}

bool scheduler::check_task_chain(base_task &start_task) {
  return check_task_chain_backward(start_task) && check_task_chain_forward(start_task);
304 305
}

306 307 308 309 310
#if PLS_SLEEP_WORKERS_ON_EMPTY
// TODO: relax memory orderings

void scheduler::empty_queue_try_sleep_worker() {
  int32_t counter_value = empty_queue_counter_.load();
311 312 313 314
  if (counter_value == num_threads()) {
#if PLS_PROFILING_ENABLED
    get_profiler().sleep_start(thread_state::get().get_thread_id());
#endif
315
    threads_sleeping_++;
316
    base::futex_wait((int32_t *) &empty_queue_counter_, num_threads());
317
    threads_sleeping_--;
318 319 320 321
    base::futex_wakeup((int32_t *) &empty_queue_counter_, 1);
#if PLS_PROFILING_ENABLED
    get_profiler().sleep_stop(thread_state::get().get_thread_id());
#endif
322 323 324 325
  }
}

void scheduler::empty_queue_increase_counter() {
326
  empty_queue_counter_.fetch_add(1);
327 328 329 330 331 332 333 334 335 336
}

void scheduler::empty_queue_decrease_counter_and_wake() {
  empty_queue_counter_.fetch_sub(1);
  if (threads_sleeping_.load() > 0) {
    base::futex_wakeup((int32_t *) &empty_queue_counter_, 1);
  }
}
#endif

337
}