Commit 41809925 by FritzFlorian

Add simple profiling code that records dynamic executed DAG.

Add key-points required to capture an execution DAG. Currently all data is in-memory and printed out. In future work it would be good to store the DAG's and/or process them further.
parent 998e43af
Pipeline #1472 passed with stages
in 4 minutes 24 seconds
...@@ -39,7 +39,7 @@ add_library(pls STATIC ...@@ -39,7 +39,7 @@ add_library(pls STATIC
include/pls/internal/scheduling/lock_free/task.h include/pls/internal/scheduling/lock_free/task.h
include/pls/internal/scheduling/lock_free/task_manager.h src/internal/scheduling/lock_free/task_manager.cpp include/pls/internal/scheduling/lock_free/task_manager.h src/internal/scheduling/lock_free/task_manager.cpp
include/pls/internal/scheduling/lock_free/external_trading_deque.h src/internal/scheduling/lock_free/external_trading_deque.cpp include/pls/internal/scheduling/lock_free/external_trading_deque.h src/internal/scheduling/lock_free/external_trading_deque.cpp
include/pls/internal/scheduling/lock_free/traded_cas_field.h src/internal/scheduling/lock_free/task.cpp) include/pls/internal/scheduling/lock_free/traded_cas_field.h src/internal/scheduling/lock_free/task.cpp include/pls/internal/profiling/dag_node.h include/pls/internal/profiling/profiler.h include/pls/internal/profiling/thread_stats.h)
# Dependencies for pls # Dependencies for pls
target_link_libraries(pls Threads::Threads) target_link_libraries(pls Threads::Threads)
......
#ifndef PLS_INTERNAL_PROFILING_DAG_NODE_H_
#define PLS_INTERNAL_PROFILING_DAG_NODE_H_
#include <memory>
#include <list>
#include <algorithm>
namespace pls::internal::profiling {
struct dag_node {
dag_node(unsigned spawning_thread_id) : spawning_thread_id_{spawning_thread_id} {};
unsigned spawning_thread_id_;
unsigned long max_memory_{0};
unsigned long total_runtime_{0};
std::unique_ptr<dag_node> next_node_;
std::list<dag_node> child_nodes_;
unsigned dag_max_memory() {
unsigned max = max_memory_;
if (next_node_) {
max = std::max(max, next_node_->dag_max_memory());
}
for (auto &child : child_nodes_) {
max = std::max(max, child.dag_max_memory());
}
return max;
}
unsigned long dag_total_user_time() {
unsigned total_user_time = total_runtime_;
if (next_node_) {
total_user_time += next_node_->dag_total_user_time();
}
for (auto &child : child_nodes_) {
total_user_time += child.dag_total_user_time();
}
return total_user_time;
}
};
}
#endif //PLS_INTERNAL_PROFILING_DAG_NODE_H_
#ifndef PLS_INTERNAL_PROFILING_PROFILER_H_
#define PLS_INTERNAL_PROFILING_PROFILER_H_
#define PLS_PROFILING_ENABLED true
#include <memory>
#include <chrono>
#include <vector>
#include <iostream>
#include "dag_node.h"
#include "thread_stats.h"
namespace pls::internal::profiling {
class profiler {
using clock = std::chrono::steady_clock;
using resolution = std::chrono::nanoseconds;
struct profiler_run {
profiler_run(unsigned num_threads) : start_time_{},
end_time_{},
root_node_{std::make_unique<dag_node>(0)},
per_thread_stats_(num_threads) {}
clock::time_point start_time_;
clock::time_point end_time_;
std::unique_ptr<dag_node> root_node_;
std::vector<thread_stats> per_thread_stats_;
void print_stats(unsigned num_threads) const {
auto run_duration = std::chrono::duration_cast<resolution>(end_time_ - start_time_).count();
std::cout << "===========================" << std::endl;
std::cout << "WALL TIME: " << run_duration << std::endl;
unsigned long total_user_time = root_node_->dag_total_user_time();
std::cout << "USER TIME: " << total_user_time << std::endl;
unsigned long total_failed_steals = 0;
unsigned long total_successful_steals = 0;
unsigned long total_steal_time = 0;
for (auto &thread_stats : per_thread_stats_) {
total_failed_steals += thread_stats.failed_steals_;
total_successful_steals += thread_stats.successful_steals_;
total_steal_time += thread_stats.total_time_stealing_;
}
std::cout << "STEALS: (Time " << total_steal_time
<< ", Total: " << (total_successful_steals + total_failed_steals)
<< ", Success: " << total_successful_steals
<< ", Failed: " << total_failed_steals << ")" << std::endl;
unsigned long total_measured = total_steal_time + total_user_time;
unsigned long total_wall = run_duration * num_threads;
std::cout << "Wall Time vs. Measured: " << 100.0 * total_measured / total_wall << std::endl;
std::cout << "MEMORY: " << root_node_->dag_max_memory() << " bytes per stack" << std::endl;
}
};
public:
profiler(unsigned num_threads) : num_threads_{num_threads},
profiler_runs_() {}
dag_node *start_profiler_run() {
profiler_run &current_run = profiler_runs_.emplace_back(num_threads_);
current_run.start_time_ = clock::now();
return current_run.root_node_.get();
}
void stop_profiler_run() {
current_run().end_time_ = clock::now();
current_run().print_stats(num_threads_);
}
void stealing_start(unsigned thread_id) {
auto &thread_stats = thread_stats_for(thread_id);
thread_stats.run_on_stack([&] {
thread_stats.stealing_start_time_ = clock::now();
thread_stats.total_steals_++;
});
}
void stealing_end(unsigned thread_id, bool success) {
auto &thread_stats = thread_stats_for(thread_id);
thread_stats.run_on_stack([&] {
thread_stats.failed_steals_ += !success;
thread_stats.successful_steals_ += success;
auto end_time = clock::now();
auto
steal_duration = std::chrono::duration_cast<resolution>(end_time - thread_stats.stealing_start_time_).count();
thread_stats.total_time_stealing_ += steal_duration;
});
}
void stealing_cas_op(unsigned thread_id) {
auto &thread_stats = thread_stats_for(thread_id);
thread_stats.run_on_stack([&] {
thread_stats.steal_cas_ops_++;
});
}
dag_node *task_spawn_child(unsigned thread_id, dag_node *parent) {
auto &thread_stats = thread_stats_for(thread_id);
dag_node *result;
thread_stats.run_on_stack([&] {
result = &parent->child_nodes_.emplace_back(thread_id);
});
return result;
}
dag_node *task_sync(unsigned thread_id, dag_node *synced) {
auto &thread_stats = thread_stats_for(thread_id);
dag_node *result;
thread_stats.run_on_stack([&] {
synced->next_node_ = std::make_unique<dag_node>(thread_id);
result = synced->next_node_.get();
});
return result;
}
void task_start_running(unsigned thread_id, dag_node *in_node) {
auto &thread_stats = thread_stats_for(thread_id);
thread_stats.run_on_stack([&] {
thread_stats.task_run_start_time = clock::now();
});
}
void task_stop_running(unsigned thread_id, dag_node *in_node) {
auto &thread_stats = thread_stats_for(thread_id);
thread_stats.run_on_stack([&] {
auto end_time = clock::now();
auto user_code_duration =
std::chrono::duration_cast<resolution>(end_time - thread_stats.task_run_start_time).count();
in_node->total_runtime_ += user_code_duration;
});
}
static constexpr char MAGIC_BYTES[] = {'A', 'B', 'A', 'B', 'A', 'B', 'A', 'B'};
void task_prepare_stack_measure(unsigned thread_id, char *stack_memory, size_t stack_size) {
auto &thread_stats = thread_stats_for(thread_id);
thread_stats.run_on_stack([&] {
for (size_t i = 0; i < stack_size - sizeof(MAGIC_BYTES); i += sizeof(MAGIC_BYTES)) {
for (size_t j = 0; j < sizeof(MAGIC_BYTES); j++) {
stack_memory[i + j] = MAGIC_BYTES[j];
}
}
});
}
void task_finish_stack_measure(unsigned thread_id, char *stack_memory, size_t stack_size, dag_node *in_node) {
auto &thread_stats = thread_stats_for(thread_id);
thread_stats.run_on_stack([&] {
for (size_t i = 0; i < stack_size - sizeof(MAGIC_BYTES); i += sizeof(MAGIC_BYTES)) {
bool section_clean = true;
for (size_t j = 0; j < sizeof(MAGIC_BYTES); j++) {
if (stack_memory[i + j] != MAGIC_BYTES[j]) {
section_clean = false;
}
}
in_node->max_memory_ = stack_size - i + sizeof(MAGIC_BYTES);
if (!section_clean) {
return;
}
}
});
}
private:
profiler_run &current_run() {
return profiler_runs_[profiler_runs_.size() - 1];
}
thread_stats &thread_stats_for(unsigned thread_id) {
return current_run().per_thread_stats_[thread_id];
}
unsigned num_threads_;
std::vector<profiler_run> profiler_runs_;
};
}
#endif //PLS_INTERNAL_PROFILING_PROFILER_H_
#ifndef PLS_INTERNAL_PROFILING_THREAD_STATS_H_
#define PLS_INTERNAL_PROFILING_THREAD_STATS_H_
#include <chrono>
#include <type_traits>
#include "context_switcher/context_switcher.h"
#include "pls/internal/base/system_details.h"
#include "pls/internal/base/stack_allocator.h"
namespace pls::internal::profiling {
struct PLS_CACHE_ALIGN thread_stats {
static constexpr size_t STACK_SIZE = 4096 * 4;
thread_stats() {
stack_ = stack_allocator_.allocate_stack(STACK_SIZE);
}
~thread_stats() {
stack_allocator_.free_stack(STACK_SIZE, stack_);
}
template<typename Function>
void run_on_stack(const Function function) {
context_switcher::enter_context(stack_, STACK_SIZE, [function](auto cont) {
function();
return std::move(cont);
});
}
using clock = std::chrono::steady_clock;
unsigned long total_steals_{0};
unsigned long successful_steals_{0};
unsigned long failed_steals_{0};
unsigned long steal_cas_ops_{0};
unsigned long total_time_stealing_{0};
clock::time_point stealing_start_time_;
clock::time_point task_run_start_time;
base::mmap_stack_allocator stack_allocator_;
char *stack_;
};
}
#endif //PLS_INTERNAL_PROFILING_THREAD_STATS_H_
...@@ -7,6 +7,9 @@ ...@@ -7,6 +7,9 @@
#include "context_switcher/continuation.h" #include "context_switcher/continuation.h"
#include "context_switcher/context_switcher.h" #include "context_switcher/context_switcher.h"
#include "pls/internal/profiling/profiler.h"
#include "pls/internal/profiling/dag_node.h"
namespace pls::internal::scheduling { namespace pls::internal::scheduling {
/** /**
* A task is the smallest unit of execution seen by the runtime system. * A task is the smallest unit of execution seen by the runtime system.
...@@ -59,6 +62,10 @@ struct base_task { ...@@ -59,6 +62,10 @@ struct base_task {
// Linked list for trading/memory management // Linked list for trading/memory management
base_task *prev_; base_task *prev_;
base_task *next_; base_task *next_;
#if PLS_PROFILING_ENABLED
profiling::dag_node *profiling_node_;
#endif
}; };
} }
......
...@@ -42,7 +42,8 @@ class task_manager { ...@@ -42,7 +42,8 @@ class task_manager {
base_task *pop_local_task(); base_task *pop_local_task();
// Stealing work, automatically trades in another task // Stealing work, automatically trades in another task
std::tuple<base_task *, base_task *> steal_task(thread_state &stealing_state); // Return: stolen_task, traded_task, cas_success
std::tuple<base_task *, base_task *, bool> steal_task(thread_state &stealing_state);
// Sync/memory management // Sync/memory management
base_task *pop_clean_task_chain(base_task *task); base_task *pop_clean_task_chain(base_task *task);
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
#include "pls/internal/scheduling/thread_state.h" #include "pls/internal/scheduling/thread_state.h"
#include "pls/internal/scheduling/task_manager.h" #include "pls/internal/scheduling/task_manager.h"
#include "pls/internal/profiling/profiler.h"
namespace pls::internal::scheduling { namespace pls::internal::scheduling {
/** /**
* The scheduler is the central part of the dispatching-framework. * The scheduler is the central part of the dispatching-framework.
...@@ -121,6 +123,10 @@ class scheduler { ...@@ -121,6 +123,10 @@ class scheduler {
bool terminated_; bool terminated_;
std::shared_ptr<base::stack_allocator> stack_allocator_; std::shared_ptr<base::stack_allocator> stack_allocator_;
#if PLS_PROFILING_ENABLED
profiling::profiler profiler_;
#endif
}; };
} }
......
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
#include "pls/internal/scheduling/base_task.h" #include "pls/internal/scheduling/base_task.h"
#include "base_task.h" #include "base_task.h"
#include "pls/internal/profiling/dag_node.h"
namespace pls::internal::scheduling { namespace pls::internal::scheduling {
template<typename ALLOC> template<typename ALLOC>
...@@ -27,7 +29,11 @@ scheduler::scheduler(unsigned int num_threads, ...@@ -27,7 +29,11 @@ scheduler::scheduler(unsigned int num_threads,
main_thread_starter_function_{nullptr}, main_thread_starter_function_{nullptr},
work_section_done_{false}, work_section_done_{false},
terminated_{false}, terminated_{false},
stack_allocator_{std::make_shared<ALLOC>(std::forward<ALLOC>(stack_allocator))} { stack_allocator_{std::make_shared<ALLOC>(std::forward<ALLOC>(stack_allocator))}
#if PLS_PROFILING_ENABLED
, profiler_{num_threads}
#endif
{
worker_threads_.reserve(num_threads); worker_threads_.reserve(num_threads);
task_managers_.reserve(num_threads); task_managers_.reserve(num_threads);
...@@ -63,15 +69,36 @@ class scheduler::init_function_impl : public init_function { ...@@ -63,15 +69,36 @@ class scheduler::init_function_impl : public init_function {
explicit init_function_impl(F &function) : function_{function} {} explicit init_function_impl(F &function) : function_{function} {}
void run() override { void run() override {
base_task *root_task = thread_state::get().get_active_task(); base_task *root_task = thread_state::get().get_active_task();
#if PLS_PROFILING_ENABLED
thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(),
root_task->profiling_node_);
thread_state::get().get_scheduler().profiler_.task_prepare_stack_measure(thread_state::get().get_thread_id(),
root_task->stack_memory_,
root_task->stack_size_);
#endif
root_task->run_as_task([root_task, this](auto cont) { root_task->run_as_task([root_task, this](auto cont) {
root_task->is_synchronized_ = true; root_task->is_synchronized_ = true;
thread_state::get().main_continuation() = std::move(cont); thread_state::get().main_continuation() = std::move(cont);
function_(); function_();
thread_state::get().get_scheduler().work_section_done_.store(true); thread_state::get().get_scheduler().work_section_done_.store(true);
PLS_ASSERT(thread_state::get().main_continuation().valid(), "Must return valid continuation from main task."); PLS_ASSERT(thread_state::get().main_continuation().valid(), "Must return valid continuation from main task.");
#if PLS_PROFILING_ENABLED
thread_state::get().get_scheduler().profiler_.task_stop_running(thread_state::get().get_thread_id(),
root_task->profiling_node_);
#endif
return std::move(thread_state::get().main_continuation()); return std::move(thread_state::get().main_continuation());
}); });
#if PLS_PROFILING_ENABLED
thread_state::get().get_scheduler().profiler_.task_finish_stack_measure(thread_state::get().get_thread_id(),
root_task->stack_memory_,
root_task->stack_size_,
root_task->profiling_node_);
#endif
} }
private: private:
F &function_; F &function_;
...@@ -83,6 +110,11 @@ void scheduler::perform_work(Function work_section) { ...@@ -83,6 +110,11 @@ void scheduler::perform_work(Function work_section) {
init_function_impl<Function> starter_function{work_section}; init_function_impl<Function> starter_function{work_section};
main_thread_starter_function_ = &starter_function; main_thread_starter_function_ = &starter_function;
#if PLS_PROFILING_ENABLED
auto *root_task = thread_state_for(0).get_active_task();
auto *root_node = profiler_.start_profiler_run();
root_task->profiling_node_ = root_node;
#endif
work_section_done_ = false; work_section_done_ = false;
if (reuse_thread_) { if (reuse_thread_) {
auto &my_state = thread_state_for(0); auto &my_state = thread_state_for(0);
...@@ -96,6 +128,9 @@ void scheduler::perform_work(Function work_section) { ...@@ -96,6 +128,9 @@ void scheduler::perform_work(Function work_section) {
sync_barrier_.wait(); // Trigger threads to wake up sync_barrier_.wait(); // Trigger threads to wake up
sync_barrier_.wait(); // Wait for threads to finish sync_barrier_.wait(); // Wait for threads to finish
} }
#if PLS_PROFILING_ENABLED
profiler_.stop_profiler_run();
#endif
} }
template<typename Function> template<typename Function>
...@@ -105,6 +140,16 @@ void scheduler::spawn(Function &&lambda) { ...@@ -105,6 +140,16 @@ void scheduler::spawn(Function &&lambda) {
base_task *last_task = spawning_state.get_active_task(); base_task *last_task = spawning_state.get_active_task();
base_task *spawned_task = last_task->next_; base_task *spawned_task = last_task->next_;
#if PLS_PROFILING_ENABLED
// Memory and DAG nodes
spawning_state.get_scheduler().profiler_.task_prepare_stack_measure(spawning_state.get_thread_id(),
spawned_task->stack_memory_,
spawned_task->stack_size_);
auto *child_dag_node = spawning_state.get_scheduler().profiler_.task_spawn_child(spawning_state.get_thread_id(),
last_task->profiling_node_);
spawned_task->profiling_node_ = child_dag_node;
#endif
auto continuation = spawned_task->run_as_task([last_task, spawned_task, lambda, &spawning_state](auto cont) { auto continuation = spawned_task->run_as_task([last_task, spawned_task, lambda, &spawning_state](auto cont) {
// allow stealing threads to continue the last task. // allow stealing threads to continue the last task.
last_task->continuation_ = std::move(cont); last_task->continuation_ = std::move(cont);
...@@ -115,7 +160,14 @@ void scheduler::spawn(Function &&lambda) { ...@@ -115,7 +160,14 @@ void scheduler::spawn(Function &&lambda) {
spawning_state.get_task_manager().push_local_task(last_task); spawning_state.get_task_manager().push_local_task(last_task);
// execute the lambda itself, which could lead to a different thread returning. // execute the lambda itself, which could lead to a different thread returning.
#if PLS_PROFILING_ENABLED
spawning_state.get_scheduler().profiler_.task_stop_running(spawning_state.get_thread_id(),
last_task->profiling_node_);
spawning_state.get_scheduler().profiler_.task_start_running(spawning_state.get_thread_id(),
spawned_task->profiling_node_);
#endif
lambda(); lambda();
thread_state &syncing_state = thread_state::get(); thread_state &syncing_state = thread_state::get();
PLS_ASSERT(syncing_state.get_active_task() == spawned_task, PLS_ASSERT(syncing_state.get_active_task() == spawned_task,
"Task manager must always point its active task onto whats executing."); "Task manager must always point its active task onto whats executing.");
...@@ -136,13 +188,27 @@ void scheduler::spawn(Function &&lambda) { ...@@ -136,13 +188,27 @@ void scheduler::spawn(Function &&lambda) {
"Fast path, no one can have continued working on the last task."); "Fast path, no one can have continued working on the last task.");
syncing_state.set_active_task(last_task); syncing_state.set_active_task(last_task);
#if PLS_PROFILING_ENABLED
syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(),
spawned_task->profiling_node_);
#endif
return std::move(last_task->continuation_); return std::move(last_task->continuation_);
} else { } else {
// Slow path, the last task was stolen. This path is common to sync() events. // Slow path, the last task was stolen. This path is common to sync() events.
return slow_return(syncing_state); auto continuation = slow_return(syncing_state);
#if PLS_PROFILING_ENABLED
syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(),
spawned_task->profiling_node_);
#endif
return std::move(continuation);
} }
}); });
#if PLS_PROFILING_ENABLED
thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(),
thread_state::get().get_active_task()->profiling_node_);
#endif
if (continuation.valid()) { if (continuation.valid()) {
// We jumped in here from the main loop, keep track! // We jumped in here from the main loop, keep track!
thread_state::get().main_continuation() = std::move(continuation); thread_state::get().main_continuation() = std::move(continuation);
......
...@@ -39,7 +39,7 @@ base_task *task_manager::pop_local_task() { ...@@ -39,7 +39,7 @@ base_task *task_manager::pop_local_task() {
return deque_.pop_bot(); return deque_.pop_bot();
} }
std::tuple<base_task *, base_task *> task_manager::steal_task(thread_state &stealing_state) { std::tuple<base_task *, base_task *, bool> task_manager::steal_task(thread_state &stealing_state) {
PLS_ASSERT(stealing_state.get_active_task()->depth_ == 0, "Must only steal with clean task chain."); PLS_ASSERT(stealing_state.get_active_task()->depth_ == 0, "Must only steal with clean task chain.");
PLS_ASSERT(scheduler::check_task_chain(*stealing_state.get_active_task()), "Must only steal with clean task chain."); PLS_ASSERT(scheduler::check_task_chain(*stealing_state.get_active_task()), "Must only steal with clean task chain.");
...@@ -52,7 +52,7 @@ std::tuple<base_task *, base_task *> task_manager::steal_task(thread_state &stea ...@@ -52,7 +52,7 @@ std::tuple<base_task *, base_task *> task_manager::steal_task(thread_state &stea
base_task *chain_after_stolen_task = traded_task->next_; base_task *chain_after_stolen_task = traded_task->next_;
// perform the actual pop operation // perform the actual pop operation
task* pop_result_task = deque_.pop_top(traded_task, peek); task *pop_result_task = deque_.pop_top(traded_task, peek);
if (pop_result_task) { if (pop_result_task) {
PLS_ASSERT(stolen_task->thread_id_ != traded_task->thread_id_, PLS_ASSERT(stolen_task->thread_id_ != traded_task->thread_id_,
"It is impossible to steal an task we already own!"); "It is impossible to steal an task we already own!");
...@@ -62,7 +62,7 @@ std::tuple<base_task *, base_task *> task_manager::steal_task(thread_state &stea ...@@ -62,7 +62,7 @@ std::tuple<base_task *, base_task *> task_manager::steal_task(thread_state &stea
// update the resource stack associated with the stolen task // update the resource stack associated with the stolen task
stolen_task->push_task_chain(traded_task); stolen_task->push_task_chain(traded_task);
task* optional_exchanged_task = external_trading_deque::get_trade_object(stolen_task); task *optional_exchanged_task = external_trading_deque::get_trade_object(stolen_task);
if (optional_exchanged_task) { if (optional_exchanged_task) {
// All good, we pushed the task over to the stack, nothing more to do // All good, we pushed the task over to the stack, nothing more to do
PLS_ASSERT(optional_exchanged_task == traded_task, PLS_ASSERT(optional_exchanged_task == traded_task,
...@@ -73,12 +73,12 @@ std::tuple<base_task *, base_task *> task_manager::steal_task(thread_state &stea ...@@ -73,12 +73,12 @@ std::tuple<base_task *, base_task *> task_manager::steal_task(thread_state &stea
stolen_task->reset_task_chain(); stolen_task->reset_task_chain();
} }
return std::pair{stolen_task, chain_after_stolen_task}; return std::tuple{stolen_task, chain_after_stolen_task, true};
} else { } else {
return std::pair{nullptr, nullptr}; return std::tuple{nullptr, nullptr, false};
} }
} else { } else {
return std::pair{nullptr, nullptr}; return std::tuple{nullptr, nullptr, true};
} }
} }
...@@ -88,7 +88,7 @@ base_task *task_manager::pop_clean_task_chain(base_task *base_task) { ...@@ -88,7 +88,7 @@ base_task *task_manager::pop_clean_task_chain(base_task *base_task) {
task *clean_chain = popped_task->pop_task_chain(); task *clean_chain = popped_task->pop_task_chain();
if (clean_chain == nullptr) { if (clean_chain == nullptr) {
// double-check if we are really last one or we only have unlucky timing // double-check if we are really last one or we only have unlucky timing
task* optional_cas_task = external_trading_deque::get_trade_object(popped_task); task *optional_cas_task = external_trading_deque::get_trade_object(popped_task);
if (optional_cas_task) { if (optional_cas_task) {
clean_chain = optional_cas_task; clean_chain = optional_cas_task;
} else { } else {
......
...@@ -48,6 +48,9 @@ void scheduler::work_thread_work_section() { ...@@ -48,6 +48,9 @@ void scheduler::work_thread_work_section() {
unsigned int failed_steals = 0; unsigned int failed_steals = 0;
while (!work_section_done_) { while (!work_section_done_) {
#if PLS_PROFILING_ENABLED
my_state.get_scheduler().profiler_.stealing_start(my_state.get_thread_id());
#endif
PLS_ASSERT(check_task_chain(*my_state.get_active_task()), "Must start stealing with a clean task chain."); PLS_ASSERT(check_task_chain(*my_state.get_active_task()), "Must start stealing with a clean task chain.");
size_t target; size_t target;
...@@ -56,7 +59,18 @@ void scheduler::work_thread_work_section() { ...@@ -56,7 +59,18 @@ void scheduler::work_thread_work_section() {
} while (target == my_state.get_thread_id()); } while (target == my_state.get_thread_id());
thread_state &target_state = my_state.get_scheduler().thread_state_for(target); thread_state &target_state = my_state.get_scheduler().thread_state_for(target);
auto[stolen_task, chain_after_stolen_task] = target_state.get_task_manager().steal_task(my_state);
base_task *stolen_task;
base_task *chain_after_stolen_task;
bool cas_success;
do {
#if PLS_PROFILING_ENABLED
my_state.get_scheduler().profiler_.stealing_cas_op(my_state.get_thread_id());
#endif
std::tie(stolen_task, chain_after_stolen_task, cas_success) =
target_state.get_task_manager().steal_task(my_state);
} while (!cas_success);
if (stolen_task) { if (stolen_task) {
// Keep task chain consistent. We want to appear as if we are working an a branch upwards of the stolen task. // Keep task chain consistent. We want to appear as if we are working an a branch upwards of the stolen task.
stolen_task->next_ = chain_after_stolen_task; stolen_task->next_ = chain_after_stolen_task;
...@@ -70,6 +84,11 @@ void scheduler::work_thread_work_section() { ...@@ -70,6 +84,11 @@ void scheduler::work_thread_work_section() {
PLS_ASSERT(stolen_task->continuation_.valid(), PLS_ASSERT(stolen_task->continuation_.valid(),
"A task that we can steal must have a valid continuation for us to start working."); "A task that we can steal must have a valid continuation for us to start working.");
stolen_task->is_synchronized_ = false; stolen_task->is_synchronized_ = false;
#if PLS_PROFILING_ENABLED
my_state.get_scheduler().profiler_.stealing_end(my_state.get_thread_id(), true);
my_state.get_scheduler().profiler_.task_start_running(my_state.get_thread_id(),
stolen_task->profiling_node_);
#endif
context_switcher::switch_context(std::move(stolen_task->continuation_)); context_switcher::switch_context(std::move(stolen_task->continuation_));
// We will continue execution in this line when we finished the stolen work. // We will continue execution in this line when we finished the stolen work.
failed_steals = 0; failed_steals = 0;
...@@ -78,6 +97,9 @@ void scheduler::work_thread_work_section() { ...@@ -78,6 +97,9 @@ void scheduler::work_thread_work_section() {
if (failed_steals >= num_threads) { if (failed_steals >= num_threads) {
std::this_thread::yield(); std::this_thread::yield();
} }
#if PLS_PROFILING_ENABLED
my_state.get_scheduler().profiler_.stealing_end(my_state.get_thread_id(), false);
#endif
} }
} }
} }
...@@ -88,7 +110,23 @@ void scheduler::sync() { ...@@ -88,7 +110,23 @@ void scheduler::sync() {
base_task *active_task = syncing_state.get_active_task(); base_task *active_task = syncing_state.get_active_task();
base_task *spawned_task = active_task->next_; base_task *spawned_task = active_task->next_;
#if PLS_PROFILING_ENABLED
syncing_state.get_scheduler().profiler_.task_finish_stack_measure(syncing_state.get_thread_id(),
active_task->stack_memory_,
active_task->stack_size_,
active_task->profiling_node_);
syncing_state.get_scheduler().profiler_.task_stop_running(syncing_state.get_thread_id(),
active_task->profiling_node_);
auto *next_dag_node =
syncing_state.get_scheduler().profiler_.task_sync(syncing_state.get_thread_id(), active_task->profiling_node_);
active_task->profiling_node_ = next_dag_node;
#endif
if (active_task->is_synchronized_) { if (active_task->is_synchronized_) {
#if PLS_PROFILING_ENABLED
syncing_state.get_scheduler().profiler_.task_start_running(syncing_state.get_thread_id(),
active_task->profiling_node_);
#endif
return; // We are already the sole owner of last_task return; // We are already the sole owner of last_task
} else { } else {
auto continuation = auto continuation =
...@@ -101,6 +139,10 @@ void scheduler::sync() { ...@@ -101,6 +139,10 @@ void scheduler::sync() {
PLS_ASSERT(!continuation.valid(), PLS_ASSERT(!continuation.valid(),
"We only return to a sync point, never jump to it directly." "We only return to a sync point, never jump to it directly."
"This must therefore never return an unfinished fiber/continuation."); "This must therefore never return an unfinished fiber/continuation.");
#if PLS_PROFILING_ENABLED
thread_state::get().get_scheduler().profiler_.task_start_running(thread_state::get().get_thread_id(),
thread_state::get().get_active_task()->profiling_node_);
#endif
return; // We cleanly synced to the last one finishing work on last_task return; // We cleanly synced to the last one finishing work on last_task
} }
} }
...@@ -140,7 +182,7 @@ context_switcher::continuation scheduler::slow_return(thread_state &calling_stat ...@@ -140,7 +182,7 @@ context_switcher::continuation scheduler::slow_return(thread_state &calling_stat
// Jump back to the continuation in main scheduling loop. // Jump back to the continuation in main scheduling loop.
context_switcher::continuation result_cont = std::move(thread_state::get().main_continuation()); context_switcher::continuation result_cont = std::move(thread_state::get().main_continuation());
PLS_ASSERT(result_cont.valid(), "Must return a valid continuation."); PLS_ASSERT(result_cont.valid(), "Must return a valid continuation.");
return result_cont; return std::move(result_cont);
} else { } else {
// Make sure that we are owner of this full continuation/task chain. // Make sure that we are owner of this full continuation/task chain.
last_task->next_ = this_task; last_task->next_ = this_task;
...@@ -152,7 +194,7 @@ context_switcher::continuation scheduler::slow_return(thread_state &calling_stat ...@@ -152,7 +194,7 @@ context_switcher::continuation scheduler::slow_return(thread_state &calling_stat
// Jump to parent task and continue working on it. // Jump to parent task and continue working on it.
context_switcher::continuation result_cont = std::move(last_task->continuation_); context_switcher::continuation result_cont = std::move(last_task->continuation_);
PLS_ASSERT(result_cont.valid(), "Must return a valid continuation."); PLS_ASSERT(result_cont.valid(), "Must return a valid continuation.");
return result_cont; return std::move(result_cont);
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment