scheduler.cpp 14 KB
Newer Older
1
#include "pls/internal/scheduling/scheduler.h"
2

3
#include "context_switcher/context_switcher.h"
4

5
#include "pls/internal/scheduling/strain_local_resource.h"
6
#include "pls/internal/build_flavour.h"
7
#include "pls/internal/base/error_handling.h"
8
#include "pls/internal/base/futex_wrapper.h"
9

10 11
#include <thread>

12
namespace pls::internal::scheduling {
13

14 15 16
scheduler::scheduler(unsigned int num_threads,
                     size_t computation_depth,
                     size_t stack_size,
17 18 19 20 21 22 23
                     bool reuse_thread,
                     size_t serial_stack_size) : scheduler(num_threads,
                                                           computation_depth,
                                                           stack_size,
                                                           reuse_thread,
                                                           serial_stack_size,
                                                           base::mmap_stack_allocator{}) {}
24 25 26 27

scheduler::~scheduler() {
  terminate();
}
28 29
void scheduler::work_thread_main_loop() {
  auto &scheduler = thread_state::get().get_scheduler();
30
  while (true) {
31
    // Wait to be triggered
32
    scheduler.sync_barrier_.wait();
33 34

    // Check for shutdown
35
    if (scheduler.terminated_) {
36 37 38
      return;
    }

39
    scheduler.work_thread_work_section();
40

41
    // Sync back with main thread
42 43 44 45 46
    scheduler.sync_barrier_.wait();
  }
}

void scheduler::work_thread_work_section() {
47
  thread_state &my_state = thread_state::get();
48
  my_state.set_scheduler_active(true);
49
  unsigned const num_threads = my_state.get_scheduler().num_threads();
50

51
  if (my_state.get_thread_id() == 0) {
52
    // Main Thread, kick off by executing the user's main code block.
53
    main_thread_starter_function_->run();
54
  }
55

56
  unsigned int failed_steals = 0;
57
  while (!work_section_done_) {
58 59 60
#if PLS_PROFILING_ENABLED
    my_state.get_scheduler().profiler_.stealing_start(my_state.get_thread_id());
#endif
61
    PLS_ASSERT_EXPENSIVE(check_task_chain(*my_state.get_active_task()), "Must start stealing with a clean task chain.");
62 63 64 65 66 67 68

    size_t target;
    do {
      target = my_state.get_rand() % num_threads;
    } while (target == my_state.get_thread_id());

    thread_state &target_state = my_state.get_scheduler().thread_state_for(target);
69 70
#if PLS_SLEEP_WORKERS_ON_EMPTY
    queue_empty_flag_retry_steal:
71 72
    data_structures::stamped_integer
        target_queue_empty_flag = target_state.get_queue_empty_flag().load(std::memory_order_relaxed);
73
#endif
74 75 76 77 78 79 80 81 82 83

    base_task *stolen_task;
    base_task *chain_after_stolen_task;
    bool cas_success;
    do {
#if PLS_PROFILING_ENABLED
      my_state.get_scheduler().profiler_.stealing_cas_op(my_state.get_thread_id());
#endif
      std::tie(stolen_task, chain_after_stolen_task, cas_success) =
          target_state.get_task_manager().steal_task(my_state);
84
    } while (!cas_success); // re-try on cas-conflicts, as in classic ws
85

86
    if (stolen_task) {
87
      // Keep task chain consistent. We want to appear as if we are working an a branch upwards of the stolen task.
88 89
      stolen_task->next_ = chain_after_stolen_task;
      chain_after_stolen_task->prev_ = stolen_task;
90
      my_state.set_active_task(stolen_task);
91 92 93
      // Keep locally owned resources consistent.
      auto *stolen_resources = stolen_task->attached_resources_.load(std::memory_order_relaxed);
      strain_local_resource::acquire_locally(stolen_resources, my_state.get_thread_id());
94

95
      PLS_ASSERT_EXPENSIVE(check_task_chain_forward(*my_state.get_active_task()),
96
                           "We are sole owner of this chain, it has to be valid!");
97 98 99 100

      // Execute the stolen task by jumping to it's continuation.
      PLS_ASSERT(stolen_task->continuation_.valid(),
                 "A task that we can steal must have a valid continuation for us to start working.");
101
      stolen_task->is_synchronized_ = false;
102 103 104 105 106
#if PLS_PROFILING_ENABLED
      my_state.get_scheduler().profiler_.stealing_end(my_state.get_thread_id(), true);
      my_state.get_scheduler().profiler_.task_start_running(my_state.get_thread_id(),
                                                            stolen_task->profiling_node_);
#endif
107 108
      context_switcher::switch_context(std::move(stolen_task->continuation_));
      // We will continue execution in this line when we finished the stolen work.
109 110
      failed_steals = 0;
    } else {
111 112 113
      // TODO: tune value for when we start yielding
      const unsigned YIELD_AFTER = 1;

114
      failed_steals++;
115
      if (failed_steals >= YIELD_AFTER) {
116
        std::this_thread::yield();
117
#if PLS_PROFILING_ENABLED
118
        my_state.get_scheduler().profiler_.stealing_end(my_state.get_thread_id(), false);
119
#endif
120
#if PLS_SLEEP_WORKERS_ON_EMPTY
121
        switch (target_queue_empty_flag.value_) {
122 123 124 125
          case EMPTY_QUEUE_STATE::QUEUE_NON_EMPTY: {
            // We found the queue empty, but the flag says it should still be full.
            // We want to declare it empty, bet we need to re-check the queue in a sub-step to avoid races.
            data_structures::stamped_integer
126
                maybe_empty_flag{target_queue_empty_flag.stamp_ + 1, EMPTY_QUEUE_STATE::QUEUE_MAYBE_EMPTY};
127
            if (target_state.get_queue_empty_flag().compare_exchange_strong(target_queue_empty_flag,
128 129
                                                                            maybe_empty_flag,
                                                                            std::memory_order_acq_rel)) {
130 131 132 133 134 135 136 137
              goto queue_empty_flag_retry_steal;
            }
            break;
          }
          case EMPTY_QUEUE_STATE::QUEUE_MAYBE_EMPTY: {
            // We found the queue empty and it was already marked as maybe empty.
            // We can safely mark it empty and increment the central counter.
            data_structures::stamped_integer
138
                empty_flag{target_queue_empty_flag.stamp_ + 1, EMPTY_QUEUE_STATE::QUEUE_EMPTY};
139 140 141
            if (target_state.get_queue_empty_flag().compare_exchange_strong(target_queue_empty_flag,
                                                                            empty_flag,
                                                                            std::memory_order_acq_rel)) {
142 143 144 145 146 147 148 149 150 151 152 153 154 155
              // We marked it empty, now its our duty to modify the central counter
              my_state.get_scheduler().empty_queue_increase_counter();
            }
            break;
          }
          case EMPTY_QUEUE_STATE::QUEUE_EMPTY: {
            // The queue was already marked empty, just do nothing
            break;
          }
        }
        // Disregarding if we found the thread empty, we should check if we can put ourself to sleep
        my_state.get_scheduler().empty_queue_try_sleep_worker();
#endif
      }
156 157
    }
  }
158
  my_state.set_scheduler_active(false);
159 160
}

161
void scheduler::sync_internal() {
162
  thread_state &syncing_state = thread_state::get();
163

164 165
  base_task *active_task = syncing_state.get_active_task();
  base_task *spawned_task = active_task->next_;
166

167
#if PLS_PROFILING_ENABLED
168 169 170 171 172 173 174 175 176
  syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(),
                                                                    active_task->stack_memory_,
                                                                    active_task->stack_size_,
                                                                    active_task->profiling_node_);
  syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(),
                                                            active_task->profiling_node_);
  auto *next_dag_node =
      syncing_state.get_scheduler().profiler_.task_sync(syncing_state.get_thread_id(), active_task->profiling_node_);
  active_task->profiling_node_ = next_dag_node;
177 178
#endif

179
  if (active_task->is_synchronized_) {
180
#if PLS_PROFILING_ENABLED
181 182
    thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(),
                                                                     thread_state::get().get_active_task()->profiling_node_);
183
#endif
184 185 186 187 188 189 190 191 192 193 194 195
    return; // We are already the sole owner of last_task
  } else {
    auto continuation =
        spawned_task->run_as_task([active_task, spawned_task, &syncing_state](context_switcher::continuation cont) {
          active_task->continuation_ = std::move(cont);
          syncing_state.set_active_task(spawned_task);
          return slow_return(syncing_state, true);
        });

    PLS_ASSERT(!continuation.valid(),
               "We only return to a sync point, never jump to it directly."
               "This must therefore never return an unfinished fiber/continuation.");
196
#if PLS_PROFILING_ENABLED
197 198
    thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(),
                                                                     thread_state::get().get_active_task()->profiling_node_);
199
#endif
200
    return; // We cleanly synced to the last one finishing work on last_task
201 202 203
  }
}

204
context_switcher::continuation scheduler::slow_return(thread_state &calling_state, bool in_sync) {
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
  base_task *this_task = calling_state.get_active_task();
  PLS_ASSERT(this_task->depth_ > 0,
             "Must never try to return from a task at level 0 (no last task), as we must have a target to return to.");
  base_task *last_task = this_task->prev_;

  // Slow return means we try to finish the child 'this_task' of 'last_task' and we
  // do not know if we are the last child to finish.
  // If we are not the last one, we get a spare task chain for our resources and can return to the main scheduling loop.
  base_task *pop_result = calling_state.get_task_manager().pop_clean_task_chain(last_task);

  if (pop_result != nullptr) {
    base_task *clean_chain = pop_result;

    // We got a clean chain to fill up our resources.
    PLS_ASSERT(last_task->depth_ == clean_chain->depth_,
               "Resources must only reside in the correct depth!");
    PLS_ASSERT(last_task != clean_chain,
               "We want to swap out the last task and its chain to use a clean one, thus they must differ.");
223 224
    PLS_ASSERT_EXPENSIVE(check_task_chain_backward(*clean_chain),
                         "Can only acquire clean chains for clean returns!");
225 226 227 228 229

    // Acquire it/merge it with our task chain.
    this_task->prev_ = clean_chain;
    clean_chain->next_ = this_task;

230 231 232 233
    // Keep locally owned resources consistent.
    auto *clean_resources = clean_chain->attached_resources_.load(std::memory_order_relaxed);
    strain_local_resource::acquire_locally(clean_resources, calling_state.get_thread_id());

234 235 236 237 238 239 240 241
    base_task *active_task = clean_chain;
    while (active_task->depth_ > 0) {
      active_task = active_task->prev_;
    }
    calling_state.set_active_task(active_task);

    // Jump back to the continuation in main scheduling loop.
    context_switcher::continuation result_cont = std::move(thread_state::get().main_continuation());
242
    PLS_ASSERT(result_cont.valid(), "Must return a valid continuation (to main).");
243
    return result_cont;
244 245 246 247 248 249 250 251 252
  } else {
    // Make sure that we are owner of this full continuation/task chain.
    last_task->next_ = this_task;

    // We are the last one working on this task. Thus the sync must be finished, continue working.
    calling_state.set_active_task(last_task);
    last_task->is_synchronized_ = true;

    // Jump to parent task and continue working on it.
253 254 255 256 257 258
    if (in_sync) {
      PLS_ASSERT(last_task->continuation_.valid(), "Must return a valid continuation (to last task) in sync.");
    } else {
      PLS_ASSERT(last_task->continuation_.valid(), "Must return a valid continuation (to last task) in spawn.");
    }

259
    context_switcher::continuation result_cont = std::move(last_task->continuation_);
260
    return result_cont;
261 262 263
  }
}

264 265
base_task *scheduler::get_trade_task(base_task *stolen_task, thread_state &calling_state) {
  // Get task itself
266
  base_task *result = calling_state.get_active_task();
267
  while (result->depth_ > stolen_task->depth_) {
268 269
    result = result->prev_;
  }
270
  while (result->depth_ < stolen_task->depth_) {
271 272 273
    result = result->next_;
  }

274 275 276 277 278 279
  // Attach other resources we need to trade to it
  auto *stolen_resources = stolen_task->attached_resources_.load(std::memory_order_relaxed);
  auto *traded_resources = strain_local_resource::get_local_copy(stolen_resources, calling_state.get_thread_id());
  result->attached_resources_.store(traded_resources, std::memory_order_relaxed);

  return result;
280 281
}

282
void scheduler::terminate() {
283 284 285 286 287 288 289
  if (terminated_) {
    return;
  }

  terminated_ = true;
  sync_barrier_.wait();

290 291 292
  for (unsigned int i = 0; i < num_threads_; i++) {
    if (reuse_thread_ && i == 0) {
      continue;
293
    }
294
    worker_threads_[i].join();
295 296 297
  }
}

298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321
bool scheduler::check_task_chain_forward(base_task &start_task) {
  base_task *current = &start_task;
  while (current->next_) {
    if (current->next_->prev_ != current) {
      return false;
    }
    current = current->next_;
  }
  return true;
}

bool scheduler::check_task_chain_backward(base_task &start_task) {
  base_task *current = &start_task;
  while (current->prev_) {
    if (current->prev_->next_ != current) {
      return false;
    }
    current = current->prev_;
  }
  return true;
}

bool scheduler::check_task_chain(base_task &start_task) {
  return check_task_chain_backward(start_task) && check_task_chain_forward(start_task);
322 323
}

324 325 326 327 328
#if PLS_SLEEP_WORKERS_ON_EMPTY
// TODO: relax memory orderings

void scheduler::empty_queue_try_sleep_worker() {
  int32_t counter_value = empty_queue_counter_.load();
329 330 331 332
  if (counter_value == num_threads()) {
#if PLS_PROFILING_ENABLED
    get_profiler().sleep_start(thread_state::get().get_thread_id());
#endif
333
    threads_sleeping_++;
334
    base::futex_wait((int32_t *) &empty_queue_counter_, num_threads());
335
    threads_sleeping_--;
336 337 338 339
    base::futex_wakeup((int32_t *) &empty_queue_counter_, 1);
#if PLS_PROFILING_ENABLED
    get_profiler().sleep_stop(thread_state::get().get_thread_id());
#endif
340 341 342 343
  }
}

void scheduler::empty_queue_increase_counter() {
344
  empty_queue_counter_.fetch_add(1);
345 346 347 348 349 350 351 352 353 354
}

void scheduler::empty_queue_decrease_counter_and_wake() {
  empty_queue_counter_.fetch_sub(1);
  if (threads_sleeping_.load() > 0) {
    base::futex_wakeup((int32_t *) &empty_queue_counter_, 1);
  }
}
#endif

355
}