Commit 7c637562 by FritzFlorian

Add tests for dataflow API.

parent 8d8f0ac2
Pipeline #1290 passed with stages
in 4 minutes 4 seconds
...@@ -120,9 +120,6 @@ class function_node<inputs<I0, I...>, outputs<O0, O...>, F> : public node { ...@@ -120,9 +120,6 @@ class function_node<inputs<I0, I...>, outputs<O0, O...>, F> : public node {
template<int N, typename ...OT> template<int N, typename ...OT>
struct propagate_output; struct propagate_output;
template<typename T>
void set_invocation_info(token<T> &token, invocation_info invocation_info);
template<int ...IS, int ...OS> template<int ...IS, int ...OS>
void execute_function_internal(input_tuple &inputs, sequence<IS...>, void execute_function_internal(input_tuple &inputs, sequence<IS...>,
output_tuple &outputs, sequence<OS...>, output_tuple &outputs, sequence<OS...>,
......
...@@ -59,7 +59,7 @@ template<typename I0, typename ...I, typename O0, typename ...O, typename F> ...@@ -59,7 +59,7 @@ template<typename I0, typename ...I, typename O0, typename ...O, typename F>
template<int N, typename ...OT> template<int N, typename ...OT>
struct function_node<inputs<I0, I...>, outputs<O0, O...>, F>:: struct function_node<inputs<I0, I...>, outputs<O0, O...>, F>::
propagate_output { propagate_output {
propagate_output(multi_out_port_type &, output_tuple &) {} propagate_output(multi_out_port_type &, output_tuple &, invocation_info&) {}
void propagate() {} void propagate() {}
}; };
template<typename I0, typename ...I, typename O0, typename ...O, typename F> template<typename I0, typename ...I, typename O0, typename ...O, typename F>
...@@ -68,32 +68,26 @@ struct function_node<inputs<I0, I...>, outputs<O0, O...>, F>:: ...@@ -68,32 +68,26 @@ struct function_node<inputs<I0, I...>, outputs<O0, O...>, F>::
propagate_output<N, OT1, OT...> { propagate_output<N, OT1, OT...> {
multi_out_port_type &out_port_; multi_out_port_type &out_port_;
output_tuple &output_tuple_; output_tuple &output_tuple_;
invocation_info &invocation_info_;
propagate_output(multi_out_port_type &out_port, output_tuple &output_tuple) : out_port_{out_port}, propagate_output(multi_out_port_type &out_port, output_tuple &output_tuple, invocation_info& invocation_info) :
output_tuple_{output_tuple} {} out_port_{out_port}, output_tuple_{output_tuple}, invocation_info_{invocation_info} {}
void propagate() { void propagate() {
std::get<N>(output_tuple_).set_invocation(invocation_info_);
out_port_.template get<N>().push_token(std::get<N>(output_tuple_)); out_port_.template get<N>().push_token(std::get<N>(output_tuple_));
propagate_output<N + 1, OT...>{out_port_, output_tuple_}.propagate(); propagate_output<N + 1, OT...>{out_port_, output_tuple_, invocation_info_}.propagate();
} }
}; };
template<typename I0, typename ...I, typename O0, typename ...O, typename F> template<typename I0, typename ...I, typename O0, typename ...O, typename F>
template<typename T>
void function_node<inputs<I0, I...>, outputs<O0, O...>, F>::
set_invocation_info(token<T> &token, invocation_info invocation_info) {
token.set_invocation(invocation_info);
}
template<typename I0, typename ...I, typename O0, typename ...O, typename F>
template<int ...IS, int ...OS> template<int ...IS, int ...OS>
void function_node<inputs<I0, I...>, outputs<O0, O...>, F>:: void function_node<inputs<I0, I...>, outputs<O0, O...>, F>::
execute_function_internal(input_tuple &inputs, sequence<IS...>, execute_function_internal(input_tuple &inputs, sequence<IS...>,
output_tuple &outputs, sequence<OS...>, output_tuple &outputs, sequence<OS...>,
invocation_info invocation_info) { invocation_info invocation_info) {
set_invocation_info(std::get<OS>(outputs)..., invocation_info);
function_(std::get<IS>(inputs).value()..., std::get<OS>(outputs).value()...); function_(std::get<IS>(inputs).value()..., std::get<OS>(outputs).value()...);
propagate_output<0, O0, O...>{out_port_, outputs}.propagate(); propagate_output<0, O0, O...>{out_port_, outputs, invocation_info}.propagate();
} }
} }
......
...@@ -62,7 +62,7 @@ class graph<inputs<I0, I...>, outputs<O0, O...>> : public node { ...@@ -62,7 +62,7 @@ class graph<inputs<I0, I...>, outputs<O0, O...>> : public node {
return inputs_.template get<POS>(); return inputs_.template get<POS>();
} }
inputs_type& input_ports() { inputs_type &input_ports() {
return inputs_; return inputs_;
} }
...@@ -71,7 +71,7 @@ class graph<inputs<I0, I...>, outputs<O0, O...>> : public node { ...@@ -71,7 +71,7 @@ class graph<inputs<I0, I...>, outputs<O0, O...>> : public node {
return outputs_.template get<POS>(); return outputs_.template get<POS>();
} }
outputs_type& output_ports() { outputs_type &output_ports() {
return outputs_; return outputs_;
} }
...@@ -79,6 +79,10 @@ class graph<inputs<I0, I...>, outputs<O0, O...>> : public node { ...@@ -79,6 +79,10 @@ class graph<inputs<I0, I...>, outputs<O0, O...>> : public node {
function_node<inputs<I0, I...>, outputs<OS...>, FUNC> function_node<inputs<I0, I...>, outputs<OS...>, FUNC>
&operator>>(function_node<inputs<I0, I...>, outputs<OS...>, FUNC> &other_node); &operator>>(function_node<inputs<I0, I...>, outputs<OS...>, FUNC> &other_node);
void wait_for_all() {
pls::scheduler::wait_for_all();
}
template<int POS, typename T> template<int POS, typename T>
void token_pushed(token<T> token); void token_pushed(token<T> token);
......
...@@ -2,5 +2,6 @@ add_executable(tests ...@@ -2,5 +2,6 @@ add_executable(tests
main.cpp main.cpp
data_structures_test.cpp data_structures_test.cpp
scheduling_tests.cpp scheduling_tests.cpp
algorithm_test.cpp) algorithm_test.cpp
dataflow_test.cpp)
target_link_libraries(tests catch2 pls) target_link_libraries(tests catch2 pls)
#include <catch.hpp> #include <catch.hpp>
#include <array> #include <array>
#include <pls/pls.h> #include "pls/pls.h"
using namespace pls; using namespace pls;
......
#include <catch.hpp> #include <catch.hpp>
#include <pls/internal/base/thread.h> #include "pls/internal/base/thread.h"
#include <pls/internal/base/spin_lock.h> #include "pls/internal/base/spin_lock.h"
#include <pls/internal/base/system_details.h> #include "pls/internal/base/system_details.h"
#include <vector> #include <vector>
#include <mutex> #include <mutex>
......
#include <catch.hpp> #include <catch.hpp>
#include <pls/internal/base/system_details.h> #include "pls/internal/base/system_details.h"
#include <pls/internal/data_structures/aligned_stack.h> #include "pls/internal/data_structures/aligned_stack.h"
#include <pls/internal/data_structures/locking_deque.h> #include "pls/internal/data_structures/locking_deque.h"
#include <pls/internal/data_structures/work_stealing_deque.h> #include "pls/internal/data_structures/work_stealing_deque.h"
#include <vector>
#include <mutex> #include <mutex>
using namespace pls::internal::data_structures; using namespace pls::internal::data_structures;
......
#include <catch.hpp>
#include <array>
#include <tuple>
#include "pls/pls.h"
#include "pls/dataflow/dataflow.h"
using namespace pls;
using namespace pls::dataflow;
void step_1(const int &in, int &out) {
out = in * 2;
}
class member_call_test {
public:
void step_2(const int &in, int &out) {
out = in * 2;
}
};
TEST_CASE("dataflow functions correctly", "[dataflow/dataflow.h]") {
malloc_scheduler_memory my_scheduler_memory{8, 2u << 12u};
scheduler my_scheduler{&my_scheduler_memory, 8};
my_scheduler.perform_work([]() {
SECTION("linear pipelines") {
auto step_1 = [](const int &in, double &out1, double &out2) {
out1 = (double) in / 2.0;
out2 = (double) in / 3.0;
};
auto step_2 = [](const double &in1, const double &in2, double &out) {
out = in1 * in2;
};
graph<inputs<int>, outputs<double>> linear_graph;
function_node<inputs<int>, outputs<double, double>, decltype(step_1)> node_1{step_1};
function_node<inputs<double, double>, outputs<double>, decltype(step_2)> node_2{step_2};
linear_graph >> node_1 >> node_2 >> linear_graph;
linear_graph.build();
std::tuple<double> out{};
linear_graph.run(5, out);
linear_graph.wait_for_all();
REQUIRE(std::get<0>(out) == (5 / 2.0) * (5 / 3.0));
}
SECTION("member and function steps") {
member_call_test instance;
using member_func_type = member_function<member_call_test, void, const int &, int &>;
member_func_type func_1{&instance, &member_call_test::step_2};
graph<inputs<int>, outputs<int>> graph;
function_node<inputs<int>, outputs<int>, void (*)(const int &, int &)> node_1{&step_1};
function_node<inputs<int>, outputs<int>, member_func_type> node_2{func_1};
graph >> node_1 >> node_2 >> graph;
graph.build();
std::tuple<int> out{};
graph.run(1, out);
graph.wait_for_all();
REQUIRE(std::get<0>(out) == 4);
}
SECTION("non linear pipeline") {
auto path_one = [](const int &in, int &out) {
out = in + 1;
};
auto path_two = [](const int &in, int &out) {
out = in - 1;
};
graph<inputs<int, bool>, outputs<int>> graph;
function_node<inputs<int>, outputs<int>, decltype(path_one)> node_1{path_one};
function_node<inputs<int>, outputs<int>, decltype(path_two)> node_2{path_two};
switch_node<int> switch_node;
merge_node<int> merge_node;
split_node<bool> split;
// Split up boolean signal
graph.input<1>() >> split.value_in_port();
// Feed switch
graph.input<0>() >> switch_node.value_in_port();
split.out_port_1() >> switch_node.condition_in_port();
// True path
switch_node.true_out_port() >> node_1.in_port<0>();
node_1.out_port<0>() >> merge_node.true_in_port();
// False path
switch_node.false_out_port() >> node_2.in_port<0>();
node_2.out_port<0>() >> merge_node.false_in_port();
// Read Merge
split.out_port_2() >> merge_node.condition_in_port();
merge_node.value_out_port() >> graph.output<0>();
// Build and run
graph.build();
std::tuple<int> out1{}, out2{};
graph.run({0, true}, out1);
graph.run({0, false}, out2);
graph.wait_for_all();
REQUIRE(std::get<0>(out1) == 1);
REQUIRE(std::get<0>(out2) == -1);
}
});
}
#include <catch.hpp> #include <catch.hpp>
#include <pls/pls.h> #include "pls/pls.h"
using namespace pls; using namespace pls;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment