main.cpp 4.02 KB
Newer Older
1 2 3 4
#include "pls/internal/scheduling/scheduler.h"
#include "pls/internal/scheduling/parallel_result.h"
#include "pls/internal/scheduling/scheduler_memory.h"
using namespace pls::internal::scheduling;
5 6 7 8

#include <iostream>
#include <complex>
#include <vector>
9 10 11
#include <atomic>

std::atomic<unsigned long> count;
12

13
static constexpr int CUTOFF = 16;
14
static constexpr int INPUT_SIZE = 8192;
15 16 17
typedef std::vector<std::complex<double>> complex_vector;

void divide(complex_vector::iterator data, int n) {
18 19 20 21 22 23 24 25 26 27
  complex_vector tmp_odd_elements(n / 2);
  for (int i = 0; i < n / 2; i++) {
    tmp_odd_elements[i] = data[i * 2 + 1];
  }
  for (int i = 0; i < n / 2; i++) {
    data[i] = data[i * 2];
  }
  for (int i = 0; i < n / 2; i++) {
    data[i + n / 2] = tmp_odd_elements[i];
  }
28 29 30
}

void combine(complex_vector::iterator data, int n) {
31 32 33
  for (int i = 0; i < n / 2; i++) {
    std::complex<double> even = data[i];
    std::complex<double> odd = data[i + n / 2];
34

35 36 37 38
    // w is the "twiddle-factor".
    // this could be cached, but we run the same 'data_structures' algorithm parallel/serial,
    // so it won't impact the performance comparison.
    std::complex<double> w = exp(std::complex<double>(0, -2. * M_PI * i / n));
39

40 41 42
    data[i] = even + w * odd;
    data[i + n / 2] = even - w * odd;
  }
43 44
}

45
void fft_normal(complex_vector::iterator data, int n) {
46
  if (n < 2) {
47
//    count++;
48 49
    return;
  }
50

51
  divide(data, n);
52 53 54 55 56 57 58 59 60 61 62
  fft_normal(data, n / 2);
  fft_normal(data + n / 2, n / 2);
  combine(data, n);
}

parallel_result<short> fft(complex_vector::iterator data, int n) {
  if (n < 2) {
    return 0;
  }

  divide(data, n);
63
  if (n <= CUTOFF) {
64 65 66
    fft_normal(data, n / 2);
    fft_normal(data + n / 2, n / 2);
    return 0;
67
  } else {
68 69 70 71 72 73 74 75
    return scheduler::par([=]() {
      return fft(data, n / 2);
    }, [=]() {
      return fft(data + n / 2, n / 2);
    }).then([=](int, int) {
      combine(data, n);
      return 0;
    });
76
  }
77 78 79
}

complex_vector prepare_input(int input_size) {
80 81
  std::vector<double> known_frequencies{2, 11, 52, 88, 256};
  complex_vector data(input_size);
82

83 84 85 86 87 88
  // Set our input data to match a time series of the known_frequencies.
  // When applying fft to this time-series we should find these frequencies.
  for (int i = 0; i < input_size; i++) {
    data[i] = std::complex<double>(0.0, 0.0);
    for (auto frequencie : known_frequencies) {
      data[i] += sin(2 * M_PI * frequencie * i / input_size);
89
    }
90
  }
91

92
  return data;
93 94
}

95 96 97 98 99 100 101 102 103
static constexpr int NUM_ITERATIONS = 1000;
constexpr size_t NUM_THREADS = 1;

constexpr size_t NUM_TASKS = 64;
constexpr size_t MAX_TASK_STACK_SIZE = 0;

constexpr size_t NUM_CONTS = 64;
constexpr size_t MAX_CONT_SIZE = 192;

104
int main() {
105
  complex_vector initial_input = prepare_input(INPUT_SIZE);
106

107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
  static_scheduler_memory<NUM_THREADS,
                          NUM_TASKS,
                          MAX_TASK_STACK_SIZE,
                          NUM_CONTS,
                          MAX_CONT_SIZE> static_scheduler_memory;

  scheduler scheduler{static_scheduler_memory, NUM_THREADS};

  count.store(0);
  auto start = std::chrono::steady_clock::now();
  for (int i = 0; i < NUM_ITERATIONS; i++) {
    complex_vector input_2(initial_input);
    scheduler.perform_work([&]() {
      return scheduler::par([&]() {
        return fft(input_2.begin(), INPUT_SIZE);
      }, []() {
        return parallel_result<int>{0};
      }).then([](int, int) {
        return 0;
      });
    });
  }
  auto end = std::chrono::steady_clock::now();
  std::cout << "Count: " << count.load() << std::endl;
  std::cout << "Framework:  " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count()
            << std::endl;

  count.store(0);
  start = std::chrono::steady_clock::now();
  for (int i = 0; i < NUM_ITERATIONS; i++) {
    complex_vector input_1(initial_input);
    fft_normal(input_1.begin(), INPUT_SIZE);
  }
  end = std::chrono::steady_clock::now();
  std::cout << "Count: " << count.load() << std::endl;
  std::cout << "Normal:     " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count()
            << std::endl;
144

145
  return 0;
146
}