Skip to content

Commit

Permalink
FFT plan optimization
Browse files Browse the repository at this point in the history
- Simplify bitrev table calculation;
- Factorize FFT plan on 2^k and residual (2, 2, 3, 3 -> 4, 3, 3);
- Update benchs;
  • Loading branch information
vitalsong committed Sep 3, 2024
1 parent beef40e commit 65ea339
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 39 deletions.
11 changes: 8 additions & 3 deletions benchs/fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include <dsplib.h>
#include <vector>

constexpr int MIN_TIME = 5;

#ifdef KISSFFT_SUPPORT

#include "kiss_fft.h"
Expand All @@ -22,13 +24,14 @@ static void BM_KISSFFT(benchmark::State& state) {
BENCHMARK(BM_KISSFFT)
->Arg(1024)
->Arg(1331)
->Arg(512 * 3)
->Arg(2048)
->Arg(4096)
->Arg(8192)
->Arg(11200)
->Arg(11202)
->Arg(16384)
->MinTime(5)
->MinTime(MIN_TIME)
->Unit(benchmark::kMicrosecond);

#endif
Expand Down Expand Up @@ -74,13 +77,14 @@ static void BM_FFTW3_DOUBLE(benchmark::State& state) {
BENCHMARK(BM_FFTW3_DOUBLE)
->Arg(1024)
->Arg(1331)
->Arg(512 * 3)
->Arg(2048)
->Arg(4096)
->Arg(8192)
->Arg(11200)
->Arg(11202)
->Arg(16384)
->MinTime(5)
->MinTime(MIN_TIME)
->Unit(benchmark::kMicrosecond);

#endif
Expand All @@ -98,11 +102,12 @@ static void BM_FFT_DSPLIB(benchmark::State& state) {
BENCHMARK(BM_FFT_DSPLIB)
->Arg(1024)
->Arg(1331)
->Arg(512 * 3)
->Arg(2048)
->Arg(4096)
->Arg(8192)
->Arg(11200)
->Arg(11202)
->Arg(16384)
->MinTime(5)
->MinTime(MIN_TIME)
->Unit(benchmark::kMicrosecond);
55 changes: 43 additions & 12 deletions lib/fft/fact-fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,25 @@
namespace dsplib {

//constructing a factorization plan
//example for n=120, factor is (2, 2, 2, 3, 5) and plan is 120 -> (8) | (15) -> (2 | (2 | 2))) | (3 | 5)
//todo: allocate a separate 2^K plan
//example for n=120, factor is (2, 2, 2, 3, 5) and plan is 120 -> (8) | (15) -> (8) | (3 | 5)
class PlanTree
{
public:
explicit PlanTree(int n)
: _n{n} {
assert(n >= 2);
DSPLIB_ASSERT(n >= 2, "FFT plan size error");

const auto fac = factor(n);
//use Pow2FFT solver
//TODO: move fft(n=2) from PrimeFFT
if ((n > 2) && ispow2(n)) {
_solver = create_fft_plan(n);
return;
}

const auto fac = _factor(n);

//use PrimeFFT solver
if (fac.size() == 1) {
_prime = true;
//it is important to use the cache because there can be several identical FFTs
_solver = create_fft_plan(n);
return;
Expand Down Expand Up @@ -55,15 +62,17 @@ class PlanTree
}

[[nodiscard]] PlanTree* q_plan() const noexcept {
assert(has_next());
return _q;
}

[[nodiscard]] PlanTree* p_plan() const noexcept {
assert(has_next());
return _p;
}

bool is_prime() const noexcept {
return _prime;
bool has_next() const noexcept {
return (_q != nullptr) && (_p != nullptr);
}

[[nodiscard]] std::shared_ptr<BaseFftPlanC> solver() const noexcept {
Expand All @@ -72,8 +81,29 @@ class PlanTree
}

private:
int _n;
bool _prime{false};
//factorization with extract 2^n component, example (2, 2, 2, 3) -> (8, 3)
static std::vector<int> _factor(int n) noexcept {
const int pn = n;
while (n % 2 == 0) {
n /= 2;
}

std::vector<int> fac;
if (n != pn) {
fac.push_back(pn / n);
}

if (n == 1) {
return fac;
}

const auto fc = factor(n);
fac.insert(fac.end(), fc.begin(), fc.end());
std::sort(fac.begin(), fac.end());
return fac;
}

const int _n;
PlanTree* _p{nullptr};
PlanTree* _q{nullptr};
std::shared_ptr<BaseFftPlanC> _solver;
Expand All @@ -93,9 +123,10 @@ void _transpose(cmplx_t* x, cmplx_t* t, int n, int m) noexcept {
void _ctfft(const PlanTree* plan, cmplx_t* x, cmplx_t* mm, const cmplx_t* tw, int ntw) {
const int n = plan->size();

if (plan->is_prime()) {
if (!plan->has_next()) {
//TODO: separate in/out pointer
plan->solver()->solve(x, x, n);
plan->solver()->solve(x, mm, n);
std::memcpy(x, mm, n * sizeof(cmplx_t));
return;
}

Expand Down Expand Up @@ -146,7 +177,7 @@ FactorFFTPlan::FactorFFTPlan(int n)

[[nodiscard]] arr_cmplx FactorFFTPlan::solve(const arr_cmplx& x) const {
DSPLIB_ASSERT(x.size() == _n, "input vector size is not equal fft size");
arr_cmplx r(x);
arr_cmplx r(x); //TODO: remove copy
_ctfft(_plan.get(), r.data(), _px.data(), _twiddle.data(), _n);
return r;
}
Expand Down
36 changes: 13 additions & 23 deletions lib/fft/pow2-fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,21 @@ namespace dsplib {

namespace {

inline int _get_bit(int a, int pos) noexcept {
return (a >> pos) & 0x1;
}

inline void _set_bit(int& a, int pos, int bit) noexcept {
a &= ~(1 << pos);
a |= (bit << pos);
}

inline int _bitrev(int a, int s) noexcept {
int r = 0;
for (int i = 0; i < ((s + 1) / 2); ++i) {
_set_bit(r, (s - i - 1), _get_bit(a, i));
_set_bit(r, i, _get_bit(a, (s - i - 1)));
}
return r;
}

//generate a half of bitrev table
//table is symmetry, table(n/2, n) == table(0, n/2)+1
std::vector<int32_t> _gen_bitrev_table(int n) noexcept {
DSPLIB_ASSUME(n % 4 == 0);
std::vector<int32_t> res(n / 2);
int h = 1;
const int s = nextpow2(n);
for (int i = 0; i < (s - 1); ++i, h *= 2) {
for (int k = 0; k < h; ++k) {
res[k] = 2 * res[k];
res[k + h] = res[k] + 1;
}
}
for (int i = 0; i < n / 2; ++i) {
res[i] = _bitrev(i, s);
res[i] *= 2;
}
return res;
}
Expand Down Expand Up @@ -73,10 +62,9 @@ void _bitreverse(const cmplx_t* restrict x, cmplx_t* restrict y, const int32_t*
DSPLIB_ASSUME(n % 2 == 0);
const int n2 = n / 2;
for (int i = 0; i < n2; ++i) {
const auto kl = bitrev[i];
const auto kr = kl + 1;
y[i] = x[kl];
y[n2 + i] = x[kr];
const auto k = bitrev[i];
y[i] = x[k];
y[n2 + i] = x[k + 1];
}
}

Expand All @@ -98,13 +86,15 @@ arr_cmplx Pow2FftPlan::solve(const arr_cmplx& x) const {
}

void Pow2FftPlan::solve(const cmplx_t* x, cmplx_t* y, int n) const {
DSPLIB_ASSERT(x != y, "Pointers must be restricted");
_fft(x, y, n);
}

int Pow2FftPlan::size() const noexcept {
return n_;
}

//TODO: add "small" implementations (2, 4, 8)
void Pow2FftPlan::_fft(const cmplx_t* restrict in, cmplx_t* restrict out, int n) const noexcept {
DSPLIB_ASSUME(n % 2 == 0);
DSPLIB_ASSUME(n >= 2);
Expand Down
1 change: 1 addition & 0 deletions lib/math.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ int nextpow2(int m) {
}

bool ispow2(int m) {
//TODO: optimization
return (int(1) << nextpow2(m)) == m;
}

Expand Down
2 changes: 1 addition & 1 deletion tests/fft_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ TEST(FFT, Ifft) {

//-------------------------------------------------------------------------------------------------
TEST(FFT, Irfft) {
for (auto nfft : {512, 1024, 1000, 200}) {
for (auto nfft : {512 * 3, 1024, 1000, 200}) {
auto x = randn(nfft);
auto y = fft(x);
auto xc = ifft(y);
Expand Down

0 comments on commit 65ea339

Please sign in to comment.