task_manager.cpp 8.38 KB
Newer Older
1 2
#include <tuple>

3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
#include "pls/internal/scheduling/task_manager.h"

#include "pls/internal/scheduling/task.h"
#include "pls/internal/scheduling/thread_state.h"

namespace pls {
namespace internal {
namespace scheduling {

task_manager::task_manager(task *tasks,
                           data_structures::aligned_stack static_stack_space,
                           size_t num_tasks,
                           size_t stack_size,
                           external_trading_deque &deque) : num_tasks_{num_tasks},
                                                            this_thread_tasks_{tasks},
                                                            active_task_{&tasks[0]},
                                                            deque_{deque} {
  for (size_t i = 0; i < num_tasks - 1; i++) {
    tasks[i].init(static_stack_space.push_bytes(stack_size), stack_size, i, 0);
    if (i > 0) {
      tasks[i].prev_ = &tasks[i - 1];
    }
    if (i < num_tasks - 2) {
      tasks[i].next_ = &tasks[i + 1];
    }
  }
}

static task *find_task(unsigned id, unsigned depth) {
  return thread_state::get().get_scheduler().thread_state_for(id).get_task_manager().get_this_thread_task(depth);
}

35
bool task_manager::steal_task(task_manager &stealing_task_manager) {
36 37 38
  PLS_ASSERT(stealing_task_manager.active_task_->depth_ == 0, "Must only steal with clean task chain.");

  auto peek = deque_.peek_top();
39 40
  auto optional_target_task = peek.top_task_;
  auto target_top = peek.top_pointer_;
41 42

  if (optional_target_task) {
43 44 45
    PLS_ASSERT(stealing_task_manager.check_task_chain(), "We are stealing, must not have a bad chain here!");

    // search for the task we want to trade in
46 47 48 49 50 51
    task *target_task = *optional_target_task;
    task *traded_task = stealing_task_manager.active_task_;
    for (unsigned i = 0; i < target_task->depth_; i++) {
      traded_task = traded_task->next_;
    }

52 53 54 55 56
    // keep a reference to the rest of the task chain that we keep
    task *next_own_task = traded_task->next_;
    // 'unchain' the traded tasks (to help us find bugs only)
    traded_task->next_ = nullptr;

57 58
    auto optional_result_task = deque_.pop_top(traded_task, target_top);
    if (optional_result_task) {
59 60 61 62
      PLS_ASSERT(target_task->thread_id_ != traded_task->thread_id_,
                 "It is impossible to steal an task we already own!");
      PLS_ASSERT(*optional_result_task == target_task,
                 "We must only steal the task that we peeked at!");
63 64 65 66 67 68
      // the steal was a success, link the chain so we own the stolen part
      target_task->next_ = next_own_task;
      next_own_task->prev_ = target_task;
      stealing_task_manager.set_active_task(target_task);

      return true;
69
    } else {
70 71 72 73
      // the steal failed, reset our chain to its old, clean state (re-link what we have broken)
      traded_task->next_ = next_own_task;

      return false;
74 75
    }
  } else {
76
    return false;
77
  }
78 79
}

80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
void task_manager::push_resource_on_task(task *target_task, task *spare_task_chain) {
  PLS_ASSERT(check_task_chain_backward(spare_task_chain), "Must only push proper task chains.");
  PLS_ASSERT(target_task->thread_id_ != spare_task_chain->thread_id_,
             "Makes no sense to push task onto itself, as it is not clean by definition.");
  PLS_ASSERT(target_task->depth_ == spare_task_chain->depth_, "Must only push tasks with correct depth.");

  data_structures::stamped_integer current_root;
  data_structures::stamped_integer target_root;
  do {
    current_root = target_task->resource_stack_root_.load();
    target_root.stamp = current_root.stamp + 1;
    target_root.value = spare_task_chain->thread_id_ + 1;

    if (current_root.value == 0) {
      // Empty, simply push in with no successor
      spare_task_chain->resource_stack_next_ = nullptr;
    } else {
      // Already an entry. Find it's corresponding task and set it as our successor.
      auto *current_root_task = find_task(current_root.value - 1, target_task->depth_);
      spare_task_chain->resource_stack_next_ = current_root_task;
    }

  } while (!target_task->resource_stack_root_.compare_exchange_strong(current_root, target_root));
}

105 106 107
task *task_manager::pop_resource_from_task(task *target_task) {
  data_structures::stamped_integer current_root;
  data_structures::stamped_integer target_root;
108
  task *output_task;
109 110 111 112 113 114 115 116 117 118
  do {
    current_root = target_task->resource_stack_root_.load();
    target_root.stamp = current_root.stamp + 1;

    if (current_root.value == 0) {
      // Empty...
      return nullptr;
    } else {
      // Found something, try to pop it
      auto *current_root_task = find_task(current_root.value - 1, target_task->depth_);
119 120
      auto *next_stack_task = current_root_task->resource_stack_next_;
      target_root.value = next_stack_task != nullptr ? next_stack_task->thread_id_ + 1 : 0;
121 122

      output_task = current_root_task;
123
    }
124 125
  } while (!target_task->resource_stack_root_.compare_exchange_strong(current_root, target_root));

126
  PLS_ASSERT(check_task_chain_backward(output_task), "Must only pop proper task chains.");
127
  return output_task;
128 129
}

130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
void task_manager::sync() {
  auto continuation = active_task_->next_->run_as_task([this](context_switcher::continuation cont) {
    auto *last_task = active_task_;
    auto *this_task = active_task_->next_;

    last_task->continuation_ = std::move(cont);
    active_task_ = this_task;

    context_switcher::continuation result_cont;
    if (try_clean_return(result_cont)) {
      // We return back to the main scheduling loop
      active_task_->clean_ = true;
      return result_cont;
    } else {
      // We finish up the last task
      active_task_->clean_ = false;
      return result_cont;
    }
  });

  if (continuation.valid()) {
    // We jumped in here from the main loop, keep track!
    thread_state::get().set_main_continuation(std::move(continuation));
  }
}

bool task_manager::try_clean_return(context_switcher::continuation &result_cont) {
  task *this_task = active_task_;
  task *last_task = active_task_->prev_;

  if (last_task == nullptr) {
    // We finished the final task of the computation, return to the scheduling loop.
    result_cont = thread_state::get().get_main_continuation();
    return true;
  }

  // Try to get a clean resource chain to go back to the main stealing loop
  task *clean_chain = pop_resource_from_task(last_task);
  if (clean_chain == nullptr) {
    // double-check if we are really last one or we only have unlucky timing
170 171 172 173 174
    auto optional_cas_task = external_trading_deque::get_trade_object(last_task);
    if (optional_cas_task) {
      clean_chain = *optional_cas_task;
    } else {
      clean_chain = pop_resource_from_task(last_task);
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231
    }
  }

  if (clean_chain != nullptr) {
    // We got a clean chain to continue working on.
    PLS_ASSERT(last_task->depth_ == clean_chain->depth_,
               "Resources must only reside in the correct depth!");
    PLS_ASSERT(check_task_chain_backward(clean_chain), "Can only aquire clean chains for clean returns!");
    this_task->prev_ = clean_chain;
    clean_chain->next_ = this_task;
    // Walk back chain to make first task active
    active_task_ = clean_chain;
    while (active_task_->prev_ != nullptr) {
      active_task_ = active_task_->prev_;
    }

    PLS_ASSERT(check_task_chain(), "We just aquired a clean chain...");

    // jump back to continuation in main scheduling loop, time to steal some work
    result_cont = thread_state::get().get_main_continuation();
    return true;
  } else {
    // We are the last one working on this task. Thus the sync must be finished, continue working.
    active_task_ = last_task;

    // Make sure that we are owner fo this full continuation/task chain.
    active_task_->next_ = this_task;
    this_task->prev_ = active_task_;

    result_cont = std::move(last_task->continuation_);
    return false;
  }
}

bool task_manager::check_task_chain_forward(task *start_task) {
  while (start_task->next_ != nullptr) {
    PLS_ASSERT(start_task->next_->prev_ == start_task, "Chain must have correct prev/next fields for linked list!");
    start_task = start_task->next_;
  }
  return true;
}

bool task_manager::check_task_chain_backward(task *start_task) {
  while (start_task->prev_ != nullptr) {
    PLS_ASSERT(start_task->prev_->next_ == start_task, "Chain must have correct prev/next fields for linked list!");
    start_task = start_task->prev_;
  }
  return true;
}

bool task_manager::check_task_chain() {
  check_task_chain_backward(active_task_);
  check_task_chain_forward(active_task_);

  return true;
}

232 233 234
}
}
}