diff --git a/lib/pls/include/pls/dataflow/internal/function_node_impl.h b/lib/pls/include/pls/dataflow/internal/function_node_impl.h new file mode 100644 index 0000000..3f2d3b0 --- /dev/null +++ b/lib/pls/include/pls/dataflow/internal/function_node_impl.h @@ -0,0 +1,85 @@ + +#ifndef PLS_DATAFLOW_INTERNAL_FUNCTION_NODE_IMPL_H_ +#define PLS_DATAFLOW_INTERNAL_FUNCTION_NODE_IMPL_H_ + +namespace pls { +namespace dataflow { +namespace internal { + +template +template +void function_node, pls::dataflow::outputs, F>:: +token_pushed(token token) { + auto current_memory = get_invocation(token.invocation()); + + std::get(current_memory->input_buffer_) = token; + auto remaining_inputs = --(current_memory->inputs_missing_); + if (remaining_inputs == 0) { + execute_function(current_memory, token.invocation()); + current_memory->inputs_missing_ = num_in_ports; + } +} + +////////////////////////////////////////////////////////////////// +// Helpers for actually calling the work lambda +////////////////////////////////////////////////////////////////// +template +void function_node, pls::dataflow::outputs, F>:: +execute_function(invocation_memory *invocation_memory, invocation_info invocation_info) { + auto lambda = [&]() { + input_tuple &inputs = invocation_memory->input_buffer_; + output_tuple outputs{}; + + execute_function_internal(inputs, typename sequence_gen<1 + sizeof...(I)>::type(), + outputs, typename sequence_gen<1 + sizeof...(O)>::type(), invocation_info); + }; + // TODO: maybe replace this with 'continuation' style invocation + pls::scheduler::spawn_child_and_wait>(lambda); +} + +template +template +struct function_node, pls::dataflow::outputs, F>:: +propagate_output { + propagate_output(multi_out_port_type &, output_tuple &) {} + void propagate() {} +}; +template +template +struct function_node, pls::dataflow::outputs, F>:: +propagate_output { + multi_out_port_type &out_port_; + output_tuple &output_tuple_; + + propagate_output(multi_out_port_type &out_port, output_tuple &output_tuple) : out_port_{out_port}, + output_tuple_{output_tuple} {} + + void propagate() { + out_port_.template get().push_token(std::get(output_tuple_)); + propagate_output{out_port_, output_tuple_}.propagate(); + } +}; + +template +template +void function_node, pls::dataflow::outputs, F>:: +set_invocation_info(token &token, invocation_info invocation_info) { + token.set_invocation(invocation_info); +} + +template +template +void function_node, pls::dataflow::outputs, F>:: +execute_function_internal(input_tuple &inputs, sequence, + output_tuple &outputs, sequence, + invocation_info invocation_info) { + set_invocation_info(std::get(outputs)..., invocation_info); + function_(std::get(inputs).value()..., std::get(outputs).value()...); + propagate_output<0, O0, O...>{out_port_, outputs}.propagate(); +} + +} +} +} + +#endif //PLS_DATAFLOW_INTERNAL_FUNCTION_NODE_IMPL_H_ diff --git a/lib/pls/include/pls/dataflow/internal/graph_impl.h b/lib/pls/include/pls/dataflow/internal/graph_impl.h new file mode 100644 index 0000000..ccdc04c --- /dev/null +++ b/lib/pls/include/pls/dataflow/internal/graph_impl.h @@ -0,0 +1,131 @@ + +#ifndef PLS_DATAFLOW_INTERNAL_GRAPH_IMPL_H_ +#define PLS_DATAFLOW_INTERNAL_GRAPH_IMPL_H_ + +namespace pls { +namespace dataflow { +namespace internal { + +template +template +void graph, pls::dataflow::outputs>:: +token_pushed(token token) { + auto invocation = get_invocation(token.invocation()); + std::get(*invocation->output_buffer_) = token.value(); +} + +template +void graph, pls::dataflow::outputs>:: +build() { + PLS_ASSERT(build_state_ == build_state::fresh, "Must only build a dataflow graph once!") + PLS_ASSERT(is_fully_connected(), "Must fully connect all inputs/outputs inside a dataflow graph!") + + build_state_ = build_state::building; + node_list_start_ = node_list_current_ = this; + memory_index_ = 0; + num_nodes_ = 1; + for (int i = 0; i < num_in_ports; i++) { + build_recursive(inputs_.next_node_at(i)); + } + build_state_ = build_state::built; +} + +template +void graph, pls::dataflow::outputs>:: +build_recursive(node *node) { + if (node->build_state_ != build_state::fresh) { + return; // Already visited + } + node->build_state_ = build_state::building; + PLS_ASSERT(node->is_fully_connected(), "Must fully connect dataflow graph nodes!") + + add_node(node); + for (int i = 0; i < node->num_successors(); i++) { + build_recursive(node->successor_at(i)); + } + + node->build_state_ = build_state::built; +} + +template +void graph, pls::dataflow::outputs>:: +add_node(node *new_node) { + new_node->memory_index_ = num_nodes_++; + + if (node_list_current_ == nullptr) { + node_list_current_ = new_node; + node_list_start_ = new_node; + } else { + node_list_current_->direct_successor_ = new_node; + node_list_current_ = new_node; + } +} + +template +template +struct graph, pls::dataflow::outputs>:: +feed_inputs { + feed_inputs(inputs_type &, value_input_tuple &, invocation_info &) {} + void run() {} +}; + +template +template +struct graph, pls::dataflow::outputs>:: +feed_inputs { + inputs_type &inputs_; + value_input_tuple &input_values_; + invocation_info &invocation_; + + feed_inputs(inputs_type &inputs, + value_input_tuple &input_values, + invocation_info &invocation) : inputs_{inputs}, + input_values_{input_values}, + invocation_{invocation} {} + + void run() { + inputs_.template get().push_token(token{std::get(input_values_), invocation_}); + feed_inputs{inputs_, input_values_, invocation_}.run(); + } +}; + +template +class graph, pls::dataflow::outputs>::run_graph_task : public pls::task { + graph *self_; + value_input_tuple input_; + value_output_tuple *output_; + + // Buffers for actual execution + invocation_info invocation_; + + public: + run_graph_task(graph *self, value_input_tuple &input, value_output_tuple *output) : self_{self}, + input_{input}, + output_{output}, + invocation_{nullptr} { + void **buffers; + buffers = reinterpret_cast(allocate_memory(self_->num_nodes_ * sizeof(void *))); + + node *iterator = self_->node_list_start_; + for (int i = 0; i < self_->num_nodes_; i++) { + auto required_size = iterator->instance_buffer_size(); + buffers[i] = allocate_memory(required_size); + iterator->init_instance_buffer(buffers[i]); + + iterator = iterator->direct_successor_; + } + invocation_ = invocation_info{buffers}; + + self_->get_invocation(invocation_)->output_buffer_ = output; + } + + void execute_internal() override { + feed_inputs<0, I0, I...>{self_->inputs_, input_, invocation_}.run(); + } +}; + +} +} +} + +#endif //PLS_DATAFLOW_INTERNAL_GRAPH_IMPL_H_