diff --git a/CMakeLists.txt b/CMakeLists.txt index 941c725..0a66e63 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,7 +7,7 @@ set(CMAKE_CXX_STANDARD 11) # seperate library and test/example executable output paths. set(EXECUTABLE_OUTPUT_PATH ${CMAKE_BINARY_DIR}/bin) -set(LIBRARY_OUTPUT_PATH ${CMAKE_BINARY_DIR}/lib) +set(LIBRARY_OUTPUT_PATH ${CMAKE_BINARY_DIR}/lib) # specific setup code is located in individual files. include(cmake/DisabelInSource.cmake) @@ -34,11 +34,12 @@ add_subdirectory(app/playground) add_subdirectory(app/test_for_new) add_subdirectory(app/invoke_parallel) add_subdirectory(app/benchmark_fft) +add_subdirectory(app/benchmark_unbalanced) # Add optional tests option(PACKAGE_TESTS "Build the tests" ON) -if(PACKAGE_TESTS) +if (PACKAGE_TESTS) enable_testing() add_subdirectory(test) add_test(NAME AllTests COMMAND tests) -endif() +endif () diff --git a/app/benchmark_fft/main.cpp b/app/benchmark_fft/main.cpp index 54abdd1..f6ed20e 100644 --- a/app/benchmark_fft/main.cpp +++ b/app/benchmark_fft/main.cpp @@ -8,7 +8,7 @@ static constexpr int CUTOFF = 16; static constexpr int NUM_ITERATIONS = 1000; -static constexpr int INPUT_SIZE = 2064; +static constexpr int INPUT_SIZE = 8192; typedef std::vector> complex_vector; void divide(complex_vector::iterator data, int n) { diff --git a/app/benchmark_unbalanced/CMakeLists.txt b/app/benchmark_unbalanced/CMakeLists.txt new file mode 100644 index 0000000..00c95ab --- /dev/null +++ b/app/benchmark_unbalanced/CMakeLists.txt @@ -0,0 +1,5 @@ +add_executable(benchmark_unbalanced main.cpp node.h node.cpp picosha2.h) +target_link_libraries(benchmark_unbalanced pls) +if (EASY_PROFILER) + target_link_libraries(benchmark_unbalanced easy_profiler) +endif () diff --git a/app/benchmark_unbalanced/LICENSE_PICOSA2 b/app/benchmark_unbalanced/LICENSE_PICOSA2 new file mode 100644 index 0000000..4e22100 --- /dev/null +++ b/app/benchmark_unbalanced/LICENSE_PICOSA2 @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2017 okdshin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/app/benchmark_unbalanced/main.cpp b/app/benchmark_unbalanced/main.cpp new file mode 100644 index 0000000..27cdaf3 --- /dev/null +++ b/app/benchmark_unbalanced/main.cpp @@ -0,0 +1,78 @@ +#include +#include +#include + +#include "node.h" + +const int SEED = 42; +const int ROOT_CHILDREN = 140; +const double Q = 0.124875; +const int NORMAL_CHILDREN = 8; + +const int NUM_NODES = 71069; + +int count_child_nodes(uts::node &node) { + int child_count = 1; + std::vector children = node.spawn_child_nodes(); + + if (children.empty()) { + return child_count; + } + + auto current_task = pls::fork_join_sub_task::current(); + std::vector results(children.size()); + for (size_t i = 0; i < children.size(); i++) { + size_t index = i; + auto lambda = [&, index] { results[index] = count_child_nodes(children[index]); }; + pls::fork_join_lambda_by_value sub_task(lambda); + current_task->spawn_child(sub_task); + } + current_task->wait_for_all(); + for (auto result : results) { + child_count += result; + } + + return child_count; +} + +int unbalanced_tree_search(int seed, int root_children, double q, int normal_children) { + static auto id = pls::unique_id::create(42); + int result; + + auto lambda = [&] { + uts::node root(seed, root_children, q, normal_children); + result = count_child_nodes(root); + }; + pls::fork_join_lambda_by_reference task(lambda); + pls::fork_join_lambda_by_reference sub_task(lambda); + pls::fork_join_task root_task{&sub_task, id}; + pls::scheduler::execute_task(root_task); + + return result; +} + +int main() { + PROFILE_ENABLE + pls::internal::helpers::run_mini_benchmark([&] { + unbalanced_tree_search(SEED, ROOT_CHILDREN, Q, NORMAL_CHILDREN); + }, 8, 4000); + + PROFILE_SAVE("test_profile.prof") +} + +//int main() { +// PROFILE_ENABLE +// pls::malloc_scheduler_memory my_scheduler_memory{8, 2u << 18}; +// pls::scheduler scheduler{&my_scheduler_memory, 8}; +// +// scheduler.perform_work([&] { +// PROFILE_MAIN_THREAD +// for (int i = 0; i < 10; i++) { +// PROFILE_WORK_BLOCK("Top Level") +// int result = unbalanced_tree_search(SEED, ROOT_CHILDREN, Q, NORMAL_CHILDREN); +// std::cout << result << std::endl; +// } +// }); +// +// PROFILE_SAVE("test_profile.prof") +//} diff --git a/app/benchmark_unbalanced/node.cpp b/app/benchmark_unbalanced/node.cpp new file mode 100644 index 0000000..1cb931e --- /dev/null +++ b/app/benchmark_unbalanced/node.cpp @@ -0,0 +1,28 @@ +#include "node.h" + +namespace uts { +node_state node::generate_child_state(uint32_t index) { + node_state result; + + picosha2::hash256_one_by_one hasher; + hasher.process(state_.begin(), state_.end()); + auto index_begin = reinterpret_cast(&index); + hasher.process(index_begin, index_begin + 4); + hasher.finish(); + hasher.get_hash_bytes(result.begin(), result.end()); + + return result; +} + +double node::get_state_random() { + int32_t state_random_integer; + uint32_t b = ((uint32_t) state_[16] << 24) | + ((uint32_t) state_[17] << 16) | + ((uint32_t) state_[18] << 8) | + ((uint32_t) state_[19] << 0); + b = b & 0x7fffffff; // Mask out negative values + state_random_integer = static_cast(b); + + return (double) state_random_integer / (double) INT32_MAX; +} +} diff --git a/app/benchmark_unbalanced/node.h b/app/benchmark_unbalanced/node.h new file mode 100644 index 0000000..5111059 --- /dev/null +++ b/app/benchmark_unbalanced/node.h @@ -0,0 +1,73 @@ + +#ifndef UTS_NODE_H +#define UTS_NODE_H + +#include +#include +#include + +#include "picosha2.h" + +namespace uts { +using node_state = std::array; + +/** + * Node of an unballanced binomial tree (https://www.cs.unc.edu/~olivier/LCPC06.pdf). + * To build up the tree recursivly call spawn_child_nodes on each node until leaves are reached. + * The tree is not built up directly in memory, but rather by the recursive calls. + */ +class node { + // The state is used to allow a deterministic tree construction using sha256 hashes. + node_state state_; + + // Set this to a positive number for the root node to start the tree with a specific size + int root_children_; + + // general branching factors + double q_; + int b_; + + // Private constructor for children + node(node_state state, double q, int b) : state_{state}, root_children_{-1}, q_{q}, b_{b} {} + + std::array generate_child_state(uint32_t index); + double get_state_random(); + + public: + node(int seed, int root_children, double q, int b) : state_({{}}), root_children_{root_children}, q_{q}, b_{b} { + for (int i = 0; i < 16; i++) { + state_[i] = 0; + } + state_[16] = static_cast(0xFF & (seed >> 24)); + state_[17] = static_cast(0xFF & (seed >> 16)); + state_[18] = static_cast(0xFF & (seed >> 8)); + state_[19] = static_cast(0xFF & (seed >> 0)); + + picosha2::hash256_one_by_one hasher; + hasher.process(state_.begin(), state_.end()); + hasher.finish(); + hasher.get_hash_bytes(state_.begin(), state_.end()); + } + + std::vector spawn_child_nodes() { + double state_random = get_state_random(); + int num_children; + if (root_children_ > 0) { + num_children = root_children_; // Root always spawns children + } else if (state_random < q_) { + num_children = b_; + } else { + num_children = 0; + } + + std::vector result; + for (int i = 0; i < num_children; i++) { + result.push_back(node(generate_child_state(i), q_, b_)); + } + + return result; + } +}; +} + +#endif //UTS_NODE_H diff --git a/app/benchmark_unbalanced/picosha2.h b/app/benchmark_unbalanced/picosha2.h new file mode 100644 index 0000000..bc00c74 --- /dev/null +++ b/app/benchmark_unbalanced/picosha2.h @@ -0,0 +1,377 @@ +/* +The MIT License (MIT) + +Copyright (C) 2017 okdshin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ +#ifndef PICOSHA2_H +#define PICOSHA2_H +// picosha2:20140213 + +#ifndef PICOSHA2_BUFFER_SIZE_FOR_INPUT_ITERATOR +#define PICOSHA2_BUFFER_SIZE_FOR_INPUT_ITERATOR \ + 1048576 //=1024*1024: default is 1MB memory +#endif + +#include +#include +#include +#include +#include +#include +namespace picosha2 { +typedef unsigned long word_t; +typedef unsigned char byte_t; + +static const size_t k_digest_size = 32; + +namespace detail { +inline byte_t mask_8bit(byte_t x) { return x & 0xff; } + +inline word_t mask_32bit(word_t x) { return x & 0xffffffff; } + +const word_t add_constant[64] = { + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, + 0x923f82a4, 0xab1c5ed5, 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, + 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, 0xe49b69c1, 0xefbe4786, + 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, + 0x06ca6351, 0x14292967, 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, + 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, 0xa2bfe8a1, 0xa81a664b, + 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, + 0x5b9cca4f, 0x682e6ff3, 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, + 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2}; + +const word_t initial_message_digest[8] = {0x6a09e667, 0xbb67ae85, 0x3c6ef372, + 0xa54ff53a, 0x510e527f, 0x9b05688c, + 0x1f83d9ab, 0x5be0cd19}; + +inline word_t ch(word_t x, word_t y, word_t z) { return (x & y) ^ ((~x) & z); } + +inline word_t maj(word_t x, word_t y, word_t z) { + return (x & y) ^ (x & z) ^ (y & z); +} + +inline word_t rotr(word_t x, std::size_t n) { + assert(n < 32); + return mask_32bit((x >> n) | (x << (32 - n))); +} + +inline word_t bsig0(word_t x) { return rotr(x, 2) ^ rotr(x, 13) ^ rotr(x, 22); } + +inline word_t bsig1(word_t x) { return rotr(x, 6) ^ rotr(x, 11) ^ rotr(x, 25); } + +inline word_t shr(word_t x, std::size_t n) { + assert(n < 32); + return x >> n; +} + +inline word_t ssig0(word_t x) { return rotr(x, 7) ^ rotr(x, 18) ^ shr(x, 3); } + +inline word_t ssig1(word_t x) { return rotr(x, 17) ^ rotr(x, 19) ^ shr(x, 10); } + +template +void hash256_block(RaIter1 message_digest, RaIter2 first, RaIter2 last) { + assert(first + 64 == last); + static_cast(last); // for avoiding unused-variable warning + word_t w[64]; + std::fill(w, w + 64, 0); + for (std::size_t i = 0; i < 16; ++i) { + w[i] = (static_cast(mask_8bit(*(first + i * 4))) << 24) | + (static_cast(mask_8bit(*(first + i * 4 + 1))) << 16) | + (static_cast(mask_8bit(*(first + i * 4 + 2))) << 8) | + (static_cast(mask_8bit(*(first + i * 4 + 3)))); + } + for (std::size_t i = 16; i < 64; ++i) { + w[i] = mask_32bit(ssig1(w[i - 2]) + w[i - 7] + ssig0(w[i - 15]) + + w[i - 16]); + } + + word_t a = *message_digest; + word_t b = *(message_digest + 1); + word_t c = *(message_digest + 2); + word_t d = *(message_digest + 3); + word_t e = *(message_digest + 4); + word_t f = *(message_digest + 5); + word_t g = *(message_digest + 6); + word_t h = *(message_digest + 7); + + for (std::size_t i = 0; i < 64; ++i) { + word_t temp1 = h + bsig1(e) + ch(e, f, g) + add_constant[i] + w[i]; + word_t temp2 = bsig0(a) + maj(a, b, c); + h = g; + g = f; + f = e; + e = mask_32bit(d + temp1); + d = c; + c = b; + b = a; + a = mask_32bit(temp1 + temp2); + } + *message_digest += a; + *(message_digest + 1) += b; + *(message_digest + 2) += c; + *(message_digest + 3) += d; + *(message_digest + 4) += e; + *(message_digest + 5) += f; + *(message_digest + 6) += g; + *(message_digest + 7) += h; + for (std::size_t i = 0; i < 8; ++i) { + *(message_digest + i) = mask_32bit(*(message_digest + i)); + } +} + +} // namespace detail + +template +void output_hex(InIter first, InIter last, std::ostream& os) { + os.setf(std::ios::hex, std::ios::basefield); + while (first != last) { + os.width(2); + os.fill('0'); + os << static_cast(*first); + ++first; + } + os.setf(std::ios::dec, std::ios::basefield); +} + +template +void bytes_to_hex_string(InIter first, InIter last, std::string& hex_str) { + std::ostringstream oss; + output_hex(first, last, oss); + hex_str.assign(oss.str()); +} + +template +void bytes_to_hex_string(const InContainer& bytes, std::string& hex_str) { + bytes_to_hex_string(bytes.begin(), bytes.end(), hex_str); +} + +template +std::string bytes_to_hex_string(InIter first, InIter last) { + std::string hex_str; + bytes_to_hex_string(first, last, hex_str); + return hex_str; +} + +template +std::string bytes_to_hex_string(const InContainer& bytes) { + std::string hex_str; + bytes_to_hex_string(bytes, hex_str); + return hex_str; +} + +class hash256_one_by_one { + public: + hash256_one_by_one() { init(); } + + void init() { + buffer_.clear(); + std::fill(data_length_digits_, data_length_digits_ + 4, 0); + std::copy(detail::initial_message_digest, + detail::initial_message_digest + 8, h_); + } + + template + void process(RaIter first, RaIter last) { + add_to_data_length(static_cast(std::distance(first, last))); + std::copy(first, last, std::back_inserter(buffer_)); + std::size_t i = 0; + for (; i + 64 <= buffer_.size(); i += 64) { + detail::hash256_block(h_, buffer_.begin() + i, + buffer_.begin() + i + 64); + } + buffer_.erase(buffer_.begin(), buffer_.begin() + i); + } + + void finish() { + byte_t temp[64]; + std::fill(temp, temp + 64, 0); + std::size_t remains = buffer_.size(); + std::copy(buffer_.begin(), buffer_.end(), temp); + temp[remains] = 0x80; + + if (remains > 55) { + std::fill(temp + remains + 1, temp + 64, 0); + detail::hash256_block(h_, temp, temp + 64); + std::fill(temp, temp + 64 - 4, 0); + } else { + std::fill(temp + remains + 1, temp + 64 - 4, 0); + } + + write_data_bit_length(&(temp[56])); + detail::hash256_block(h_, temp, temp + 64); + } + + template + void get_hash_bytes(OutIter first, OutIter last) const { + for (const word_t* iter = h_; iter != h_ + 8; ++iter) { + for (std::size_t i = 0; i < 4 && first != last; ++i) { + *(first++) = detail::mask_8bit( + static_cast((*iter >> (24 - 8 * i)))); + } + } + } + + private: + void add_to_data_length(word_t n) { + word_t carry = 0; + data_length_digits_[0] += n; + for (std::size_t i = 0; i < 4; ++i) { + data_length_digits_[i] += carry; + if (data_length_digits_[i] >= 65536u) { + carry = data_length_digits_[i] >> 16; + data_length_digits_[i] &= 65535u; + } else { + break; + } + } + } + void write_data_bit_length(byte_t* begin) { + word_t data_bit_length_digits[4]; + std::copy(data_length_digits_, data_length_digits_ + 4, + data_bit_length_digits); + + // convert byte length to bit length (multiply 8 or shift 3 times left) + word_t carry = 0; + for (std::size_t i = 0; i < 4; ++i) { + word_t before_val = data_bit_length_digits[i]; + data_bit_length_digits[i] <<= 3; + data_bit_length_digits[i] |= carry; + data_bit_length_digits[i] &= 65535u; + carry = (before_val >> (16 - 3)) & 65535u; + } + + // write data_bit_length + for (int i = 3; i >= 0; --i) { + (*begin++) = static_cast(data_bit_length_digits[i] >> 8); + (*begin++) = static_cast(data_bit_length_digits[i]); + } + } + std::vector buffer_; + word_t data_length_digits_[4]; // as 64bit integer (16bit x 4 integer) + word_t h_[8]; +}; + +inline void get_hash_hex_string(const hash256_one_by_one& hasher, + std::string& hex_str) { + byte_t hash[k_digest_size]; + hasher.get_hash_bytes(hash, hash + k_digest_size); + return bytes_to_hex_string(hash, hash + k_digest_size, hex_str); +} + +inline std::string get_hash_hex_string(const hash256_one_by_one& hasher) { + std::string hex_str; + get_hash_hex_string(hasher, hex_str); + return hex_str; +} + +namespace impl { +template +void hash256_impl(RaIter first, RaIter last, OutIter first2, OutIter last2, int, + std::random_access_iterator_tag) { + hash256_one_by_one hasher; + // hasher.init(); + hasher.process(first, last); + hasher.finish(); + hasher.get_hash_bytes(first2, last2); +} + +template +void hash256_impl(InputIter first, InputIter last, OutIter first2, + OutIter last2, int buffer_size, std::input_iterator_tag) { + std::vector buffer(buffer_size); + hash256_one_by_one hasher; + // hasher.init(); + while (first != last) { + int size = buffer_size; + for (int i = 0; i != buffer_size; ++i, ++first) { + if (first == last) { + size = i; + break; + } + buffer[i] = *first; + } + hasher.process(buffer.begin(), buffer.begin() + size); + } + hasher.finish(); + hasher.get_hash_bytes(first2, last2); +} +} + +template +void hash256(InIter first, InIter last, OutIter first2, OutIter last2, + int buffer_size = PICOSHA2_BUFFER_SIZE_FOR_INPUT_ITERATOR) { + picosha2::impl::hash256_impl( + first, last, first2, last2, buffer_size, + typename std::iterator_traits::iterator_category()); +} + +template +void hash256(InIter first, InIter last, OutContainer& dst) { + hash256(first, last, dst.begin(), dst.end()); +} + +template +void hash256(const InContainer& src, OutIter first, OutIter last) { + hash256(src.begin(), src.end(), first, last); +} + +template +void hash256(const InContainer& src, OutContainer& dst) { + hash256(src.begin(), src.end(), dst.begin(), dst.end()); +} + +template +void hash256_hex_string(InIter first, InIter last, std::string& hex_str) { + byte_t hashed[k_digest_size]; + hash256(first, last, hashed, hashed + k_digest_size); + std::ostringstream oss; + output_hex(hashed, hashed + k_digest_size, oss); + hex_str.assign(oss.str()); +} + +template +std::string hash256_hex_string(InIter first, InIter last) { + std::string hex_str; + hash256_hex_string(first, last, hex_str); + return hex_str; +} + +inline void hash256_hex_string(const std::string& src, std::string& hex_str) { + hash256_hex_string(src.begin(), src.end(), hex_str); +} + +template +void hash256_hex_string(const InContainer& src, std::string& hex_str) { + hash256_hex_string(src.begin(), src.end(), hex_str); +} + +template +std::string hash256_hex_string(const InContainer& src) { + return hash256_hex_string(src.begin(), src.end()); +} +templatevoid hash256(std::ifstream& f, OutIter first, OutIter last){ + hash256(std::istreambuf_iterator(f), std::istreambuf_iterator(), first,last); + +} +}// namespace picosha2 +#endif // PICOSHA2_H diff --git a/app/invoke_parallel/main.cpp b/app/invoke_parallel/main.cpp index 1d5e32e..4382168 100644 --- a/app/invoke_parallel/main.cpp +++ b/app/invoke_parallel/main.cpp @@ -6,7 +6,7 @@ #include static constexpr int CUTOFF = 16; -static constexpr int INPUT_SIZE = 2064; +static constexpr int INPUT_SIZE = 8192; typedef std::vector> complex_vector; void divide(complex_vector::iterator data, int n) { diff --git a/lib/pls/include/pls/algorithms/invoke_parallel_impl.h b/lib/pls/include/pls/algorithms/invoke_parallel_impl.h index 2fd05ce..c2ab998 100644 --- a/lib/pls/include/pls/algorithms/invoke_parallel_impl.h +++ b/lib/pls/include/pls/algorithms/invoke_parallel_impl.h @@ -2,6 +2,7 @@ #ifndef PLS_INVOKE_PARALLEL_IMPL_H #define PLS_INVOKE_PARALLEL_IMPL_H +#include #include "pls/internal/scheduling/fork_join_task.h" #include "pls/internal/scheduling/scheduler.h" #include "pls/internal/helpers/unique_id.h" @@ -19,8 +20,7 @@ inline void run_body(const Body &internal_body, const abstract_task::id &id) { // if not we will spawn it as a new 'fork-join-style' task. auto current_task = scheduler::current_task(); if (current_task->unique_id() == id) { - auto current_sub_task = reinterpret_cast(current_task)->currently_executing(); - internal_body(current_sub_task); + internal_body(); } else { fork_join_lambda_by_reference root_body(&internal_body); fork_join_task root_task{&root_body, id}; @@ -37,8 +37,7 @@ void invoke_parallel(const Function1 &function1, const Function2 &function2) { static abstract_task::id id = unique_id::create(); auto internal_body = [&](fork_join_sub_task *this_task) { - auto sub_task_body_2 = [&](fork_join_sub_task *) { function2(); }; - auto sub_task_2 = fork_join_lambda_by_reference(&sub_task_body_2); + auto sub_task_2 = fork_join_lambda_by_reference(function2); this_task->spawn_child(sub_task_2); function1(); // Execute first function 'inline' without spawning a sub_task object @@ -54,16 +53,15 @@ void invoke_parallel(const Function1 &function1, const Function2 &function2, con using namespace ::pls::internal::helpers; static abstract_task::id id = unique_id::create(); - auto internal_body = [&](fork_join_sub_task *this_task) { - auto sub_task_body_2 = [&](fork_join_sub_task *) { function2(); }; - auto sub_task_2 = fork_join_lambda_by_reference(&sub_task_body_2); - auto sub_task_body_3 = [&](fork_join_sub_task *) { function3(); }; - auto sub_task_3 = fork_join_lambda_by_reference(&sub_task_body_3); + auto internal_body = [&]() { + auto current_task = fork_join_sub_task::current(); + auto sub_task_2 = fork_join_lambda_by_reference(function2); + auto sub_task_3 = fork_join_lambda_by_reference(function3); - this_task->spawn_child(sub_task_2); - this_task->spawn_child(sub_task_3); + current_task->spawn_child(sub_task_2); + current_task->spawn_child(sub_task_3); function1(); // Execute first function 'inline' without spawning a sub_task object - this_task->wait_for_all(); + current_task->wait_for_all(); }; internal::run_body(internal_body, id); diff --git a/lib/pls/include/pls/internal/scheduling/fork_join_task.h b/lib/pls/include/pls/internal/scheduling/fork_join_task.h index 51f5c09..33a6c2f 100644 --- a/lib/pls/include/pls/internal/scheduling/fork_join_task.h +++ b/lib/pls/include/pls/internal/scheduling/fork_join_task.h @@ -43,20 +43,21 @@ class fork_join_sub_task { void spawn_child(T &sub_task); void wait_for_all(); + static fork_join_sub_task *current(); private: void execute(); }; template class fork_join_lambda_by_reference : public fork_join_sub_task { - const Function *function_; + const Function &function_; public: - explicit fork_join_lambda_by_reference(const Function *function) : fork_join_sub_task{}, function_{function} {}; + explicit fork_join_lambda_by_reference(const Function &function) : fork_join_sub_task{}, function_{function} {}; protected: void execute_internal() override { - (*function_)(this); + function_(); } }; @@ -69,7 +70,7 @@ class fork_join_lambda_by_value : public fork_join_sub_task { protected: void execute_internal() override { - function_(this); + function_(); } }; diff --git a/lib/pls/include/pls/pls.h b/lib/pls/include/pls/pls.h index b75317b..b7a218b 100644 --- a/lib/pls/include/pls/pls.h +++ b/lib/pls/include/pls/pls.h @@ -18,6 +18,8 @@ using task_id = internal::scheduling::abstract_task::id; using unique_id = internal::helpers::unique_id; using internal::scheduling::fork_join_sub_task; +using internal::scheduling::fork_join_lambda_by_reference; +using internal::scheduling::fork_join_lambda_by_value; using internal::scheduling::fork_join_task; using algorithm::invoke_parallel; diff --git a/lib/pls/src/internal/scheduling/fork_join_task.cpp b/lib/pls/src/internal/scheduling/fork_join_task.cpp index c7c46d2..f6d9435 100644 --- a/lib/pls/src/internal/scheduling/fork_join_task.cpp +++ b/lib/pls/src/internal/scheduling/fork_join_task.cpp @@ -66,6 +66,10 @@ fork_join_sub_task *fork_join_task::get_stolen_sub_task() { return deque_.pop_head(); } +fork_join_sub_task *fork_join_sub_task::current() { + return dynamic_cast(scheduler::current_task())->currently_executing(); +} + bool fork_join_task::internal_stealing(abstract_task *other_task) { PROFILE_STEALING("fork_join_task::internal_stealin") auto cast_other_task = reinterpret_cast(other_task);