scheduler.cpp 3.69 KB
Newer Older
1
#include "pls/internal/scheduling/scheduler.h"
2 3 4
#include "pls/internal/scheduling/thread_state.h"
#include "pls/internal/scheduling/task.h"

5
#include "pls/internal/base/error_handling.h"
6 7

namespace pls {
8 9 10
namespace internal {
namespace scheduling {

11
scheduler::scheduler(scheduler_memory &memory, const unsigned int num_threads, bool reuse_thread) :
12
    num_threads_{num_threads},
13
    reuse_thread_{reuse_thread},
14
    memory_{memory},
15
    sync_barrier_{num_threads + 1 - reuse_thread},
16
    terminated_{false} {
17
  if (num_threads_ > memory.max_threads()) {
18 19 20 21 22
    PLS_ERROR("Tried to create scheduler with more OS threads than pre-allocated memory.");
  }

  for (unsigned int i = 0; i < num_threads_; i++) {
    // Placement new is required, as the memory of `memory_` is not required to be initialized.
23 24
    memory.thread_state_for(i).scheduler_ = this;
    memory.thread_state_for(i).id_ = i;
25 26 27 28

    if (reuse_thread && i == 0) {
      continue; // Skip over first/main thread when re-using the users thread, as this one will replace the first one.
    }
29
    memory.thread_for(i) = base::thread(&scheduler::work_thread_main_loop, &memory_.thread_state_for(i));
30 31 32 33 34 35 36
  }
}

scheduler::~scheduler() {
  terminate();
}

37 38
void scheduler::work_thread_main_loop() {
  auto &scheduler = thread_state::get().get_scheduler();
39
  while (true) {
40
    // Wait to be triggered
41
    scheduler.sync_barrier_.wait();
42 43

    // Check for shutdown
44
    if (scheduler.terminated_) {
45 46 47
      return;
    }

48
    scheduler.work_thread_work_section();
49

50
    // Sync back with main thread
51 52 53 54 55 56
    scheduler.sync_barrier_.wait();
  }
}

void scheduler::work_thread_work_section() {
  auto &my_state = thread_state::get();
57 58 59 60
  auto &my_cont_manager = my_state.get_cont_manager();

  auto const num_threads = my_state.get_scheduler().num_threads();
  auto const my_id = my_state.get_id();
61 62

  if (my_state.get_id() == 0) {
63
    // Main Thread, kick off by executing the user's main code block.
64
    main_thread_starter_function_->run();
65
  }
66 67

  do {
68 69 70 71
    // Work off pending continuations we need to execute locally
    while (my_cont_manager.falling_through()) {
      my_cont_manager.execute_fall_through_code();
    }
72
    my_cont_manager.check_clean_chain();
73 74 75 76 77 78 79 80 81 82 83 84 85

    // Steal Routine (will be continuously executed when there are no more fall through's).
    // TODO: move into separate function
    const size_t offset = my_state.random_() % num_threads;
    const size_t max_tries = num_threads - 1;
    for (size_t i = 0; i < max_tries; i++) {
      size_t target = (offset + i) % num_threads;

      // Skip our self for stealing
      target = ((target == my_id) + target) % num_threads;

      auto &target_state = my_state.get_scheduler().thread_state_for(target);

86
      my_cont_manager.check_clean_chain();
87 88 89 90 91
      auto *stolen_task = target_state.get_task_manager().steal_remote_task(my_cont_manager);
      if (stolen_task != nullptr) {
        my_state.parent_cont_ = stolen_task->get_cont();
        my_state.right_spawn_ = true;
        my_cont_manager.set_active_depth(stolen_task->get_cont()->get_memory_block()->get_depth() + 1);
92
        my_cont_manager.check_clean_chain();
93
        stolen_task->execute();
94 95 96 97
        if (my_cont_manager.falling_through()) {
          break;
        } else {
          my_cont_manager.fall_through_and_notify_cont(stolen_task->get_cont(), true);
98
          break;
99
        }
100 101
      }
    }
102
  } while (!work_section_done_);
103

104 105
}

106
void scheduler::terminate() {
107 108 109 110 111 112 113
  if (terminated_) {
    return;
  }

  terminated_ = true;
  sync_barrier_.wait();

114 115 116
  for (unsigned int i = 0; i < num_threads_; i++) {
    if (reuse_thread_ && i == 0) {
      continue;
117
    }
118
    memory_.thread_for(i).join();
119 120 121
  }
}

122
thread_state &scheduler::thread_state_for(size_t id) { return memory_.thread_state_for(id); }
123

124 125
}
}
126
}