Commit db8bd0b9 by FritzFlorian

Fix: change thread impelmentation to point to the correct memory region.

The local storage per thread is only a pointer, this must be assigned AFTER the local storage object has moved to it's final location (e.g. on the stack of the caller). We might even change this a little further in the future to make the location of the state object more transparent.
parent 17aee013
# List all required files here (cmake best practice to NOT automate this step!)
add_library(pls STATIC
src/library.cpp include/pls/library.h
src/internal/base/thread.cpp include/pls/internal/base/thread.h)
include/pls/internal/base/choose_threading.h
src/internal/base/spin_lock.cpp include/pls/internal/base/spin_lock.h
src/internal/base/thread.cpp include/pls/internal/base/thread.h
)
# Settings for our project...
# ...pthreads or C++ 11 threads
......@@ -46,4 +49,4 @@ target_compile_options(pls PRIVATE
$<$<OR:$<CXX_COMPILER_ID:Clang>,$<CXX_COMPILER_ID:AppleClang>,$<CXX_COMPILER_ID:GNU>>:
-Wall>
$<$<CXX_COMPILER_ID:MSVC>:
-W4>)
\ No newline at end of file
-W4>)
// Make sure exactly ONE threading library is active for
// all of our threading primitives
#ifndef PLS_CHOOSE_THREADING_H
#define PLS_CHOOSE_THREADING_H
#if defined(PLS_USING_PTHREADS) && defined(PLS_USING_CPP_THREADS)
#error "Please activate exactly one threading library (currently both are activated)!"
#endif
#if !defined(PLS_USING_PTHREADS) && !defined(PLS_USING_CPP_THREADS)
#error "Please activate exactly one threading library (currently none are activated)!"
#endif
#endif //PLS_CHOOSE_THREADING_H
#ifndef PLS_SPINLOCK_H
#define PLS_SPINLOCK_H
#include <atomic>
#include "pls/internal/base/thread.h"
namespace pls {
namespace internal {
namespace base {
class spin_lock {
std::atomic_flag flag_;
int yield_at_tries_;
public:
spin_lock(): flag_{ATOMIC_FLAG_INIT}, yield_at_tries_{1024} {};
void lock() {
int tries = 0;
while (flag_.test_and_set(std::memory_order_acquire)) {
tries++;
if (tries % yield_at_tries_ == 0) {
this_thread::yield();
}
}
}
};
}
}
}
#endif //PLS_SPINLOCK_H
......@@ -6,13 +6,16 @@
#ifndef PLS_THREAD_H
#define PLS_THREAD_H
#include <functional>
// platform specific includes
#include "pls/internal/base/choose_threading.h"
#ifdef PLS_USING_PTHREADS
#include <pthread.h>
#include <iostream>
#elif PLS_USING_CPP_THREADS
#include <thread>
#else
#error "Please configure exactly one threading library!"
#endif
namespace pls {
......@@ -20,91 +23,110 @@ namespace pls {
namespace base {
using thread_entrypoint = void();
// Thread local storage support
class this_thread {
template<typename Function, typename State>
friend class thread;
#ifdef PLS_USING_PTHREADS
pthread_key_t thr_id_key;
bool thr_id_key_created = false;
// forward declaration
template<typename T>
void* start_pthread_internal(void*);
static pthread_key_t local_storage_key_;
static bool local_storage_key_initialized_;
#endif
#ifdef PLS_USING_CPP_THREADS
thread_local void* local_thread;
static thread_local void* local_storage_;
#endif
template<typename T>
class thread {
private:
// Handle to the native thread used
public:
static void yield() {
#ifdef PLS_USING_PTHREADS
friend void* start_pthread_internal<T>(void*);
pthread_t pthread_thread_;
thread_entrypoint* entry_function_;
thread(thread_entrypoint entry_function, T local_object):
pthread_thread_(),
entry_function_(entry_function),
local_object_(local_object) {
if (!thr_id_key_created) {
pthread_key_create(&thr_id_key, nullptr);
thr_id_key_created = true;
}
pthread_create(&pthread_thread_, nullptr, start_pthread_internal<T>, (void *)(this));
}
pthread_yield();
#endif
#ifdef PLS_USING_CPP_THREADS
std::thread std_thread_;
std::this_thread::yield();
#endif
}
thread(thread_entrypoint entry_function, T local_object):
local_object_(local_object),
std_thread_([=](){
local_thread = this;
entry_function();
}) {};
template<typename T>
static T* state() {
#ifdef PLS_USING_PTHREADS
return reinterpret_cast<T*>(pthread_getspecific(local_storage_key_));
#endif
#ifdef PLS_USING_CPP_THREADS
reinterpret_cast<T*>(local_storage_);
#endif
public:
T local_object_;
/**
* Creates and starts a thread.
* NOT thread safe, best only use from one main thread managing the runtime!
*
* @param entry_function The entry function to run on the thread
* @param T local_object
*
* @return a handle to the newly created thread
*/
static thread start(thread_entrypoint entry_function, T local_object) {
return thread(entry_function, local_object);
}
void join() {
template<typename T>
static void set_state(const T& state) {
#ifdef PLS_USING_PTHREADS
pthread_join(pthread_thread_, nullptr);
*reinterpret_cast<T*>(pthread_getspecific(local_storage_key_)) = state;
#endif
#ifdef PLS_USING_CPP_THREADS
std_thread_.join();
*reinterpret_cast<T*>(local_storage_) = state;
#endif
}
};
static void yield() {
template<typename Function, typename State>
class thread {
friend class this_thread;
// Keep a copy of the function (lambda) in this object to make sure it is valid when called!
Function function_;
// Keep the local state we hold in here
State state_;
// Keep handle to native implementation
#ifdef PLS_USING_PTHREADS
pthread_yield();
pthread_t pthread_thread_;
#endif
#ifdef PLS_USING_CPP_THREADS
std::this_thread::yield();
std::thread std_thread_;
#endif
#ifdef PLS_USING_PTHREADS
static void* start_pthread_internal(void* thread_pointer) {
auto my_thread = reinterpret_cast<thread*>(thread_pointer);
pthread_setspecific(this_thread::local_storage_key_, (void*)&my_thread->state_);
my_thread->function_();
pthread_exit(nullptr);
}
static thread* get_current() {
#endif
public:
#ifdef PLS_USING_PTHREADS
return (thread*) pthread_getspecific(thr_id_key);
explicit thread(const Function& function, const State& state):
function_{function},
state_{state},
pthread_thread_{} {}
void start() {
if (!this_thread::local_storage_key_initialized_) {
pthread_key_create(&this_thread::local_storage_key_, nullptr);
this_thread::local_storage_key_initialized_ = true;
}
pthread_create(&pthread_thread_, nullptr, start_pthread_internal, (void *)(this));
}
#endif
#ifdef PLS_USING_CPP_THREADS
return (thread*) local_thread;
explicit thread(const Function& function, const State& state):
function_{function},
state_{state},
std_thread_{0} {};
void start() {
std_thread_ = std::thread([=](){
local_storage = reinterpret_cast<void*><&this->state_>;
this->function_();
});
}
#endif
public:
void join() {
#ifdef PLS_USING_PTHREADS
pthread_join(pthread_thread_, nullptr);
#endif
#ifdef PLS_USING_CPP_THREADS
std_thread_.join();
#endif
}
......@@ -116,15 +138,11 @@ namespace pls {
thread& operator=(const thread&) = delete;
};
#ifdef PLS_USING_PTHREADS
template<typename T>
void* start_pthread_internal(void* thread_pointer) {
auto* my_thread = (thread<T>*)thread_pointer;
pthread_setspecific(thr_id_key, thread_pointer);
my_thread->entry_function_();
pthread_exit(nullptr);
template<typename Function, typename State>
thread<Function, State> create_thread(const Function& function, const State& state) {
return thread<Function, State>(function, state);
}
#endif
}
}
......
#include "pls/internal/base/spin_lock.h"
namespace pls {
namespace internal {
namespace base {
// implementation in header (inlining)
}
}
}
#include "pls/internal/base/thread.h"
namespace pls {
namespace internal {
namespace base {
#ifdef PLS_USING_PTHREADS
bool this_thread::local_storage_key_initialized_ = false;
pthread_key_t this_thread::local_storage_key_;
#endif
#ifdef PLS_USING_CPP_THREADS
thread_local void* this_thread::local_storage_;
#endif
// implementation in header (C++ templating)
}
}
}
#include <catch.hpp>
#include <pls/internal/base/thread.h>
#include <string>
#include <vector>
using namespace pls::internal::base;
using namespace std;
......@@ -10,15 +10,18 @@ static bool visited;
TEST_CASE( "thread creation and joining", "[internal/base/thread.h]") {
visited = false;
auto t1 = thread<int>::start([]() { visited = true; }, 0);
auto t1 = create_thread([]() { visited = true; }, 0);
t1.start();
t1.join();
REQUIRE(visited);
}
TEST_CASE( "thread state", "[internal/base/thread.h]") {
auto t1 = thread<int>::start([]() { REQUIRE(thread<int>::get_current()->local_object_ == 1); }, 1);
auto t2 = thread<string>::start([]() { REQUIRE(thread<string>::get_current()->local_object_ == "Hello"); }, "Hello");
auto t1 = create_thread([]() { REQUIRE(*this_thread::state<int>() == 1); }, 1);
auto t2 = create_thread([]() { REQUIRE(*this_thread::state<vector<int>>() == vector<int>{1, 2}); }, vector<int>{1, 2});
t1.start();
t2.start();
t1.join();
t2.join();
}
\ No newline at end of file
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment