task_manager.cpp 8.61 KB
Newer Older
1 2
#include <tuple>

3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
#include "pls/internal/scheduling/task_manager.h"

#include "pls/internal/scheduling/task.h"
#include "pls/internal/scheduling/thread_state.h"

namespace pls {
namespace internal {
namespace scheduling {

task_manager::task_manager(task *tasks,
                           data_structures::aligned_stack static_stack_space,
                           size_t num_tasks,
                           size_t stack_size,
                           external_trading_deque &deque) : num_tasks_{num_tasks},
                                                            this_thread_tasks_{tasks},
                                                            active_task_{&tasks[0]},
                                                            deque_{deque} {
  for (size_t i = 0; i < num_tasks - 1; i++) {
    tasks[i].init(static_stack_space.push_bytes(stack_size), stack_size, i, 0);
    if (i > 0) {
      tasks[i].prev_ = &tasks[i - 1];
    }
    if (i < num_tasks - 2) {
      tasks[i].next_ = &tasks[i + 1];
    }
  }
}

static task *find_task(unsigned id, unsigned depth) {
  return thread_state::get().get_scheduler().thread_state_for(id).get_task_manager().get_this_thread_task(depth);
}

35
task *task_manager::steal_task(task_manager &stealing_task_manager) {
36
  PLS_ASSERT(stealing_task_manager.active_task_->depth_ == 0, "Must only steal with clean task chain.");
37
  PLS_ASSERT(stealing_task_manager.check_task_chain(), "Must only steal with clean task chain.");
38 39

  auto peek = deque_.peek_top();
40
  if (peek.top_task_) {
41
    // search for the task we want to trade in
42
    task *stolen_task = *peek.top_task_;
43
    task *traded_task = stealing_task_manager.active_task_;
44
    for (unsigned i = 0; i < stolen_task->depth_; i++) {
45 46 47
      traded_task = traded_task->next_;
    }

48 49
    // keep a reference to the rest of the task chain that we keep
    task *next_own_task = traded_task->next_;
50
    // 'unchain' the traded tasks (to help us find bugs)
51 52
    traded_task->next_ = nullptr;

53 54 55 56
    // perform the actual pop operation
    auto pop_result_task = deque_.pop_top(traded_task, peek.top_pointer_);
    if (pop_result_task) {
      PLS_ASSERT(stolen_task->thread_id_ != traded_task->thread_id_,
57
                 "It is impossible to steal an task we already own!");
58
      PLS_ASSERT(*pop_result_task == stolen_task,
59
                 "We must only steal the task that we peeked at!");
60

61
      // the steal was a success, link the chain so we own the stolen part
62 63 64
      stolen_task->next_ = next_own_task;
      next_own_task->prev_ = stolen_task;
      stealing_task_manager.active_task_ = stolen_task;
65

66
      return traded_task;
67
    } else {
68 69 70
      // the steal failed, reset our chain to its old, clean state (re-link what we have broken)
      traded_task->next_ = next_own_task;

71
      return nullptr;
72 73
    }
  } else {
74
    return nullptr;
75
  }
76 77
}

78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
void task_manager::push_resource_on_task(task *target_task, task *spare_task_chain) {
  PLS_ASSERT(check_task_chain_backward(spare_task_chain), "Must only push proper task chains.");
  PLS_ASSERT(target_task->thread_id_ != spare_task_chain->thread_id_,
             "Makes no sense to push task onto itself, as it is not clean by definition.");
  PLS_ASSERT(target_task->depth_ == spare_task_chain->depth_, "Must only push tasks with correct depth.");

  data_structures::stamped_integer current_root;
  data_structures::stamped_integer target_root;
  do {
    current_root = target_task->resource_stack_root_.load();
    target_root.stamp = current_root.stamp + 1;
    target_root.value = spare_task_chain->thread_id_ + 1;

    if (current_root.value == 0) {
      // Empty, simply push in with no successor
93
      spare_task_chain->resource_stack_next_.store(nullptr, std::memory_order_relaxed);
94 95 96
    } else {
      // Already an entry. Find it's corresponding task and set it as our successor.
      auto *current_root_task = find_task(current_root.value - 1, target_task->depth_);
97
      spare_task_chain->resource_stack_next_.store(current_root_task, std::memory_order_relaxed);
98 99 100 101 102
    }

  } while (!target_task->resource_stack_root_.compare_exchange_strong(current_root, target_root));
}

103 104 105
task *task_manager::pop_resource_from_task(task *target_task) {
  data_structures::stamped_integer current_root;
  data_structures::stamped_integer target_root;
106
  task *output_task;
107 108 109 110 111 112 113 114
  do {
    current_root = target_task->resource_stack_root_.load();
    if (current_root.value == 0) {
      // Empty...
      return nullptr;
    } else {
      // Found something, try to pop it
      auto *current_root_task = find_task(current_root.value - 1, target_task->depth_);
115 116 117
      auto *next_stack_task = current_root_task->resource_stack_next_.load(std::memory_order_relaxed);

      target_root.stamp = current_root.stamp + 1;
118
      target_root.value = next_stack_task != nullptr ? next_stack_task->thread_id_ + 1 : 0;
119 120

      output_task = current_root_task;
121
    }
122 123
  } while (!target_task->resource_stack_root_.compare_exchange_strong(current_root, target_root));

124
  PLS_ASSERT(check_task_chain_backward(output_task), "Must only pop proper task chains.");
125
  output_task->resource_stack_next_.store(nullptr, std::memory_order_relaxed);
126
  return output_task;
127 128
}

129
void task_manager::sync() {
130 131 132
  auto *spawning_task_manager = this;
  auto *last_task = spawning_task_manager->active_task_;
  auto *spawned_task = spawning_task_manager->active_task_->next_;
133

134
  auto continuation = spawned_task->run_as_task([=](context_switcher::continuation cont) {
135
    last_task->continuation_ = std::move(cont);
136
    spawning_task_manager->active_task_ = spawned_task;
137 138

    context_switcher::continuation result_cont;
139
    if (spawning_task_manager->try_clean_return(result_cont)) {
140 141 142 143 144 145 146 147
      // We return back to the main scheduling loop
      return result_cont;
    } else {
      // We finish up the last task
      return result_cont;
    }
  });

148 149 150
  PLS_ASSERT(!continuation.valid(),
             "We only return to a sync point, never jump to it directly."
             "This must therefore never return an unfinished fiber/continuation.");
151 152 153 154 155 156
}

bool task_manager::try_clean_return(context_switcher::continuation &result_cont) {
  task *this_task = active_task_;
  task *last_task = active_task_->prev_;

157 158
  PLS_ASSERT(last_task != nullptr,
             "Must never try to return from a task at level 0 (no last task), as we must have a target to return to.");
159 160 161 162 163

  // Try to get a clean resource chain to go back to the main stealing loop
  task *clean_chain = pop_resource_from_task(last_task);
  if (clean_chain == nullptr) {
    // double-check if we are really last one or we only have unlucky timing
164 165 166 167 168
    auto optional_cas_task = external_trading_deque::get_trade_object(last_task);
    if (optional_cas_task) {
      clean_chain = *optional_cas_task;
    } else {
      clean_chain = pop_resource_from_task(last_task);
169 170 171 172 173 174 175
    }
  }

  if (clean_chain != nullptr) {
    // We got a clean chain to continue working on.
    PLS_ASSERT(last_task->depth_ == clean_chain->depth_,
               "Resources must only reside in the correct depth!");
176 177 178 179
    PLS_ASSERT(clean_chain != last_task,
               "We want to swap out the last task and its chain to use a clean one, thus they must differ.");
    PLS_ASSERT(check_task_chain_backward(clean_chain),
               "Can only acquire clean chains for clean returns!");
180 181
    this_task->prev_ = clean_chain;
    clean_chain->next_ = this_task;
182

183 184 185 186 187 188
    // Walk back chain to make first task active
    active_task_ = clean_chain;
    while (active_task_->prev_ != nullptr) {
      active_task_ = active_task_->prev_;
    }

189
    // jump back to the continuation in main scheduling loop, time to steal some work
190 191 192
    result_cont = thread_state::get().get_main_continuation();
    return true;
  } else {
193 194 195 196
    // Make sure that we are owner fo this full continuation/task chain.
    last_task->next_ = this_task;
    this_task->prev_ = last_task;

197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
    // We are the last one working on this task. Thus the sync must be finished, continue working.
    active_task_ = last_task;

    result_cont = std::move(last_task->continuation_);
    return false;
  }
}

bool task_manager::check_task_chain_forward(task *start_task) {
  while (start_task->next_ != nullptr) {
    PLS_ASSERT(start_task->next_->prev_ == start_task, "Chain must have correct prev/next fields for linked list!");
    start_task = start_task->next_;
  }
  return true;
}

bool task_manager::check_task_chain_backward(task *start_task) {
  while (start_task->prev_ != nullptr) {
    PLS_ASSERT(start_task->prev_->next_ == start_task, "Chain must have correct prev/next fields for linked list!");
    start_task = start_task->prev_;
  }
  return true;
}

bool task_manager::check_task_chain() {
  check_task_chain_backward(active_task_);
  check_task_chain_forward(active_task_);

  return true;
}

228 229 230
}
}
}