diff --git a/lib/pls/include/pls/internal/base/spin_lock.h b/lib/pls/include/pls/internal/base/spin_lock.h index 4105e1c..880a521 100644 --- a/lib/pls/include/pls/internal/base/spin_lock.h +++ b/lib/pls/include/pls/internal/base/spin_lock.h @@ -26,6 +26,10 @@ namespace pls { } } } + + void unlock() { + flag_.clear(std::memory_order_release); + } }; } } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 0ca48e9..3532cbe 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,5 +1,5 @@ add_executable(tests main.cpp example_tests.cpp - thread_tests.cpp) -target_link_libraries(tests catch2 pls) \ No newline at end of file + base_tests.cpp) +target_link_libraries(tests catch2 pls) diff --git a/test/base_tests.cpp b/test/base_tests.cpp new file mode 100644 index 0000000..d3c5ba9 --- /dev/null +++ b/test/base_tests.cpp @@ -0,0 +1,59 @@ +#include +#include +#include + +#include + +using namespace pls::internal::base; +using namespace std; + +static bool base_tests_visited; +static int base_tests_local_value_one; +static vector base_tests_local_value_two; + +TEST_CASE( "thread creation and joining", "[internal/base/thread.h]") { + base_tests_visited = false; + auto t1 = start_thread([]() { base_tests_visited = true; }); + t1.join(); + + REQUIRE(base_tests_visited); +} + +TEST_CASE( "thread state", "[internal/base/thread.h]") { + int state_one = 1; + vector state_two{1, 2}; + + auto t1 = start_thread([]() { base_tests_local_value_one = *this_thread::state(); }, &state_one); + auto t2 = start_thread([]() { base_tests_local_value_two = *this_thread::state>(); }, &state_two); + t1.join(); + t2.join(); + + REQUIRE(base_tests_local_value_one == 1); + REQUIRE(base_tests_local_value_two == vector{1, 2}); +} + +TEST_CASE( "spinlock protects concurrent counter", "[internal/base/spinlock.h]") { + constexpr int num_iterations = 1000000; + int shared_counter = 0; + spin_lock lock{}; + + auto t1 = start_thread([&] () { + for (int i = 0; i < num_iterations; i++) { + lock.lock(); + shared_counter++; + lock.unlock(); + } + }); + auto t2 = start_thread([&] () { + for (int i = 0; i < num_iterations; i++) { + lock.lock(); + shared_counter--; + lock.unlock(); + } + }); + + t1.join(); + t2.join(); + + REQUIRE(shared_counter == 0); +} diff --git a/test/thread_tests.cpp b/test/thread_tests.cpp deleted file mode 100644 index 03809fe..0000000 --- a/test/thread_tests.cpp +++ /dev/null @@ -1,32 +0,0 @@ -#include -#include - -#include - -using namespace pls::internal::base; -using namespace std; - -static bool visited; -static int local_value_1; -static vector local_value_two; - -TEST_CASE( "thread creation and joining", "[internal/base/thread.h]") { - visited = false; - auto t1 = start_thread([]() { visited = true; }); - t1.join(); - - REQUIRE(visited); -} - -TEST_CASE( "thread state", "[internal/base/thread.h]") { - int state_one = 1; - vector state_two{1, 2}; - - auto t1 = start_thread([]() { local_value_1 = *this_thread::state(); }, &state_one); - auto t2 = start_thread([]() { local_value_two = *this_thread::state>(); }, &state_two); - t1.join(); - t2.join(); - - REQUIRE(local_value_1 == 1); - REQUIRE(local_value_two == vector{1, 2}); -}