diff --git a/benchs/fft.cpp b/benchs/fft.cpp index 398a3b0..05b9802 100644 --- a/benchs/fft.cpp +++ b/benchs/fft.cpp @@ -3,6 +3,8 @@ #include #include +constexpr int MIN_TIME = 5; + #ifdef KISSFFT_SUPPORT #include "kiss_fft.h" @@ -22,13 +24,14 @@ static void BM_KISSFFT(benchmark::State& state) { BENCHMARK(BM_KISSFFT) ->Arg(1024) ->Arg(1331) + ->Arg(512 * 3) ->Arg(2048) ->Arg(4096) ->Arg(8192) ->Arg(11200) ->Arg(11202) ->Arg(16384) - ->MinTime(5) + ->MinTime(MIN_TIME) ->Unit(benchmark::kMicrosecond); #endif @@ -74,13 +77,14 @@ static void BM_FFTW3_DOUBLE(benchmark::State& state) { BENCHMARK(BM_FFTW3_DOUBLE) ->Arg(1024) ->Arg(1331) + ->Arg(512 * 3) ->Arg(2048) ->Arg(4096) ->Arg(8192) ->Arg(11200) ->Arg(11202) ->Arg(16384) - ->MinTime(5) + ->MinTime(MIN_TIME) ->Unit(benchmark::kMicrosecond); #endif @@ -98,11 +102,12 @@ static void BM_FFT_DSPLIB(benchmark::State& state) { BENCHMARK(BM_FFT_DSPLIB) ->Arg(1024) ->Arg(1331) + ->Arg(512 * 3) ->Arg(2048) ->Arg(4096) ->Arg(8192) ->Arg(11200) ->Arg(11202) ->Arg(16384) - ->MinTime(5) + ->MinTime(MIN_TIME) ->Unit(benchmark::kMicrosecond); diff --git a/lib/fft/fact-fft.cpp b/lib/fft/fact-fft.cpp index 6e938e2..2288e41 100644 --- a/lib/fft/fact-fft.cpp +++ b/lib/fft/fact-fft.cpp @@ -10,18 +10,25 @@ namespace dsplib { //constructing a factorization plan -//example for n=120, factor is (2, 2, 2, 3, 5) and plan is 120 -> (8) | (15) -> (2 | (2 | 2))) | (3 | 5) -//todo: allocate a separate 2^K plan +//example for n=120, factor is (2, 2, 2, 3, 5) and plan is 120 -> (8) | (15) -> (8) | (3 | 5) class PlanTree { public: explicit PlanTree(int n) : _n{n} { - assert(n >= 2); + DSPLIB_ASSERT(n >= 2, "FFT plan size error"); - const auto fac = factor(n); + //use Pow2FFT solver + //TODO: move fft(n=2) from PrimeFFT + if ((n > 2) && ispow2(n)) { + _solver = create_fft_plan(n); + return; + } + + const auto fac = _factor(n); + + //use PrimeFFT solver if (fac.size() == 1) { - _prime = true; //it is important to use the cache because there can be several identical FFTs _solver = create_fft_plan(n); return; @@ -55,15 +62,17 @@ class PlanTree } [[nodiscard]] PlanTree* q_plan() const noexcept { + assert(has_next()); return _q; } [[nodiscard]] PlanTree* p_plan() const noexcept { + assert(has_next()); return _p; } - bool is_prime() const noexcept { - return _prime; + bool has_next() const noexcept { + return (_q != nullptr) && (_p != nullptr); } [[nodiscard]] std::shared_ptr solver() const noexcept { @@ -72,8 +81,29 @@ class PlanTree } private: - int _n; - bool _prime{false}; + //factorization with extract 2^n component, example (2, 2, 2, 3) -> (8, 3) + static std::vector _factor(int n) noexcept { + const int pn = n; + while (n % 2 == 0) { + n /= 2; + } + + std::vector fac; + if (n != pn) { + fac.push_back(pn / n); + } + + if (n == 1) { + return fac; + } + + const auto fc = factor(n); + fac.insert(fac.end(), fc.begin(), fc.end()); + std::sort(fac.begin(), fac.end()); + return fac; + } + + const int _n; PlanTree* _p{nullptr}; PlanTree* _q{nullptr}; std::shared_ptr _solver; @@ -93,9 +123,10 @@ void _transpose(cmplx_t* x, cmplx_t* t, int n, int m) noexcept { void _ctfft(const PlanTree* plan, cmplx_t* x, cmplx_t* mm, const cmplx_t* tw, int ntw) { const int n = plan->size(); - if (plan->is_prime()) { + if (!plan->has_next()) { //TODO: separate in/out pointer - plan->solver()->solve(x, x, n); + plan->solver()->solve(x, mm, n); + std::memcpy(x, mm, n * sizeof(cmplx_t)); return; } @@ -146,7 +177,7 @@ FactorFFTPlan::FactorFFTPlan(int n) [[nodiscard]] arr_cmplx FactorFFTPlan::solve(const arr_cmplx& x) const { DSPLIB_ASSERT(x.size() == _n, "input vector size is not equal fft size"); - arr_cmplx r(x); + arr_cmplx r(x); //TODO: remove copy _ctfft(_plan.get(), r.data(), _px.data(), _twiddle.data(), _n); return r; } diff --git a/lib/fft/pow2-fft.cpp b/lib/fft/pow2-fft.cpp index f933dc7..373e743 100644 --- a/lib/fft/pow2-fft.cpp +++ b/lib/fft/pow2-fft.cpp @@ -12,32 +12,21 @@ namespace dsplib { namespace { -inline int _get_bit(int a, int pos) noexcept { - return (a >> pos) & 0x1; -} - -inline void _set_bit(int& a, int pos, int bit) noexcept { - a &= ~(1 << pos); - a |= (bit << pos); -} - -inline int _bitrev(int a, int s) noexcept { - int r = 0; - for (int i = 0; i < ((s + 1) / 2); ++i) { - _set_bit(r, (s - i - 1), _get_bit(a, i)); - _set_bit(r, i, _get_bit(a, (s - i - 1))); - } - return r; -} - //generate a half of bitrev table //table is symmetry, table(n/2, n) == table(0, n/2)+1 std::vector _gen_bitrev_table(int n) noexcept { DSPLIB_ASSUME(n % 4 == 0); std::vector res(n / 2); + int h = 1; const int s = nextpow2(n); + for (int i = 0; i < (s - 1); ++i, h *= 2) { + for (int k = 0; k < h; ++k) { + res[k] = 2 * res[k]; + res[k + h] = res[k] + 1; + } + } for (int i = 0; i < n / 2; ++i) { - res[i] = _bitrev(i, s); + res[i] *= 2; } return res; } @@ -73,10 +62,9 @@ void _bitreverse(const cmplx_t* restrict x, cmplx_t* restrict y, const int32_t* DSPLIB_ASSUME(n % 2 == 0); const int n2 = n / 2; for (int i = 0; i < n2; ++i) { - const auto kl = bitrev[i]; - const auto kr = kl + 1; - y[i] = x[kl]; - y[n2 + i] = x[kr]; + const auto k = bitrev[i]; + y[i] = x[k]; + y[n2 + i] = x[k + 1]; } } @@ -98,6 +86,7 @@ arr_cmplx Pow2FftPlan::solve(const arr_cmplx& x) const { } void Pow2FftPlan::solve(const cmplx_t* x, cmplx_t* y, int n) const { + DSPLIB_ASSERT(x != y, "Pointers must be restricted"); _fft(x, y, n); } @@ -105,6 +94,7 @@ int Pow2FftPlan::size() const noexcept { return n_; } +//TODO: add "small" implementations (2, 4, 8) void Pow2FftPlan::_fft(const cmplx_t* restrict in, cmplx_t* restrict out, int n) const noexcept { DSPLIB_ASSUME(n % 2 == 0); DSPLIB_ASSUME(n >= 2); diff --git a/lib/math.cpp b/lib/math.cpp index 43b8b3c..0a53897 100644 --- a/lib/math.cpp +++ b/lib/math.cpp @@ -311,6 +311,7 @@ int nextpow2(int m) { } bool ispow2(int m) { + //TODO: optimization return (int(1) << nextpow2(m)) == m; } diff --git a/tests/fft_test.cpp b/tests/fft_test.cpp index 9acf147..5fb3e11 100644 --- a/tests/fft_test.cpp +++ b/tests/fft_test.cpp @@ -138,7 +138,7 @@ TEST(FFT, Ifft) { //------------------------------------------------------------------------------------------------- TEST(FFT, Irfft) { - for (auto nfft : {512, 1024, 1000, 200}) { + for (auto nfft : {512 * 3, 1024, 1000, 200}) { auto x = randn(nfft); auto y = fft(x); auto xc = ifft(y);