scheduler.cpp 12.5 KB
Newer Older
1
#include "pls/internal/scheduling/scheduler.h"
2

3
#include "context_switcher/context_switcher.h"
4 5 6

#include "pls/internal/build_flavour.h"
#include "pls/internal/base/futex_wrapper.h"
7
#include "pls/internal/base/error_handling.h"
8

9 10
#include <thread>

11
namespace pls::internal::scheduling {
12

13 14 15 16 17 18 19 20
scheduler::scheduler(unsigned int num_threads,
                     size_t computation_depth,
                     size_t stack_size,
                     bool reuse_thread) : scheduler(num_threads,
                                                    computation_depth,
                                                    stack_size,
                                                    reuse_thread,
                                                    base::mmap_stack_allocator{}) {}
21 22 23 24

scheduler::~scheduler() {
  terminate();
}
25 26
void scheduler::work_thread_main_loop() {
  auto &scheduler = thread_state::get().get_scheduler();
27
  while (true) {
28
    // Wait to be triggered
29
    scheduler.sync_barrier_.wait();
30 31

    // Check for shutdown
32
    if (scheduler.terminated_) {
33 34 35
      return;
    }

36
    scheduler.work_thread_work_section();
37

38
    // Sync back with main thread
39 40 41 42 43
    scheduler.sync_barrier_.wait();
  }
}

void scheduler::work_thread_work_section() {
44 45
  thread_state &my_state = thread_state::get();
  unsigned const num_threads = my_state.get_scheduler().num_threads();
46

47
  if (my_state.get_thread_id() == 0) {
48
    // Main Thread, kick off by executing the user's main code block.
49
    main_thread_starter_function_->run();
50
  }
51

52
  unsigned int failed_steals = 0;
53
  while (!work_section_done_) {
54 55 56
#if PLS_PROFILING_ENABLED
    my_state.get_scheduler().profiler_.stealing_start(my_state.get_thread_id());
#endif
57 58 59 60 61 62 63 64
    PLS_ASSERT(check_task_chain(*my_state.get_active_task()), "Must start stealing with a clean task chain.");

    size_t target;
    do {
      target = my_state.get_rand() % num_threads;
    } while (target == my_state.get_thread_id());

    thread_state &target_state = my_state.get_scheduler().thread_state_for(target);
65 66 67 68 69
#if PLS_SLEEP_WORKERS_ON_EMPTY
    queue_empty_flag_retry_steal:
    // TODO: relax atomics for empty flag
    data_structures::stamped_integer target_queue_empty_flag = target_state.get_queue_empty_flag().load();
#endif
70 71 72 73 74 75 76 77 78 79

    base_task *stolen_task;
    base_task *chain_after_stolen_task;
    bool cas_success;
    do {
#if PLS_PROFILING_ENABLED
      my_state.get_scheduler().profiler_.stealing_cas_op(my_state.get_thread_id());
#endif
      std::tie(stolen_task, chain_after_stolen_task, cas_success) =
          target_state.get_task_manager().steal_task(my_state);
80
    } while (!cas_success); // re-try on cas-conflicts, as in classic ws
81

82
    if (stolen_task) {
83
      // Keep task chain consistent. We want to appear as if we are working an a branch upwards of the stolen task.
84 85
      stolen_task->next_ = chain_after_stolen_task;
      chain_after_stolen_task->prev_ = stolen_task;
86
      my_state.set_active_task(stolen_task);
87

88
      PLS_ASSERT(check_task_chain_forward(*my_state.get_active_task()),
89 90 91 92 93
                 "We are sole owner of this chain, it has to be valid!");

      // Execute the stolen task by jumping to it's continuation.
      PLS_ASSERT(stolen_task->continuation_.valid(),
                 "A task that we can steal must have a valid continuation for us to start working.");
94
      stolen_task->is_synchronized_ = false;
95 96 97 98 99
#if PLS_PROFILING_ENABLED
      my_state.get_scheduler().profiler_.stealing_end(my_state.get_thread_id(), true);
      my_state.get_scheduler().profiler_.task_start_running(my_state.get_thread_id(),
                                                            stolen_task->profiling_node_);
#endif
100 101
      context_switcher::switch_context(std::move(stolen_task->continuation_));
      // We will continue execution in this line when we finished the stolen work.
102 103
      failed_steals = 0;
    } else {
104 105 106
      // TODO: tune value for when we start yielding
      const unsigned YIELD_AFTER = 1;

107
      failed_steals++;
108
      if (failed_steals >= YIELD_AFTER) {
109
        std::this_thread::yield();
110
#if PLS_PROFILING_ENABLED
111
        my_state.get_scheduler().profiler_.stealing_end(my_state.get_thread_id(), false);
112
#endif
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
#if PLS_SLEEP_WORKERS_ON_EMPTY
        switch (target_queue_empty_flag.value) {
          case EMPTY_QUEUE_STATE::QUEUE_NON_EMPTY: {
            // We found the queue empty, but the flag says it should still be full.
            // We want to declare it empty, bet we need to re-check the queue in a sub-step to avoid races.
            data_structures::stamped_integer
                maybe_empty_flag{target_queue_empty_flag.stamp + 1, EMPTY_QUEUE_STATE::QUEUE_MAYBE_EMPTY};
            if (target_state.get_queue_empty_flag().compare_exchange_strong(target_queue_empty_flag,
                                                                            maybe_empty_flag)) {
              goto queue_empty_flag_retry_steal;
            }
            break;
          }
          case EMPTY_QUEUE_STATE::QUEUE_MAYBE_EMPTY: {
            // We found the queue empty and it was already marked as maybe empty.
            // We can safely mark it empty and increment the central counter.
            data_structures::stamped_integer
                empty_flag{target_queue_empty_flag.stamp + 1, EMPTY_QUEUE_STATE::QUEUE_EMPTY};
            if (target_state.get_queue_empty_flag().compare_exchange_strong(target_queue_empty_flag, empty_flag)) {
              // We marked it empty, now its our duty to modify the central counter
              my_state.get_scheduler().empty_queue_increase_counter();
            }
            break;
          }
          case EMPTY_QUEUE_STATE::QUEUE_EMPTY: {
            // The queue was already marked empty, just do nothing
            break;
          }
        }
        // Disregarding if we found the thread empty, we should check if we can put ourself to sleep
        my_state.get_scheduler().empty_queue_try_sleep_worker();
#endif
      }
146 147
    }
  }
148 149
}

150 151 152 153 154 155
void scheduler::sync() {
  thread_state &syncing_state = thread_state::get();

  base_task *active_task = syncing_state.get_active_task();
  base_task *spawned_task = active_task->next_;

156 157 158 159 160 161 162 163 164 165 166 167
#if PLS_PROFILING_ENABLED
  syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(),
                                                                    active_task->stack_memory_,
                                                                    active_task->stack_size_,
                                                                    active_task->profiling_node_);
  syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(),
                                                            active_task->profiling_node_);
  auto *next_dag_node =
      syncing_state.get_scheduler().profiler_.task_sync(syncing_state.get_thread_id(), active_task->profiling_node_);
  active_task->profiling_node_ = next_dag_node;
#endif

168
  if (active_task->is_synchronized_) {
169
#if PLS_PROFILING_ENABLED
FritzFlorian committed
170 171
    thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(),
                                                                     thread_state::get().get_active_task()->profiling_node_);
172
#endif
173 174 175 176 177 178 179 180 181 182 183 184
    return; // We are already the sole owner of last_task
  } else {
    auto continuation =
        spawned_task->run_as_task([active_task, spawned_task, &syncing_state](context_switcher::continuation cont) {
          active_task->continuation_ = std::move(cont);
          syncing_state.set_active_task(spawned_task);
          return slow_return(syncing_state);
        });

    PLS_ASSERT(!continuation.valid(),
               "We only return to a sync point, never jump to it directly."
               "This must therefore never return an unfinished fiber/continuation.");
185 186 187 188
#if PLS_PROFILING_ENABLED
    thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(),
                                                                     thread_state::get().get_active_task()->profiling_node_);
#endif
189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
    return; // We cleanly synced to the last one finishing work on last_task
  }
}

context_switcher::continuation scheduler::slow_return(thread_state &calling_state) {
  base_task *this_task = calling_state.get_active_task();
  PLS_ASSERT(this_task->depth_ > 0,
             "Must never try to return from a task at level 0 (no last task), as we must have a target to return to.");
  base_task *last_task = this_task->prev_;

  // Slow return means we try to finish the child 'this_task' of 'last_task' and we
  // do not know if we are the last child to finish.
  // If we are not the last one, we get a spare task chain for our resources and can return to the main scheduling loop.
  base_task *pop_result = calling_state.get_task_manager().pop_clean_task_chain(last_task);

  if (pop_result != nullptr) {
    base_task *clean_chain = pop_result;

    // We got a clean chain to fill up our resources.
    PLS_ASSERT(last_task->depth_ == clean_chain->depth_,
               "Resources must only reside in the correct depth!");
    PLS_ASSERT(last_task != clean_chain,
               "We want to swap out the last task and its chain to use a clean one, thus they must differ.");
    PLS_ASSERT(check_task_chain_backward(*clean_chain),
               "Can only acquire clean chains for clean returns!");

    // Acquire it/merge it with our task chain.
    this_task->prev_ = clean_chain;
    clean_chain->next_ = this_task;

    base_task *active_task = clean_chain;
    while (active_task->depth_ > 0) {
      active_task = active_task->prev_;
    }
    calling_state.set_active_task(active_task);

    // Jump back to the continuation in main scheduling loop.
    context_switcher::continuation result_cont = std::move(thread_state::get().main_continuation());
    PLS_ASSERT(result_cont.valid(), "Must return a valid continuation.");
228
    return result_cont;
229 230 231 232 233 234 235 236 237 238 239
  } else {
    // Make sure that we are owner of this full continuation/task chain.
    last_task->next_ = this_task;

    // We are the last one working on this task. Thus the sync must be finished, continue working.
    calling_state.set_active_task(last_task);
    last_task->is_synchronized_ = true;

    // Jump to parent task and continue working on it.
    context_switcher::continuation result_cont = std::move(last_task->continuation_);
    PLS_ASSERT(result_cont.valid(), "Must return a valid continuation.");
240
    return result_cont;
241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
  }
}

base_task &scheduler::task_chain_at(unsigned int depth, thread_state &calling_state) {
  // TODO: possible optimize with cache array at steal events
  base_task *result = calling_state.get_active_task();
  while (result->depth_ > depth) {
    result = result->prev_;
  }
  while (result->depth_ < depth) {
    result = result->next_;
  }

  return *result;
}

257
void scheduler::terminate() {
258 259 260 261 262 263 264
  if (terminated_) {
    return;
  }

  terminated_ = true;
  sync_barrier_.wait();

265 266 267
  for (unsigned int i = 0; i < num_threads_; i++) {
    if (reuse_thread_ && i == 0) {
      continue;
268
    }
269
    worker_threads_[i].join();
270 271 272
  }
}

273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296
bool scheduler::check_task_chain_forward(base_task &start_task) {
  base_task *current = &start_task;
  while (current->next_) {
    if (current->next_->prev_ != current) {
      return false;
    }
    current = current->next_;
  }
  return true;
}

bool scheduler::check_task_chain_backward(base_task &start_task) {
  base_task *current = &start_task;
  while (current->prev_) {
    if (current->prev_->next_ != current) {
      return false;
    }
    current = current->prev_;
  }
  return true;
}

bool scheduler::check_task_chain(base_task &start_task) {
  return check_task_chain_backward(start_task) && check_task_chain_forward(start_task);
297 298
}

299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328
#if PLS_SLEEP_WORKERS_ON_EMPTY
// TODO: relax memory orderings

void scheduler::empty_queue_try_sleep_worker() {
  int32_t counter_value = empty_queue_counter_.load();
//  printf("try sleep thread %d...\n", counter_value);
  if (counter_value == num_threads() - 1) {
//    printf("sleeping thread...\n");
    threads_sleeping_++;
    base::futex_wait((int32_t *) &empty_queue_counter_, num_threads() - 1);
    base::futex_wakeup((int32_t *) &empty_queue_counter_, 1);
    threads_sleeping_--;
//    printf("waked thread...\n");
  }
}

void scheduler::empty_queue_increase_counter() {
  int32_t old_value = empty_queue_counter_.fetch_add(1);
//  printf("increase counter from %d", old_value);
}

void scheduler::empty_queue_decrease_counter_and_wake() {
  empty_queue_counter_.fetch_sub(1);
  if (threads_sleeping_.load() > 0) {
    // Threads could be sleeping, we MUST wake them up
    base::futex_wakeup((int32_t *) &empty_queue_counter_, 1);
  }
}
#endif

329
}