scheduling_tests.cpp 2.37 KB
Newer Older
1 2 3 4 5 6
#include <catch.hpp>

#include <pls/pls.h>

using namespace pls;

7
class once_sub_task: public fork_join_sub_task {
8 9 10 11 12 13 14 15 16 17 18 19 20
    std::atomic<int>* counter_;
    int children_;

protected:
    void execute_internal() override {
        (*counter_)++;
        for (int i = 0; i < children_; i++) {
            spawn_child(once_sub_task(counter_, children_ - 1));
        }
    }

public:
    explicit once_sub_task(std::atomic<int>* counter, int children):
21
        fork_join_sub_task(),
22 23 24 25
        counter_{counter},
        children_{children} {}
};

26
class force_steal_sub_task: public fork_join_sub_task {
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
    std::atomic<int>* parent_counter_;
    std::atomic<int>* overall_counter_;

protected:
    void execute_internal() override {
        (*overall_counter_)--;
        if (overall_counter_->load() > 0) {
            std::atomic<int> counter{1};
            spawn_child(force_steal_sub_task(&counter, overall_counter_));
            while (counter.load() > 0)
                ; // Spin...
        }

        (*parent_counter_)--;
    }

public:
    explicit force_steal_sub_task(std::atomic<int>* parent_counter, std::atomic<int>* overall_counter):
45
            fork_join_sub_task(),
46 47 48 49
            parent_counter_{parent_counter},
            overall_counter_{overall_counter} {}
};

50
TEST_CASE( "tbb task are scheduled correctly", "[internal/scheduling/fork_join_task.h]") {
51 52 53 54 55 56 57 58 59 60
    static static_scheduler_memory<8, 2 << 12> my_scheduler_memory;

    SECTION("tasks are executed exactly once") {
        scheduler my_scheduler{&my_scheduler_memory, 2};
        int start_counter = 4;
        int total_tasks = 1 + 4 + 4 * 3 + 4 * 3 * 2 + 4 * 3 * 2 * 1;
        std::atomic<int> counter{0};

        my_scheduler.perform_work([&] (){
            once_sub_task sub_task{&counter, start_counter};
61
            fork_join_task task{&sub_task, task_id{42}};
62 63 64 65 66 67 68 69 70 71 72 73
            scheduler::execute_task(task);
        });

        REQUIRE(counter.load() == total_tasks);
        my_scheduler.terminate(true);
    }

    SECTION("tasks can be stolen") {
        scheduler my_scheduler{&my_scheduler_memory, 8};
        my_scheduler.perform_work([&] (){
            std::atomic<int> dummy_parent{1}, overall_counter{8};
            force_steal_sub_task sub_task{&dummy_parent, &overall_counter};
74
            fork_join_task task{&sub_task, task_id{42}};
75 76 77 78 79
            scheduler::execute_task(task);
        });
        my_scheduler.terminate(true);
    }
}