Commit d5b66aba by FritzFlorian

Change FFT benchmark to use static allocated temporary arrays.

parent 89b6e3cb
Pipeline #1406 failed with stages
in 41 seconds
#include "pls/internal/scheduling/scheduler.h" #include "pls/internal/scheduling/scheduler.h"
#include "pls/internal/scheduling/static_scheduler_memory.h" #include "pls/internal/scheduling/static_scheduler_memory.h"
#include "pls/internal/helpers/profiler.h"
using namespace pls::internal::scheduling; using namespace pls::internal::scheduling;
#include <iostream>
#include <complex>
#include <vector>
#include "benchmark_runner.h" #include "benchmark_runner.h"
#include "benchmark_base/fft.h" #include "benchmark_base/fft.h"
using namespace comparison_benchmarks::base; using namespace comparison_benchmarks::base;
void conquer(fft::complex_vector::iterator data, int n) { void pls_conquer(fft::complex_vector::iterator data, fft::complex_vector::iterator swap_array, int n) {
if (n < 2) { if (n < 2) {
return; return;
} }
fft::divide(data, n); fft::divide(data, swap_array, n);
if (n <= fft::RECURSIVE_CUTOFF) { if (n <= fft::RECURSIVE_CUTOFF) {
fft::conquer(data, n / 2); fft::conquer(data, swap_array, n / 2);
fft::conquer(data + n / 2, n / 2); fft::conquer(data + n / 2, swap_array + n / 2, n / 2);
} else { } else {
scheduler::spawn([data, n]() { scheduler::spawn([data, n, swap_array]() {
conquer(data, n / 2); pls_conquer(data, swap_array, n / 2);
}); });
scheduler::spawn([data, n]() { scheduler::spawn([data, n, swap_array]() {
conquer(data + n / 2, n / 2); pls_conquer(data + n / 2, swap_array + n / 2, n / 2);
}); });
scheduler::sync(); scheduler::sync();
} }
...@@ -37,11 +32,7 @@ void conquer(fft::complex_vector::iterator data, int n) { ...@@ -37,11 +32,7 @@ void conquer(fft::complex_vector::iterator data, int n) {
constexpr int MAX_NUM_THREADS = 8; constexpr int MAX_NUM_THREADS = 8;
constexpr int MAX_NUM_TASKS = 32; constexpr int MAX_NUM_TASKS = 32;
constexpr int MAX_STACK_SIZE = 1024 * 32; constexpr int MAX_STACK_SIZE = 1024 * 8;
static_scheduler_memory<MAX_NUM_THREADS,
MAX_NUM_TASKS,
MAX_STACK_SIZE> global_scheduler_memory;
int main(int argc, char **argv) { int main(int argc, char **argv) {
int num_threads; int num_threads;
...@@ -53,12 +44,16 @@ int main(int argc, char **argv) { ...@@ -53,12 +44,16 @@ int main(int argc, char **argv) {
benchmark_runner runner{full_directory, test_name}; benchmark_runner runner{full_directory, test_name};
fft::complex_vector data = fft::generate_input(); fft::complex_vector data = fft::generate_input();
fft::complex_vector swap_array(data.size());
static_scheduler_memory<MAX_NUM_THREADS,
MAX_NUM_TASKS,
MAX_STACK_SIZE> global_scheduler_memory;
scheduler scheduler{global_scheduler_memory, (unsigned) num_threads}; scheduler scheduler{global_scheduler_memory, (unsigned) num_threads};
runner.run_iterations(fft::NUM_ITERATIONS, [&]() { runner.run_iterations(fft::NUM_ITERATIONS, [&]() {
scheduler.perform_work([&]() { scheduler.perform_work([&]() {
conquer(data.begin(), fft::SIZE);; pls_conquer(data.begin(), swap_array.begin(), fft::SIZE);;
}); });
}, fft::NUM_WARMUP_ITERATIONS); }, fft::NUM_WARMUP_ITERATIONS);
runner.commit_results(true); runner.commit_results(true);
......
...@@ -18,8 +18,8 @@ typedef std::vector<std::complex<double>> complex_vector; ...@@ -18,8 +18,8 @@ typedef std::vector<std::complex<double>> complex_vector;
complex_vector generate_input(); complex_vector generate_input();
void divide(complex_vector::iterator data, int n); void divide(complex_vector::iterator data, complex_vector::iterator swap_array, int n);
void conquer(complex_vector::iterator data, int n); void conquer(complex_vector::iterator data, complex_vector::iterator swap_array, int n);
void combine(complex_vector::iterator data, int n); void combine(complex_vector::iterator data, int n);
} }
......
...@@ -19,8 +19,7 @@ complex_vector generate_input() { ...@@ -19,8 +19,7 @@ complex_vector generate_input() {
return data; return data;
} }
void divide(complex_vector::iterator data, int n) { void divide(complex_vector::iterator data, complex_vector::iterator tmp_odd_elements, int n) {
complex_vector tmp_odd_elements(n / 2);
for (int i = 0; i < n / 2; i++) { for (int i = 0; i < n / 2; i++) {
tmp_odd_elements[i] = data[i * 2 + 1]; tmp_odd_elements[i] = data[i * 2 + 1];
} }
...@@ -47,14 +46,14 @@ void combine(complex_vector::iterator data, int n) { ...@@ -47,14 +46,14 @@ void combine(complex_vector::iterator data, int n) {
} }
} }
void conquer(complex_vector::iterator data, int n) { void conquer(complex_vector::iterator data, complex_vector::iterator swap_array, int n) {
if (n < 2) { if (n < 2) {
return; return;
} }
divide(data, n); divide(data, swap_array, n);
conquer(data, n / 2); conquer(data, swap_array, n / 2);
conquer(data + n / 2, n / 2); conquer(data + n / 2, swap_array + n / 2, n / 2);
combine(data, n); combine(data, n);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment