task_manager.cpp 8.46 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
#include "pls/internal/scheduling/task_manager.h"

#include "pls/internal/scheduling/task.h"
#include "pls/internal/scheduling/thread_state.h"

namespace pls {
namespace internal {
namespace scheduling {

task_manager::task_manager(task *tasks,
                           data_structures::aligned_stack static_stack_space,
                           size_t num_tasks,
                           size_t stack_size,
                           external_trading_deque &deque) : num_tasks_{num_tasks},
                                                            this_thread_tasks_{tasks},
                                                            active_task_{&tasks[0]},
                                                            deque_{deque} {
  for (size_t i = 0; i < num_tasks - 1; i++) {
    tasks[i].init(static_stack_space.push_bytes(stack_size), stack_size, i, 0);
    if (i > 0) {
      tasks[i].prev_ = &tasks[i - 1];
    }
    if (i < num_tasks - 2) {
      tasks[i].next_ = &tasks[i + 1];
    }
  }
}

static task *find_task(unsigned id, unsigned depth) {
  return thread_state::get().get_scheduler().thread_state_for(id).get_task_manager().get_this_thread_task(depth);
}

33
task *task_manager::steal_task(task_manager &stealing_task_manager) {
34
  PLS_ASSERT(stealing_task_manager.active_task_->depth_ == 0, "Must only steal with clean task chain.");
35
  PLS_ASSERT(stealing_task_manager.check_task_chain(), "Must only steal with clean task chain.");
36 37

  auto peek = deque_.peek_top();
38
  if (peek.top_task_) {
39
    // search for the task we want to trade in
40
    task *stolen_task = *peek.top_task_;
41
    task *traded_task = stealing_task_manager.active_task_;
42
    for (unsigned i = 0; i < stolen_task->depth_; i++) {
43 44 45
      traded_task = traded_task->next_;
    }

46 47
    // keep a reference to the rest of the task chain that we keep
    task *next_own_task = traded_task->next_;
48
    // 'unchain' the traded tasks (to help us find bugs)
49 50
    traded_task->next_ = nullptr;

51 52 53 54
    // perform the actual pop operation
    auto pop_result_task = deque_.pop_top(traded_task, peek.top_pointer_);
    if (pop_result_task) {
      PLS_ASSERT(stolen_task->thread_id_ != traded_task->thread_id_,
55
                 "It is impossible to steal an task we already own!");
56
      PLS_ASSERT(*pop_result_task == stolen_task,
57
                 "We must only steal the task that we peeked at!");
58

59
      // the steal was a success, link the chain so we own the stolen part
60 61 62
      stolen_task->next_ = next_own_task;
      next_own_task->prev_ = stolen_task;
      stealing_task_manager.active_task_ = stolen_task;
63

64
      return traded_task;
65
    } else {
66 67 68
      // the steal failed, reset our chain to its old, clean state (re-link what we have broken)
      traded_task->next_ = next_own_task;

69
      return nullptr;
70 71
    }
  } else {
72
    return nullptr;
73
  }
74 75
}

76 77 78 79 80 81 82 83 84 85 86 87 88 89
void task_manager::push_resource_on_task(task *target_task, task *spare_task_chain) {
  PLS_ASSERT(target_task->thread_id_ != spare_task_chain->thread_id_,
             "Makes no sense to push task onto itself, as it is not clean by definition.");
  PLS_ASSERT(target_task->depth_ == spare_task_chain->depth_, "Must only push tasks with correct depth.");

  data_structures::stamped_integer current_root;
  data_structures::stamped_integer target_root;
  do {
    current_root = target_task->resource_stack_root_.load();
    target_root.stamp = current_root.stamp + 1;
    target_root.value = spare_task_chain->thread_id_ + 1;

    if (current_root.value == 0) {
      // Empty, simply push in with no successor
90
      spare_task_chain->resource_stack_next_.store(nullptr);
91 92 93
    } else {
      // Already an entry. Find it's corresponding task and set it as our successor.
      auto *current_root_task = find_task(current_root.value - 1, target_task->depth_);
94
      spare_task_chain->resource_stack_next_.store(current_root_task);
95 96 97 98 99
    }

  } while (!target_task->resource_stack_root_.compare_exchange_strong(current_root, target_root));
}

100 101 102
task *task_manager::pop_resource_from_task(task *target_task) {
  data_structures::stamped_integer current_root;
  data_structures::stamped_integer target_root;
103
  task *output_task;
104 105 106 107 108 109 110 111
  do {
    current_root = target_task->resource_stack_root_.load();
    if (current_root.value == 0) {
      // Empty...
      return nullptr;
    } else {
      // Found something, try to pop it
      auto *current_root_task = find_task(current_root.value - 1, target_task->depth_);
112
      auto *next_stack_task = current_root_task->resource_stack_next_.load();
113 114

      target_root.stamp = current_root.stamp + 1;
115
      target_root.value = next_stack_task != nullptr ? next_stack_task->thread_id_ + 1 : 0;
116 117

      output_task = current_root_task;
118
    }
119 120
  } while (!target_task->resource_stack_root_.compare_exchange_strong(current_root, target_root));

121
  PLS_ASSERT(check_task_chain_backward(output_task), "Must only pop proper task chains.");
122
  output_task->resource_stack_next_.store(nullptr);
123
  return output_task;
124 125
}

126
void task_manager::sync() {
127 128 129
  auto *spawning_task_manager = this;
  auto *last_task = spawning_task_manager->active_task_;
  auto *spawned_task = spawning_task_manager->active_task_->next_;
130

131
  auto continuation = spawned_task->run_as_task([=](context_switcher::continuation cont) {
132
    last_task->continuation_ = std::move(cont);
133
    spawning_task_manager->active_task_ = spawned_task;
134 135

    context_switcher::continuation result_cont;
136
    if (spawning_task_manager->try_clean_return(result_cont)) {
137 138 139 140 141 142 143 144
      // We return back to the main scheduling loop
      return result_cont;
    } else {
      // We finish up the last task
      return result_cont;
    }
  });

145 146 147
  PLS_ASSERT(!continuation.valid(),
             "We only return to a sync point, never jump to it directly."
             "This must therefore never return an unfinished fiber/continuation.");
148 149 150 151 152 153
}

bool task_manager::try_clean_return(context_switcher::continuation &result_cont) {
  task *this_task = active_task_;
  task *last_task = active_task_->prev_;

154 155
  PLS_ASSERT(last_task != nullptr,
             "Must never try to return from a task at level 0 (no last task), as we must have a target to return to.");
156 157 158 159 160

  // Try to get a clean resource chain to go back to the main stealing loop
  task *clean_chain = pop_resource_from_task(last_task);
  if (clean_chain == nullptr) {
    // double-check if we are really last one or we only have unlucky timing
161 162 163 164 165
    auto optional_cas_task = external_trading_deque::get_trade_object(last_task);
    if (optional_cas_task) {
      clean_chain = *optional_cas_task;
    } else {
      clean_chain = pop_resource_from_task(last_task);
166 167 168 169 170 171 172
    }
  }

  if (clean_chain != nullptr) {
    // We got a clean chain to continue working on.
    PLS_ASSERT(last_task->depth_ == clean_chain->depth_,
               "Resources must only reside in the correct depth!");
173 174 175 176
    PLS_ASSERT(clean_chain != last_task,
               "We want to swap out the last task and its chain to use a clean one, thus they must differ.");
    PLS_ASSERT(check_task_chain_backward(clean_chain),
               "Can only acquire clean chains for clean returns!");
177 178
    this_task->prev_ = clean_chain;
    clean_chain->next_ = this_task;
179

180 181 182 183 184 185
    // Walk back chain to make first task active
    active_task_ = clean_chain;
    while (active_task_->prev_ != nullptr) {
      active_task_ = active_task_->prev_;
    }

186
    // jump back to the continuation in main scheduling loop, time to steal some work
187 188
    result_cont = std::move(thread_state::get().main_continuation());
    PLS_ASSERT(result_cont.valid(), "Must return a valid continuation.");
189 190
    return true;
  } else {
191 192 193 194
    // Make sure that we are owner fo this full continuation/task chain.
    last_task->next_ = this_task;
    this_task->prev_ = last_task;

195 196 197 198
    // We are the last one working on this task. Thus the sync must be finished, continue working.
    active_task_ = last_task;

    result_cont = std::move(last_task->continuation_);
199
    PLS_ASSERT(result_cont.valid(), "Must return a valid continuation.");
200 201 202 203 204 205
    return false;
  }
}

bool task_manager::check_task_chain_forward(task *start_task) {
  while (start_task->next_ != nullptr) {
206 207 208
    if (start_task->next_->prev_ != start_task) {
      return false;
    }
209 210 211 212 213 214 215
    start_task = start_task->next_;
  }
  return true;
}

bool task_manager::check_task_chain_backward(task *start_task) {
  while (start_task->prev_ != nullptr) {
216 217 218
    if (start_task->prev_->next_ != start_task) {
      return false;
    }
219 220 221 222 223 224
    start_task = start_task->prev_;
  }
  return true;
}

bool task_manager::check_task_chain() {
225
  return check_task_chain_backward(active_task_) && check_task_chain_forward(active_task_);
226 227
}

228 229 230
}
}
}