task_manager.cpp 8.68 KB
Newer Older
1 2
#include <tuple>

3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
#include "pls/internal/scheduling/task_manager.h"

#include "pls/internal/scheduling/task.h"
#include "pls/internal/scheduling/thread_state.h"

namespace pls {
namespace internal {
namespace scheduling {

task_manager::task_manager(task *tasks,
                           data_structures::aligned_stack static_stack_space,
                           size_t num_tasks,
                           size_t stack_size,
                           external_trading_deque &deque) : num_tasks_{num_tasks},
                                                            this_thread_tasks_{tasks},
                                                            active_task_{&tasks[0]},
                                                            deque_{deque} {
  for (size_t i = 0; i < num_tasks - 1; i++) {
    tasks[i].init(static_stack_space.push_bytes(stack_size), stack_size, i, 0);
    if (i > 0) {
      tasks[i].prev_ = &tasks[i - 1];
    }
    if (i < num_tasks - 2) {
      tasks[i].next_ = &tasks[i + 1];
    }
  }
}

static task *find_task(unsigned id, unsigned depth) {
  return thread_state::get().get_scheduler().thread_state_for(id).get_task_manager().get_this_thread_task(depth);
}

35
task *task_manager::steal_task(task_manager &stealing_task_manager) {
36
  PLS_ASSERT(stealing_task_manager.active_task_->depth_ == 0, "Must only steal with clean task chain.");
37
  PLS_ASSERT(stealing_task_manager.check_task_chain(), "Must only steal with clean task chain.");
38 39

  auto peek = deque_.peek_top();
40
  if (peek.top_task_) {
41
    // search for the task we want to trade in
42
    task *stolen_task = *peek.top_task_;
43
    task *traded_task = stealing_task_manager.active_task_;
44
    for (unsigned i = 0; i < stolen_task->depth_; i++) {
45 46 47
      traded_task = traded_task->next_;
    }

48 49
    // keep a reference to the rest of the task chain that we keep
    task *next_own_task = traded_task->next_;
50
    // 'unchain' the traded tasks (to help us find bugs)
51 52
    traded_task->next_ = nullptr;

53 54 55 56
    // perform the actual pop operation
    auto pop_result_task = deque_.pop_top(traded_task, peek.top_pointer_);
    if (pop_result_task) {
      PLS_ASSERT(stolen_task->thread_id_ != traded_task->thread_id_,
57
                 "It is impossible to steal an task we already own!");
58
      PLS_ASSERT(*pop_result_task == stolen_task,
59
                 "We must only steal the task that we peeked at!");
60

61
      // the steal was a success, link the chain so we own the stolen part
62 63 64
      stolen_task->next_ = next_own_task;
      next_own_task->prev_ = stolen_task;
      stealing_task_manager.active_task_ = stolen_task;
65

66
      return traded_task;
67
    } else {
68 69 70
      // the steal failed, reset our chain to its old, clean state (re-link what we have broken)
      traded_task->next_ = next_own_task;

71
      return nullptr;
72 73
    }
  } else {
74
    return nullptr;
75
  }
76 77
}

78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
void task_manager::push_resource_on_task(task *target_task, task *spare_task_chain) {
  PLS_ASSERT(check_task_chain_backward(spare_task_chain), "Must only push proper task chains.");
  PLS_ASSERT(target_task->thread_id_ != spare_task_chain->thread_id_,
             "Makes no sense to push task onto itself, as it is not clean by definition.");
  PLS_ASSERT(target_task->depth_ == spare_task_chain->depth_, "Must only push tasks with correct depth.");

  data_structures::stamped_integer current_root;
  data_structures::stamped_integer target_root;
  do {
    current_root = target_task->resource_stack_root_.load();
    target_root.stamp = current_root.stamp + 1;
    target_root.value = spare_task_chain->thread_id_ + 1;

    if (current_root.value == 0) {
      // Empty, simply push in with no successor
93
      spare_task_chain->resource_stack_next_.store(nullptr, std::memory_order_relaxed);
94 95 96
    } else {
      // Already an entry. Find it's corresponding task and set it as our successor.
      auto *current_root_task = find_task(current_root.value - 1, target_task->depth_);
97
      spare_task_chain->resource_stack_next_.store(current_root_task, std::memory_order_relaxed);
98 99 100 101 102
    }

  } while (!target_task->resource_stack_root_.compare_exchange_strong(current_root, target_root));
}

103 104 105
task *task_manager::pop_resource_from_task(task *target_task) {
  data_structures::stamped_integer current_root;
  data_structures::stamped_integer target_root;
106
  task *output_task;
107 108 109 110 111 112 113 114
  do {
    current_root = target_task->resource_stack_root_.load();
    if (current_root.value == 0) {
      // Empty...
      return nullptr;
    } else {
      // Found something, try to pop it
      auto *current_root_task = find_task(current_root.value - 1, target_task->depth_);
115 116 117
      auto *next_stack_task = current_root_task->resource_stack_next_.load(std::memory_order_relaxed);

      target_root.stamp = current_root.stamp + 1;
118
      target_root.value = next_stack_task != nullptr ? next_stack_task->thread_id_ + 1 : 0;
119 120

      output_task = current_root_task;
121
    }
122 123
  } while (!target_task->resource_stack_root_.compare_exchange_strong(current_root, target_root));

124
  PLS_ASSERT(check_task_chain_backward(output_task), "Must only pop proper task chains.");
125
  output_task->resource_stack_next_.store(nullptr, std::memory_order_relaxed);
126
  return output_task;
127 128
}

129
void task_manager::sync() {
130 131 132
  auto *spawning_task_manager = this;
  auto *last_task = spawning_task_manager->active_task_;
  auto *spawned_task = spawning_task_manager->active_task_->next_;
133

134
  auto continuation = spawned_task->run_as_task([=](context_switcher::continuation cont) {
135
    last_task->continuation_ = std::move(cont);
136
    spawning_task_manager->active_task_ = spawned_task;
137 138

    context_switcher::continuation result_cont;
139
    if (spawning_task_manager->try_clean_return(result_cont)) {
140 141 142 143 144 145 146 147 148 149
      // We return back to the main scheduling loop
      active_task_->clean_ = true;
      return result_cont;
    } else {
      // We finish up the last task
      active_task_->clean_ = false;
      return result_cont;
    }
  });

150 151 152
  PLS_ASSERT(!continuation.valid(),
             "We only return to a sync point, never jump to it directly."
             "This must therefore never return an unfinished fiber/continuation.");
153 154 155 156 157 158
}

bool task_manager::try_clean_return(context_switcher::continuation &result_cont) {
  task *this_task = active_task_;
  task *last_task = active_task_->prev_;

159 160
  PLS_ASSERT(last_task != nullptr,
             "Must never try to return from a task at level 0 (no last task), as we must have a target to return to.");
161 162 163 164 165

  // Try to get a clean resource chain to go back to the main stealing loop
  task *clean_chain = pop_resource_from_task(last_task);
  if (clean_chain == nullptr) {
    // double-check if we are really last one or we only have unlucky timing
166 167 168 169 170
    auto optional_cas_task = external_trading_deque::get_trade_object(last_task);
    if (optional_cas_task) {
      clean_chain = *optional_cas_task;
    } else {
      clean_chain = pop_resource_from_task(last_task);
171 172 173 174 175 176 177
    }
  }

  if (clean_chain != nullptr) {
    // We got a clean chain to continue working on.
    PLS_ASSERT(last_task->depth_ == clean_chain->depth_,
               "Resources must only reside in the correct depth!");
178 179 180 181
    PLS_ASSERT(clean_chain != last_task,
               "We want to swap out the last task and its chain to use a clean one, thus they must differ.");
    PLS_ASSERT(check_task_chain_backward(clean_chain),
               "Can only acquire clean chains for clean returns!");
182 183
    this_task->prev_ = clean_chain;
    clean_chain->next_ = this_task;
184

185 186 187 188 189 190
    // Walk back chain to make first task active
    active_task_ = clean_chain;
    while (active_task_->prev_ != nullptr) {
      active_task_ = active_task_->prev_;
    }

191
    // jump back to the continuation in main scheduling loop, time to steal some work
192 193 194
    result_cont = thread_state::get().get_main_continuation();
    return true;
  } else {
195 196 197 198
    // Make sure that we are owner fo this full continuation/task chain.
    last_task->next_ = this_task;
    this_task->prev_ = last_task;

199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
    // We are the last one working on this task. Thus the sync must be finished, continue working.
    active_task_ = last_task;

    result_cont = std::move(last_task->continuation_);
    return false;
  }
}

bool task_manager::check_task_chain_forward(task *start_task) {
  while (start_task->next_ != nullptr) {
    PLS_ASSERT(start_task->next_->prev_ == start_task, "Chain must have correct prev/next fields for linked list!");
    start_task = start_task->next_;
  }
  return true;
}

bool task_manager::check_task_chain_backward(task *start_task) {
  while (start_task->prev_ != nullptr) {
    PLS_ASSERT(start_task->prev_->next_ == start_task, "Chain must have correct prev/next fields for linked list!");
    start_task = start_task->prev_;
  }
  return true;
}

bool task_manager::check_task_chain() {
  check_task_chain_backward(active_task_);
  check_task_chain_forward(active_task_);

  return true;
}

230 231 232
}
}
}