From 7c637562d4b0f6136adb2353580dc9a818ba4ad3 Mon Sep 17 00:00:00 2001 From: FritzFlorian Date: Wed, 31 Jul 2019 12:48:05 +0200 Subject: [PATCH] Add tests for dataflow API. --- lib/pls/include/pls/dataflow/internal/function_node.h | 3 --- lib/pls/include/pls/dataflow/internal/function_node_impl.h | 20 +++++++------------- lib/pls/include/pls/dataflow/internal/graph.h | 8 ++++++-- test/CMakeLists.txt | 3 ++- test/algorithm_test.cpp | 2 +- test/base_tests.cpp | 6 +++--- test/data_structures_test.cpp | 9 ++++----- test/dataflow_test.cpp | 114 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ test/scheduling_tests.cpp | 2 +- 9 files changed, 138 insertions(+), 29 deletions(-) create mode 100644 test/dataflow_test.cpp diff --git a/lib/pls/include/pls/dataflow/internal/function_node.h b/lib/pls/include/pls/dataflow/internal/function_node.h index 8005470..d7490b7 100644 --- a/lib/pls/include/pls/dataflow/internal/function_node.h +++ b/lib/pls/include/pls/dataflow/internal/function_node.h @@ -120,9 +120,6 @@ class function_node, outputs, F> : public node { template struct propagate_output; - template - void set_invocation_info(token &token, invocation_info invocation_info); - template void execute_function_internal(input_tuple &inputs, sequence, output_tuple &outputs, sequence, diff --git a/lib/pls/include/pls/dataflow/internal/function_node_impl.h b/lib/pls/include/pls/dataflow/internal/function_node_impl.h index 36f083b..c5e058e 100644 --- a/lib/pls/include/pls/dataflow/internal/function_node_impl.h +++ b/lib/pls/include/pls/dataflow/internal/function_node_impl.h @@ -59,7 +59,7 @@ template template struct function_node, outputs, F>:: propagate_output { - propagate_output(multi_out_port_type &, output_tuple &) {} + propagate_output(multi_out_port_type &, output_tuple &, invocation_info&) {} void propagate() {} }; template @@ -68,32 +68,26 @@ struct function_node, outputs, F>:: propagate_output { multi_out_port_type &out_port_; output_tuple &output_tuple_; + invocation_info &invocation_info_; - propagate_output(multi_out_port_type &out_port, output_tuple &output_tuple) : out_port_{out_port}, - output_tuple_{output_tuple} {} + propagate_output(multi_out_port_type &out_port, output_tuple &output_tuple, invocation_info& invocation_info) : + out_port_{out_port}, output_tuple_{output_tuple}, invocation_info_{invocation_info} {} void propagate() { + std::get(output_tuple_).set_invocation(invocation_info_); out_port_.template get().push_token(std::get(output_tuple_)); - propagate_output{out_port_, output_tuple_}.propagate(); + propagate_output{out_port_, output_tuple_, invocation_info_}.propagate(); } }; template -template -void function_node, outputs, F>:: -set_invocation_info(token &token, invocation_info invocation_info) { - token.set_invocation(invocation_info); -} - -template template void function_node, outputs, F>:: execute_function_internal(input_tuple &inputs, sequence, output_tuple &outputs, sequence, invocation_info invocation_info) { - set_invocation_info(std::get(outputs)..., invocation_info); function_(std::get(inputs).value()..., std::get(outputs).value()...); - propagate_output<0, O0, O...>{out_port_, outputs}.propagate(); + propagate_output<0, O0, O...>{out_port_, outputs, invocation_info}.propagate(); } } diff --git a/lib/pls/include/pls/dataflow/internal/graph.h b/lib/pls/include/pls/dataflow/internal/graph.h index 75436ad..0004ed3 100644 --- a/lib/pls/include/pls/dataflow/internal/graph.h +++ b/lib/pls/include/pls/dataflow/internal/graph.h @@ -62,7 +62,7 @@ class graph, outputs> : public node { return inputs_.template get(); } - inputs_type& input_ports() { + inputs_type &input_ports() { return inputs_; } @@ -71,7 +71,7 @@ class graph, outputs> : public node { return outputs_.template get(); } - outputs_type& output_ports() { + outputs_type &output_ports() { return outputs_; } @@ -79,6 +79,10 @@ class graph, outputs> : public node { function_node, outputs, FUNC> &operator>>(function_node, outputs, FUNC> &other_node); + void wait_for_all() { + pls::scheduler::wait_for_all(); + } + template void token_pushed(token token); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index f9af66a..9023079 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -2,5 +2,6 @@ add_executable(tests main.cpp data_structures_test.cpp scheduling_tests.cpp - algorithm_test.cpp) + algorithm_test.cpp + dataflow_test.cpp) target_link_libraries(tests catch2 pls) diff --git a/test/algorithm_test.cpp b/test/algorithm_test.cpp index 2909a37..beb0cfe 100644 --- a/test/algorithm_test.cpp +++ b/test/algorithm_test.cpp @@ -1,7 +1,7 @@ #include #include -#include +#include "pls/pls.h" using namespace pls; diff --git a/test/base_tests.cpp b/test/base_tests.cpp index 61435d5..ba3789e 100644 --- a/test/base_tests.cpp +++ b/test/base_tests.cpp @@ -1,7 +1,7 @@ #include -#include -#include -#include +#include "pls/internal/base/thread.h" +#include "pls/internal/base/spin_lock.h" +#include "pls/internal/base/system_details.h" #include #include diff --git a/test/data_structures_test.cpp b/test/data_structures_test.cpp index 2327fb9..4117139 100644 --- a/test/data_structures_test.cpp +++ b/test/data_structures_test.cpp @@ -1,12 +1,11 @@ #include -#include +#include "pls/internal/base/system_details.h" -#include -#include -#include +#include "pls/internal/data_structures/aligned_stack.h" +#include "pls/internal/data_structures/locking_deque.h" +#include "pls/internal/data_structures/work_stealing_deque.h" -#include #include using namespace pls::internal::data_structures; diff --git a/test/dataflow_test.cpp b/test/dataflow_test.cpp new file mode 100644 index 0000000..114c93d --- /dev/null +++ b/test/dataflow_test.cpp @@ -0,0 +1,114 @@ +#include +#include +#include + +#include "pls/pls.h" +#include "pls/dataflow/dataflow.h" + +using namespace pls; +using namespace pls::dataflow; + +void step_1(const int &in, int &out) { + out = in * 2; +} + +class member_call_test { + public: + void step_2(const int &in, int &out) { + out = in * 2; + } +}; + +TEST_CASE("dataflow functions correctly", "[dataflow/dataflow.h]") { + malloc_scheduler_memory my_scheduler_memory{8, 2u << 12u}; + scheduler my_scheduler{&my_scheduler_memory, 8}; + my_scheduler.perform_work([]() { + SECTION("linear pipelines") { + auto step_1 = [](const int &in, double &out1, double &out2) { + out1 = (double) in / 2.0; + out2 = (double) in / 3.0; + }; + auto step_2 = [](const double &in1, const double &in2, double &out) { + out = in1 * in2; + }; + + graph, outputs> linear_graph; + function_node, outputs, decltype(step_1)> node_1{step_1}; + function_node, outputs, decltype(step_2)> node_2{step_2}; + + linear_graph >> node_1 >> node_2 >> linear_graph; + linear_graph.build(); + + std::tuple out{}; + linear_graph.run(5, out); + linear_graph.wait_for_all(); + + REQUIRE(std::get<0>(out) == (5 / 2.0) * (5 / 3.0)); + } + + SECTION("member and function steps") { + member_call_test instance; + using member_func_type = member_function; + member_func_type func_1{&instance, &member_call_test::step_2}; + + graph, outputs> graph; + function_node, outputs, void (*)(const int &, int &)> node_1{&step_1}; + function_node, outputs, member_func_type> node_2{func_1}; + + graph >> node_1 >> node_2 >> graph; + graph.build(); + + std::tuple out{}; + graph.run(1, out); + graph.wait_for_all(); + + REQUIRE(std::get<0>(out) == 4); + } + + SECTION("non linear pipeline") { + auto path_one = [](const int &in, int &out) { + out = in + 1; + }; + auto path_two = [](const int &in, int &out) { + out = in - 1; + }; + + graph, outputs> graph; + function_node, outputs, decltype(path_one)> node_1{path_one}; + function_node, outputs, decltype(path_two)> node_2{path_two}; + switch_node switch_node; + merge_node merge_node; + split_node split; + + // Split up boolean signal + graph.input<1>() >> split.value_in_port(); + + // Feed switch + graph.input<0>() >> switch_node.value_in_port(); + split.out_port_1() >> switch_node.condition_in_port(); + + // True path + switch_node.true_out_port() >> node_1.in_port<0>(); + node_1.out_port<0>() >> merge_node.true_in_port(); + // False path + switch_node.false_out_port() >> node_2.in_port<0>(); + node_2.out_port<0>() >> merge_node.false_in_port(); + + // Read Merge + split.out_port_2() >> merge_node.condition_in_port(); + merge_node.value_out_port() >> graph.output<0>(); + + + // Build and run + graph.build(); + std::tuple out1{}, out2{}; + graph.run({0, true}, out1); + graph.run({0, false}, out2); + graph.wait_for_all(); + + REQUIRE(std::get<0>(out1) == 1); + REQUIRE(std::get<0>(out2) == -1); + } + + }); +} diff --git a/test/scheduling_tests.cpp b/test/scheduling_tests.cpp index 1334d6e..2cce39a 100644 --- a/test/scheduling_tests.cpp +++ b/test/scheduling_tests.cpp @@ -1,6 +1,6 @@ #include -#include +#include "pls/pls.h" using namespace pls; -- libgit2 0.26.0