scheduler.cpp 2.04 KB
Newer Older
1 2 3 4 5
#include "pls/internal/scheduling/scheduler.h"

namespace pls {
    namespace internal {
        namespace scheduling {
6 7 8 9 10 11 12 13
            scheduler::scheduler(scheduler_memory* memory, const unsigned int num_threads):
                    num_threads_{num_threads},
                    memory_{memory},
                    sync_barrier_{num_threads + 1},
                    terminated_{false} {
                if (num_threads > MAX_THREADS) {
                    exit(1); // TODO: Exception Handling
                }
14

15
                for (unsigned int i = 0; i < num_threads; i++) {
16
                    *memory_->thread_state_for(i) = thread_state{this, memory_->task_stack_for(i), i};
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
                    *memory_->thread_for(i) = base::start_thread(&worker_routine, memory_->thread_state_for(i));
                }
            }

            scheduler::~scheduler() {
                terminate();
            }

            void worker_routine() {
                auto my_state = base::this_thread::state<thread_state>();

                while (true) {
                    my_state->scheduler_->sync_barrier_.wait();
                    if (my_state->scheduler_->terminated_) {
                        return;
                    }

                    // The root task must only return when all work is done,
                    // because of this a simple call is enough to ensure the
                    // fork-join-section is done (logically joined back into our main thread).
                    my_state->root_task_->execute();

                    my_state->scheduler_->sync_barrier_.wait();
                }
            }

            void scheduler::terminate(bool wait_for_workers) {
                if (terminated_) {
                    return;
                }

                terminated_ = true;
                sync_barrier_.wait();

                if (wait_for_workers) {
                    for (unsigned int i = 0; i < num_threads_; i++) {
                        memory_->thread_for(i)->join();
                    }
                }
            }
57 58 59
        }
    }
}