scheduler.cpp 2.14 KB
Newer Older
1
#include "pls/internal/scheduling/scheduler.h"
2
#include "pls/internal/base/error_handling.h"
3 4 5 6

namespace pls {
    namespace internal {
        namespace scheduling {
7 8 9 10 11 12
            scheduler::scheduler(scheduler_memory* memory, const unsigned int num_threads):
                    num_threads_{num_threads},
                    memory_{memory},
                    sync_barrier_{num_threads + 1},
                    terminated_{false} {
                if (num_threads > MAX_THREADS) {
13
                    PLS_ERROR("Tried to create scheduler with more OS threads than pre-allocated memory.");
14
                }
15

16
                for (unsigned int i = 0; i < num_threads; i++) {
17
                    *memory_->thread_state_for(i) = thread_state{this, memory_->task_stack_for(i), i};
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
                    *memory_->thread_for(i) = base::start_thread(&worker_routine, memory_->thread_state_for(i));
                }
            }

            scheduler::~scheduler() {
                terminate();
            }

            void worker_routine() {
                auto my_state = base::this_thread::state<thread_state>();

                while (true) {
                    my_state->scheduler_->sync_barrier_.wait();
                    if (my_state->scheduler_->terminated_) {
                        return;
                    }

                    // The root task must only return when all work is done,
                    // because of this a simple call is enough to ensure the
                    // fork-join-section is done (logically joined back into our main thread).
                    my_state->root_task_->execute();

                    my_state->scheduler_->sync_barrier_.wait();
                }
            }

            void scheduler::terminate(bool wait_for_workers) {
                if (terminated_) {
                    return;
                }

                terminated_ = true;
                sync_barrier_.wait();

                if (wait_for_workers) {
                    for (unsigned int i = 0; i < num_threads_; i++) {
                        memory_->thread_for(i)->join();
                    }
                }
            }
58 59 60
        }
    }
}