Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fft table storage size optimization #12

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 51 additions & 21 deletions lib/dft-tables.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "datacache.h"

#include <dsplib/math.h>
#include <dsplib/throw.h>

#include <cmath>
#include <vector>
Expand All @@ -11,42 +12,71 @@ namespace dsplib {
namespace tables {

//TODO: optional disable caching
static datacache<size_t, dft_ptr> g_dft_cache;
static datacache<size_t, fft2tb_ptr> g_dft_cache;
static datacache<size_t, bitrev_ptr> g_bitrev_cache;

//-------------------------------------------------------------------------------------------------
static dft_ptr _gen_dft_table(size_t size) {
auto tb = std::make_shared<std::vector<cmplx_t>>(size);
auto data = tb->data();
fft2tb_ptr fft2tb::alloc(size_t n) {
if (n != (1L << nextpow2(n))) {
DSPLIB_THROW("fft size is not power of 2");
}

real_t p;
for (size_t i = 0; i < size; ++i) {
p = i / real_t(size);
data[i].re = std::cos(2 * pi * p);
data[i].im = -std::sin(2 * pi * p);
if (g_dft_cache.cached(n)) {
return g_dft_cache.get(n);
}

return tb;
auto ptr = fft2tb_ptr(new fft2tb(n));
g_dft_cache.update(n, ptr);
return ptr;
}

//-------------------------------------------------------------------------------------------------
const dft_ptr dft_table(size_t size) {
if (g_dft_cache.cached(size)) {
return g_dft_cache.get(size);
}
void fft2tb::reset(size_t n) {
g_dft_cache.reset(n);
}

g_dft_cache.update(size, _gen_dft_table(size));
return g_dft_cache.get(size);
//-------------------------------------------------------------------------------------------------
bool fft2tb::is_cached(size_t n) {
return g_dft_cache.cached(n);
}

//-------------------------------------------------------------------------------------------------
void dft_clear(size_t size) {
g_dft_cache.reset(size);
arr_cmplx fft2tb::unpack() const noexcept {
arr_cmplx r(_n);

//real
for (size_t i = 0; i < _n4; i++) {
r[i].re = _cos_tb[i];
}
r[_n4].re = 0;
for (size_t i = 0; i < _n4; i++) {
r[_n4 + 1 + i].re = -_cos_tb[_n4 - i - 1];
}
for (size_t i = 0; i < _n2 - 1; i++) {
r[_n2 + 1 + i].re = r[_n2 - i - 1].re;
}

//imag
const uint32_t ns = (_n - 1);
for (size_t i = 0; i < _n; i++) {
r[i].im = r[(i + _n4) & ns].re;
}

return r;
}

//-------------------------------------------------------------------------------------------------
bool dft_cached(size_t size) {
return g_dft_cache.cached(size);
fft2tb::fft2tb(uint32_t n) noexcept
: _n{n}
, _n2{n / 2}
, _n4{n / 4}
, _cos_tb(_n4) {
assert(n >= 4);
assert(n == (1L << nextpow2(n)));
const real_t dt = 1 / real_t(_n);
for (size_t i = 0; i < _n4; ++i) {
_cos_tb[i] = std::cos(2 * pi * i * dt);
}
}

//-------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -89,7 +119,7 @@ static bitrev_ptr _gen_bitrev_table(size_t size) {
}

//-------------------------------------------------------------------------------------------------
const bitrev_ptr bitrev_table(size_t size) {
bitrev_ptr bitrev_table(size_t size) {
if (g_bitrev_cache.cached(size)) {
return g_bitrev_cache.get(size);
}
Expand Down
43 changes: 22 additions & 21 deletions lib/dft-tables.h
Original file line number Diff line number Diff line change
@@ -1,38 +1,39 @@
#pragma once

#include <dsplib/types.h>
#include <dsplib/array.h>

#include <vector>
#include <memory>
#include <stdint.h>
#include <cstdint>

namespace dsplib {
namespace tables {

using dft_ptr = std::shared_ptr<std::vector<cmplx_t>>;
class fft2tb;
using fft2tb_ptr = std::shared_ptr<fft2tb>;

//wrapper for table exp(-1i * 2 * pi * i / n) compresed to 1/4
class fft2tb
{
public:
static fft2tb_ptr alloc(size_t n);
static void reset(size_t n);
static bool is_cached(size_t n);

/*!
* \brief Get (or generate) a table for calculating DFT
* \param n DFT base
* \return Table pointer
*/
const dft_ptr dft_table(size_t n);
arr_cmplx unpack() const noexcept;

/*!
* \brief Clear table from cache
* \param n DFT base
*/
void dft_clear(size_t n);
private:
explicit fft2tb(uint32_t n) noexcept;

/*!
* \brief Check if table cached
* \param n DFT base
* \return Cached
*/
bool dft_cached(size_t n);
const uint32_t _n;
const uint32_t _n2;
const uint32_t _n4;
std::vector<real_t> _cos_tb;
};

//bit-reverse table
using bitrev_ptr = std::shared_ptr<std::vector<int32_t>>;
const bitrev_ptr bitrev_table(size_t n);
bitrev_ptr bitrev_table(size_t n);
bool bitrev_cached(size_t n);
void bitrev_clear(size_t n);

Expand Down
6 changes: 3 additions & 3 deletions lib/fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ class fft_plan_impl
const int n2 = 1L << nextpow2(n);
if (n == n2) {
//n == 2^K
auto brev = tables::bitrev_table(n);
auto coeff = tables::dft_table(n);
const auto brev = tables::bitrev_table(n);
const auto coeff = tables::fft2tb::alloc(n)->unpack();
solve = [brev, coeff, n](const arr_cmplx& x) {
arr_cmplx r = x;
_fft2(r.data(), coeff->data(), brev->data(), n);
_fft2(r.data(), coeff.data(), brev->data(), n);
return r;
};
} else {
Expand Down
2 changes: 1 addition & 1 deletion tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ CPMAddPackage(NAME googletest

file(GLOB_RECURSE SOURCES "*.cpp" "*.h")
add_executable(${PROJECT_NAME} ${SOURCES})
target_include_directories(${PROJECT_NAME} PUBLIC ${CMAKE_CURRENT_LIST_DIR})
target_include_directories(${PROJECT_NAME} PUBLIC ${CMAKE_CURRENT_LIST_DIR} "${CMAKE_SOURCE_DIR}/lib")
target_link_libraries(${PROJECT_NAME} PUBLIC dsplib gtest)
28 changes: 21 additions & 7 deletions tests/fft_test.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#include "tests_common.h"

#include <dft-tables.h>

//-------------------------------------------------------------------------------------------------
TEST(MathTest, FftReal) {
TEST(FFT, FftReal) {
using namespace dsplib;
int idx = 10;
int nfft = 512;
Expand All @@ -15,7 +17,7 @@ TEST(MathTest, FftReal) {
}

//-------------------------------------------------------------------------------------------------
TEST(MathTest, FftCmplx) {
TEST(FFT, FftCmplx) {
using namespace dsplib;
int idx = 10;
int nfft = 512;
Expand All @@ -28,7 +30,7 @@ TEST(MathTest, FftCmplx) {
}

//-------------------------------------------------------------------------------------------------
TEST(MathTest, Ifft) {
TEST(FFT, Ifft) {
using namespace dsplib;

{
Expand All @@ -49,7 +51,7 @@ TEST(MathTest, Ifft) {
}

//-------------------------------------------------------------------------------------------------
TEST(MathTest, Czt) {
TEST(FFT, Czt) {
using namespace dsplib;
arr_cmplx dft_ref = {6.00000000000000 + 0.00000000000000i, -1.50000000000000 + 0.866025403784439i,
-1.50000000000000 - 0.866025403784439i};
Expand All @@ -61,7 +63,7 @@ TEST(MathTest, Czt) {
}

//-------------------------------------------------------------------------------------------------
TEST(MathTest, CztICzt) {
TEST(FFT, CztICzt) {
using namespace dsplib;
for (size_t i = 0; i < 1000; i++) {
int n = randi({16, 2000});
Expand All @@ -73,7 +75,7 @@ TEST(MathTest, CztICzt) {
}

//-------------------------------------------------------------------------------------------------
TEST(MathTest, CztFft2) {
TEST(FFT, CztFft2) {
using namespace dsplib;
for (size_t i = 0; i < 1000; i++) {
int n = randi({16, 2000});
Expand All @@ -87,7 +89,7 @@ TEST(MathTest, CztFft2) {
}

//-------------------------------------------------------------------------------------------------
TEST(MathTest, CztIFft2) {
TEST(FFT, CztIFft2) {
using namespace dsplib;
for (size_t i = 0; i < 1000; i++) {
int n = randi({16, 2000});
Expand All @@ -99,3 +101,15 @@ TEST(MathTest, CztIFft2) {
ASSERT_EQ_ARR_CMPLX(y1, y2);
}
}

//-------------------------------------------------------------------------------------------------
TEST(FFT, Fft2Table) {
using namespace dsplib;
auto nfft_list = {4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192};
for (auto nfft : nfft_list) {
auto tb = tables::fft2tb::alloc(nfft);
auto x1 = tb->unpack();
auto x2 = expj(-2 * dsplib::pi * range(nfft) / nfft);
ASSERT_EQ_ARR_CMPLX(x1, x2);
}
}