Commit 6ce5e247 by FritzFlorian

Split graph and function_node impls into separate files.

parent 7239d842
Pipeline #1284 failed with stages
in 25 seconds
......@@ -49,7 +49,7 @@ add_library(pls STATIC
include/pls/internal/scheduling/scheduler_impl.h
include/pls/internal/scheduling/task.h src/internal/scheduling/task.cpp
include/pls/internal/scheduling/scheduler_memory.h src/internal/scheduling/scheduler_memory.cpp
include/pls/internal/scheduling/lambda_task.h include/pls/internal/helpers/seqence.h include/pls/dataflow/internal/build_state.h)
include/pls/internal/scheduling/lambda_task.h include/pls/internal/helpers/seqence.h include/pls/dataflow/internal/build_state.h include/pls/dataflow/internal/function_node_impl.h include/pls/dataflow/internal/graph_impl.h)
# Add everything in `./include` to be in the include path of this project
target_include_directories(pls
PUBLIC
......
......@@ -48,6 +48,8 @@ class function_node<pls::dataflow::inputs<I0, I...>, pls::dataflow::outputs<O0,
};
public:
explicit function_node(F function) : in_port_{this, this}, out_port_{}, function_{function} {}
template<int POS>
using in_port_at = typename multi_in_port_type::template in_port_type_at<POS>;
template<int POS>
......@@ -64,67 +66,8 @@ class function_node<pls::dataflow::inputs<I0, I...>, pls::dataflow::outputs<O0,
}
template<int POS, typename T>
void token_pushed(token<T> token) {
auto current_memory = get_invocation<invocation_memory>(token.invocation());
std::get<POS>(current_memory->input_buffer_) = token;
auto remaining_inputs = --(current_memory->inputs_missing_);
if (remaining_inputs == 0) {
execute_function(current_memory, token.invocation());
current_memory->inputs_missing_ = num_in_ports;
}
}
void execute_function(invocation_memory *invocation_memory, invocation_info invocation_info) {
auto lambda = [&]() {
input_tuple &inputs = invocation_memory->input_buffer_;
output_tuple outputs;
execute_function_internal(inputs, typename sequence_gen<1 + sizeof...(I)>::type(),
outputs, typename sequence_gen<1 + sizeof...(O)>::type(), invocation_info);
};
// TODO: maybe replace this with 'continuation' style invocation
pls::scheduler::spawn_child_and_wait<pls::lambda_task_by_reference<decltype(lambda)>>(lambda);
}
template<typename T>
void set_invocation_info(token<T> &token, invocation_info invocation_info) {
token.set_invocation(invocation_info);
}
template<int N, typename ...OT>
struct propagate_output {
propagate_output(multi_out_port_type &, output_tuple &) {}
void propagate() {}
};
template<int N, typename OT1, typename ...OT>
struct propagate_output<N, OT1, OT...> {
multi_out_port_type &out_port_;
output_tuple &output_tuple_;
propagate_output(multi_out_port_type &out_port, output_tuple &output_tuple) : out_port_{out_port},
output_tuple_{output_tuple} {}
void propagate() {
out_port_.template get<N>().push_token(std::get<N>(output_tuple_));
propagate_output<N + 1, OT...>{out_port_, output_tuple_}.propagate();
}
};
template<int ...IS, int ...OS>
void execute_function_internal(input_tuple &inputs,
sequence<IS...>,
output_tuple &outputs,
sequence<OS...>,
invocation_info invocation_info) {
set_invocation_info(std::get<OS>(outputs)..., invocation_info);
function_(std::get<IS>(inputs).value()..., std::get<OS>(outputs).value()...);
propagate_output<0, O0, O...>{out_port_, outputs}.propagate();
}
void token_pushed(token<T> token);
explicit function_node(F function) : in_port_{this, this}, out_port_{}, function_{function} {}
//////////////////////////////////////////////////////////////////
// Overrides for generic node functionality (building the graph)
//////////////////////////////////////////////////////////////////
int num_successors() const override {
return num_out_ports;
}
......@@ -149,10 +92,27 @@ class function_node<pls::dataflow::inputs<I0, I...>, pls::dataflow::outputs<O0,
multi_out_port_type out_port_;
F function_;
//////////////////////////////////////////////////////////////////
// Helpers for actually calling the work lambda
//////////////////////////////////////////////////////////////////
void execute_function(invocation_memory *invocation_memory, invocation_info invocation_info);
template<int N, typename ...OT>
struct propagate_output;
template<typename T>
void set_invocation_info(token<T> &token, invocation_info invocation_info);
template<int ...IS, int ...OS>
void execute_function_internal(input_tuple &inputs, sequence<IS...>,
output_tuple &outputs, sequence<OS...>,
invocation_info invocation_info);
};
}
}
}
#include "function_node_impl.h"
#endif //PLS_DATAFLOW_INTERNAL_NODE_H_
......@@ -35,10 +35,10 @@ class graph<pls::dataflow::inputs<I0, I...>, pls::dataflow::outputs<O0, O...>> :
// Input-Output value tuples
using value_input_tuple = std::tuple<I0, I...>;
using input_tuple = std::tuple<token < I0>, token<I>...>;
using input_tuple = std::tuple<token<I0>, token<I>...>;
using value_output_tuple = std::tuple<O0, O...>;
using output_tuple = std::tuple<token < O0>, token<O>...>;
using output_tuple = std::tuple<token<O0>, token<O>...>;
static constexpr int num_in_ports = std::tuple_size<input_tuple>();
static constexpr int num_out_ports = std::tuple_size<output_tuple>();
......@@ -64,121 +64,16 @@ class graph<pls::dataflow::inputs<I0, I...>, pls::dataflow::outputs<O0, O...>> :
}
template<int POS, typename T>
void token_pushed(token<T> token) {
auto invocation = get_invocation<invocation_memory>(token.invocation());
std::get<POS>(*invocation->output_buffer_) = token.value();
}
void token_pushed(token<T> token);
graph() : inputs_{}, outputs_{this, this} {}
/////////////////////////////////////////////////////////////////
// Graph building
/////////////////////////////////////////////////////////////////
void build() {
PLS_ASSERT(build_state_ == build_state::fresh, "Must only build a dataflow graph once!")
PLS_ASSERT(is_fully_connected(), "Must fully connect all inputs/outputs inside a dataflow graph!")
build_state_ = build_state::building;
node_list_start_ = node_list_current_ = this;
memory_index_ = 0;
num_nodes_ = 1;
for (int i = 0; i < num_in_ports; i++) {
build_recursive(inputs_.next_node_at(i));
}
build_state_ = build_state::built;
}
void build_recursive(node *node) {
if (node->build_state_ != build_state::fresh) {
return; // Already visited
}
node->build_state_ = build_state::building;
PLS_ASSERT(node->is_fully_connected(), "Must fully connect dataflow graph nodes!")
add_node(node);
for (int i = 0; i < node->num_successors(); i++) {
build_recursive(node->successor_at(i));
}
node->build_state_ = build_state::built;
}
void add_node(node *new_node) {
new_node->memory_index_ = num_nodes_++;
if (node_list_current_ == nullptr) {
node_list_current_ = new_node;
node_list_start_ = new_node;
} else {
node_list_current_->direct_successor_ = new_node;
node_list_current_ = new_node;
}
}
template<int N, typename ...IT>
struct feed_inputs {
feed_inputs(inputs_type &, value_input_tuple &, invocation_info &) {}
void run() {}
};
template<int N, typename IT1, typename ...IT>
struct feed_inputs<N, IT1, IT...> {
inputs_type &inputs_;
value_input_tuple &input_values_;
invocation_info &invocation_;
feed_inputs(inputs_type &inputs,
value_input_tuple &input_values,
invocation_info &invocation) : inputs_{inputs},
input_values_{input_values},
invocation_{invocation} {}
void run() {
inputs_.template get<N>().push_token(token<IT1>{std::get<N>(input_values_), invocation_});
feed_inputs<N + 1, IT...>{inputs_, input_values_, invocation_}.run();
}
};
class run_graph_task : public pls::task {
graph *self_;
value_input_tuple input_;
value_output_tuple *output_;
// Buffers for actual execution
invocation_info invocation_;
public:
run_graph_task(graph *self, value_input_tuple &input, value_output_tuple *output) : self_{self},
input_{input},
output_{output},
invocation_{nullptr} {
void **buffers;
buffers = reinterpret_cast<void **>(allocate_memory(self_->num_nodes_ * sizeof(void *)));
node *iterator = self_->node_list_start_;
for (int i = 0; i < self_->num_nodes_; i++) {
auto required_size = iterator->instance_buffer_size();
buffers[i] = allocate_memory(required_size);
iterator->init_instance_buffer(buffers[i]);
iterator = iterator->direct_successor_;
}
invocation_ = invocation_info{buffers};
self_->get_invocation<invocation_memory>(invocation_)->output_buffer_ = output;
}
void execute_internal() override {
feed_inputs<0, I0, I...>{self_->inputs_, input_, invocation_}.run();
}
};
void build();
void run(value_input_tuple input, value_output_tuple &output) {
pls::scheduler::spawn_child<run_graph_task>(this, input, &output);
}
//////////////////////////////////////////////////////////////////
// Overrides for generic node functionality (building the graph)
//////////////////////////////////////////////////////////////////
int num_successors() const override {
return 0;
}
......@@ -205,10 +100,19 @@ class graph<pls::dataflow::inputs<I0, I...>, pls::dataflow::outputs<O0, O...>> :
node *node_list_start_{nullptr};
node *node_list_current_{nullptr};
int num_nodes_{0};
// Internals required for building and running
void build_recursive(node *node);
void add_node(node *new_node);
template<int N, typename ...IT>
struct feed_inputs;
class run_graph_task;
};
}
}
}
#include "graph_impl.h"
#endif //PLS_DATAFLOW_INTERNAL_GRAPH_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment