scheduler.cpp 7.29 KB
Newer Older
1
#include "pls/internal/scheduling/scheduler.h"
2

3
#include "context_switcher/context_switcher.h"
4
#include "pls/internal/base/error_handling.h"
5

6 7
#include <thread>

8
namespace pls::internal::scheduling {
9

10 11 12 13 14 15 16 17
scheduler::scheduler(unsigned int num_threads,
                     size_t computation_depth,
                     size_t stack_size,
                     bool reuse_thread) : scheduler(num_threads,
                                                    computation_depth,
                                                    stack_size,
                                                    reuse_thread,
                                                    base::mmap_stack_allocator{}) {}
18 19 20 21

scheduler::~scheduler() {
  terminate();
}
22 23
void scheduler::work_thread_main_loop() {
  auto &scheduler = thread_state::get().get_scheduler();
24
  while (true) {
25
    // Wait to be triggered
26
    scheduler.sync_barrier_.wait();
27 28

    // Check for shutdown
29
    if (scheduler.terminated_) {
30 31 32
      return;
    }

33
    scheduler.work_thread_work_section();
34

35
    // Sync back with main thread
36 37 38 39 40
    scheduler.sync_barrier_.wait();
  }
}

void scheduler::work_thread_work_section() {
41 42
  thread_state &my_state = thread_state::get();
  unsigned const num_threads = my_state.get_scheduler().num_threads();
43

44
  if (my_state.get_thread_id() == 0) {
45
    // Main Thread, kick off by executing the user's main code block.
46
    main_thread_starter_function_->run();
47
  }
48

49
  unsigned int failed_steals = 0;
50
  while (!work_section_done_) {
51 52 53 54 55 56 57 58
    PLS_ASSERT(check_task_chain(*my_state.get_active_task()), "Must start stealing with a clean task chain.");

    size_t target;
    do {
      target = my_state.get_rand() % num_threads;
    } while (target == my_state.get_thread_id());

    thread_state &target_state = my_state.get_scheduler().thread_state_for(target);
59
    auto[stolen_task, traded_task] = target_state.get_task_manager().steal_task(my_state);
60
    if (stolen_task) {
61 62 63 64
      // Keep task chain consistent. We want to appear as if we are working an a branch upwards of the stolen task.
      base_task *next_own_task = traded_task->next_;
      stolen_task->next_ = next_own_task;
      next_own_task->prev_ = stolen_task;
65
      my_state.set_active_task(stolen_task);
66

67
      PLS_ASSERT(check_task_chain_forward(*my_state.get_active_task()),
68 69 70 71 72
                 "We are sole owner of this chain, it has to be valid!");

      // Execute the stolen task by jumping to it's continuation.
      PLS_ASSERT(stolen_task->continuation_.valid(),
                 "A task that we can steal must have a valid continuation for us to start working.");
73
      stolen_task->is_synchronized_ = false;
74 75
      context_switcher::switch_context(std::move(stolen_task->continuation_));
      // We will continue execution in this line when we finished the stolen work.
76 77 78 79
      failed_steals = 0;
    } else {
      failed_steals++;
      if (failed_steals >= num_threads) {
80
        std::this_thread::yield();
81
      }
82 83
    }
  }
84 85
}

86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
void scheduler::sync() {
  thread_state &syncing_state = thread_state::get();

  base_task *active_task = syncing_state.get_active_task();
  base_task *spawned_task = active_task->next_;

  if (active_task->is_synchronized_) {
    return; // We are already the sole owner of last_task
  } else {
    auto continuation =
        spawned_task->run_as_task([active_task, spawned_task, &syncing_state](context_switcher::continuation cont) {
          active_task->continuation_ = std::move(cont);
          syncing_state.set_active_task(spawned_task);

          return slow_return(syncing_state);
        });

    PLS_ASSERT(!continuation.valid(),
               "We only return to a sync point, never jump to it directly."
               "This must therefore never return an unfinished fiber/continuation.");
    return; // We cleanly synced to the last one finishing work on last_task
  }
}

context_switcher::continuation scheduler::slow_return(thread_state &calling_state) {
  base_task *this_task = calling_state.get_active_task();
  PLS_ASSERT(this_task->depth_ > 0,
             "Must never try to return from a task at level 0 (no last task), as we must have a target to return to.");
  base_task *last_task = this_task->prev_;

  // Slow return means we try to finish the child 'this_task' of 'last_task' and we
  // do not know if we are the last child to finish.
  // If we are not the last one, we get a spare task chain for our resources and can return to the main scheduling loop.
  base_task *pop_result = calling_state.get_task_manager().pop_clean_task_chain(last_task);

  if (pop_result != nullptr) {
    base_task *clean_chain = pop_result;

    // We got a clean chain to fill up our resources.
    PLS_ASSERT(last_task->depth_ == clean_chain->depth_,
               "Resources must only reside in the correct depth!");
    PLS_ASSERT(last_task != clean_chain,
               "We want to swap out the last task and its chain to use a clean one, thus they must differ.");
    PLS_ASSERT(check_task_chain_backward(*clean_chain),
               "Can only acquire clean chains for clean returns!");

    // Acquire it/merge it with our task chain.
    this_task->prev_ = clean_chain;
    clean_chain->next_ = this_task;

    base_task *active_task = clean_chain;
    while (active_task->depth_ > 0) {
      active_task = active_task->prev_;
    }
    calling_state.set_active_task(active_task);

    // Jump back to the continuation in main scheduling loop.
    context_switcher::continuation result_cont = std::move(thread_state::get().main_continuation());
    PLS_ASSERT(result_cont.valid(), "Must return a valid continuation.");
    return result_cont;
  } else {
    // Make sure that we are owner of this full continuation/task chain.
    last_task->next_ = this_task;

    // We are the last one working on this task. Thus the sync must be finished, continue working.
    calling_state.set_active_task(last_task);
    last_task->is_synchronized_ = true;

    // Jump to parent task and continue working on it.
    context_switcher::continuation result_cont = std::move(last_task->continuation_);
    PLS_ASSERT(result_cont.valid(), "Must return a valid continuation.");
    return result_cont;
  }
}

base_task &scheduler::task_chain_at(unsigned int depth, thread_state &calling_state) {
  // TODO: possible optimize with cache array at steal events
  base_task *result = calling_state.get_active_task();
  while (result->depth_ > depth) {
    result = result->prev_;
  }
  while (result->depth_ < depth) {
    result = result->next_;
  }

  return *result;
}

174
void scheduler::terminate() {
175 176 177 178 179 180 181
  if (terminated_) {
    return;
  }

  terminated_ = true;
  sync_barrier_.wait();

182 183 184
  for (unsigned int i = 0; i < num_threads_; i++) {
    if (reuse_thread_ && i == 0) {
      continue;
185
    }
186
    worker_threads_[i].join();
187 188 189
  }
}

190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
bool scheduler::check_task_chain_forward(base_task &start_task) {
  base_task *current = &start_task;
  while (current->next_) {
    if (current->next_->prev_ != current) {
      return false;
    }
    current = current->next_;
  }
  return true;
}

bool scheduler::check_task_chain_backward(base_task &start_task) {
  base_task *current = &start_task;
  while (current->prev_) {
    if (current->prev_->next_ != current) {
      return false;
    }
    current = current->prev_;
  }
  return true;
}

bool scheduler::check_task_chain(base_task &start_task) {
  return check_task_chain_backward(start_task) && check_task_chain_forward(start_task);
214 215
}

216
}