dataflow_test.cpp 3.33 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
#include <catch.hpp>
#include <array>
#include <tuple>

#include "pls/pls.h"
#include "pls/dataflow/dataflow.h"

using namespace pls;
using namespace pls::dataflow;

void step_1(const int &in, int &out) {
  out = in * 2;
}

class member_call_test {
 public:
  void step_2(const int &in, int &out) {
    out = in * 2;
  }
};

TEST_CASE("dataflow functions correctly", "[dataflow/dataflow.h]") {
  malloc_scheduler_memory my_scheduler_memory{8, 2u << 12u};
  scheduler my_scheduler{&my_scheduler_memory, 8};
  my_scheduler.perform_work([]() {
    SECTION("linear pipelines") {
      auto step_1 = [](const int &in, double &out1, double &out2) {
        out1 = (double) in / 2.0;
        out2 = (double) in / 3.0;
      };
      auto step_2 = [](const double &in1, const double &in2, double &out) {
        out = in1 * in2;
      };

      graph<inputs<int>, outputs<double>> linear_graph;
      function_node<inputs<int>, outputs<double, double>, decltype(step_1)> node_1{step_1};
      function_node<inputs<double, double>, outputs<double>, decltype(step_2)> node_2{step_2};

      linear_graph >> node_1 >> node_2 >> linear_graph;
      linear_graph.build();

      std::tuple<double> out{};
      linear_graph.run(5, out);
      linear_graph.wait_for_all();

      REQUIRE(std::get<0>(out) == (5 / 2.0) * (5 / 3.0));
    }

    SECTION("member and function steps") {
      member_call_test instance;
      using member_func_type = member_function<member_call_test, void, const int &, int &>;
      member_func_type func_1{&instance, &member_call_test::step_2};

      graph<inputs<int>, outputs<int>> graph;
      function_node<inputs<int>, outputs<int>, void (*)(const int &, int &)> node_1{&step_1};
      function_node<inputs<int>, outputs<int>, member_func_type> node_2{func_1};

      graph >> node_1 >> node_2 >> graph;
      graph.build();

      std::tuple<int> out{};
      graph.run(1, out);
      graph.wait_for_all();

      REQUIRE(std::get<0>(out) == 4);
    }

    SECTION("non linear pipeline") {
      auto path_one = [](const int &in, int &out) {
        out = in + 1;
      };
      auto path_two = [](const int &in, int &out) {
        out = in - 1;
      };

      graph<inputs<int, bool>, outputs<int>> graph;
      function_node<inputs<int>, outputs<int>, decltype(path_one)> node_1{path_one};
      function_node<inputs<int>, outputs<int>, decltype(path_two)> node_2{path_two};
      switch_node<int> switch_node;
      merge_node<int> merge_node;
      split_node<bool> split;

      // Split up boolean signal
      graph.input<1>() >> split.value_in_port();

      // Feed switch
      graph.input<0>() >> switch_node.value_in_port();
      split.out_port_1() >> switch_node.condition_in_port();

      // True path
      switch_node.true_out_port() >> node_1.in_port<0>();
      node_1.out_port<0>() >> merge_node.true_in_port();
      // False path
      switch_node.false_out_port() >> node_2.in_port<0>();
      node_2.out_port<0>() >> merge_node.false_in_port();

      // Read Merge
      split.out_port_2() >> merge_node.condition_in_port();
      merge_node.value_out_port() >> graph.output<0>();


      // Build and run
      graph.build();
      std::tuple<int> out1{}, out2{};
      graph.run({0, true}, out1);
      graph.run({0, false}, out2);
      graph.wait_for_all();

      REQUIRE(std::get<0>(out1) == 1);
      REQUIRE(std::get<0>(out2) == -1);
    }

  });
}