From 7491b28f8689fb05d9a8e49894cdbe65f490e110 Mon Sep 17 00:00:00 2001 From: Alejandro Estringana Ruiz Date: Wed, 17 Jul 2024 14:58:13 +0200 Subject: [PATCH] Improve sampler (#2761) --- appsec/src/helper/client.cpp | 14 ++----- appsec/src/helper/sampler.hpp | 58 ++++------------------------ appsec/tests/helper/sampler_test.cpp | 31 ++------------- appsec/tests/helper/service_test.cpp | 12 +++--- 4 files changed, 21 insertions(+), 94 deletions(-) diff --git a/appsec/src/helper/client.cpp b/appsec/src/helper/client.cpp index 92d13ea854..bb6497c221 100644 --- a/appsec/src/helper/client.cpp +++ b/appsec/src/helper/client.cpp @@ -416,16 +416,10 @@ bool client::handle_command(network::request_shutdown::request &command) auto free_ctx = defer([this]() { this->context_.reset(); }); auto sampler = service_->get_schema_sampler(); - std::optional scope; - if (sampler) { - scope = sampler->get(); - if (scope.has_value()) { - parameter context_processor = parameter::map(); - context_processor.add( - "extract-schema", parameter::as_boolean(true)); - command.data.add( - "waf.context.processor", std::move(context_processor)); - } + if (sampler && sampler->picked()) { + parameter context_processor = parameter::map(); + context_processor.add("extract-schema", parameter::as_boolean(true)); + command.data.add("waf.context.processor", std::move(context_processor)); } auto response = publish(command); diff --git a/appsec/src/helper/sampler.hpp b/appsec/src/helper/sampler.hpp index 4a3dc2cc8c..bcc76fd5af 100644 --- a/appsec/src/helper/sampler.hpp +++ b/appsec/src/helper/sampler.hpp @@ -9,9 +9,7 @@ #include #include #include -#include #include -#include namespace dds { static const double min_rate = 0.0001; @@ -29,63 +27,21 @@ class sampler { } // NOLINTEND(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers) } - class scope { - public: - explicit scope(std::atomic &concurrent) : concurrent_(&concurrent) - { - concurrent_->store(true, std::memory_order_relaxed); - } - - scope(const scope &) = delete; - scope &operator=(const scope &) = delete; - scope(scope &&oth) noexcept - { - concurrent_ = oth.concurrent_; - oth.concurrent_ = nullptr; - } - scope &operator=(scope &&oth) - { - concurrent_ = oth.concurrent_; - oth.concurrent_ = nullptr; - - return *this; - } - - ~scope() - { - if (concurrent_ != nullptr) { - concurrent_->store(false, std::memory_order_relaxed); - } - } - - protected: - std::atomic *concurrent_; - }; - std::optional get() + bool picked() { - const std::lock_guard lock_guard(mtx_); - - std::optional result = std::nullopt; - - if (!concurrent_ && floor(request_ * sample_rate_) != - floor((request_ + 1) * sample_rate_)) { - result = {scope{concurrent_}}; - } - - if (request_ < std::numeric_limits::max()) { - request_++; - } else { - request_ = 1; + if (sample_rate_ == 1) { + return true; } - return result; + auto old_request = request_.fetch_add(1, std::memory_order_relaxed); + return floor(old_request * sample_rate_) != + floor((request_)*sample_rate_); } protected: - unsigned request_{1}; + std::atomic request_{0}; double sample_rate_; - std::atomic concurrent_{false}; std::mutex mtx_; }; } // namespace dds diff --git a/appsec/tests/helper/sampler_test.cpp b/appsec/tests/helper/sampler_test.cpp index eb2c90d8a8..7f7054b7ca 100644 --- a/appsec/tests/helper/sampler_test.cpp +++ b/appsec/tests/helper/sampler_test.cpp @@ -15,7 +15,7 @@ class sampler : public dds::sampler { public: sampler(double sample_rate) : dds::sampler(sample_rate) {} void set_request(unsigned int i) { request_ = i; } - auto get_request() { return request_; } + unsigned int get_request() { return request_; } }; } // namespace mock @@ -25,8 +25,7 @@ std::atomic picked = 0; void count_picked(dds::sampler &sampler, int iterations) { for (int i = 0; i < iterations; i++) { - auto is_pick = sampler.get(); - if (is_pick != std::nullopt) { + if (sampler.picked()) { picked++; } } @@ -198,29 +197,7 @@ TEST(SamplerTest, TestOverflow) { mock::sampler s(0); s.set_request(UINT_MAX); - s.get(); - EXPECT_EQ(1, s.get_request()); -} - -TEST(ScopeTest, TestConcurrent) -{ - std::atomic concurrent = false; - { - auto s = sampler::scope(std::ref(concurrent)); - EXPECT_TRUE(concurrent); - } - EXPECT_FALSE(concurrent); -} - -TEST(ScopeTest, TestItDoesNotPickTokenUntilScopeReleased) -{ - sampler sampler(1); - auto is_pick = sampler.get(); - EXPECT_TRUE(is_pick != std::nullopt); - is_pick = sampler.get(); - EXPECT_FALSE(is_pick != std::nullopt); - is_pick.reset(); - is_pick = sampler.get(); - EXPECT_TRUE(is_pick != std::nullopt); + s.picked(); + EXPECT_EQ(0, s.get_request()); } } // namespace dds diff --git a/appsec/tests/helper/service_test.cpp b/appsec/tests/helper/service_test.cpp index 5af59db440..9fd6bfc340 100644 --- a/appsec/tests/helper/service_test.cpp +++ b/appsec/tests/helper/service_test.cpp @@ -57,7 +57,7 @@ TEST(ServiceTest, ServicePickSchemaExtractionSamples) auto s = service( engine, service_config, nullptr, {true, all_requests_are_picked}); - EXPECT_TRUE(s.get_schema_sampler()->get().has_value()); + EXPECT_TRUE(s.get_schema_sampler()->picked()); } { // Constructor. It does not pick based on rate @@ -65,7 +65,7 @@ TEST(ServiceTest, ServicePickSchemaExtractionSamples) auto s = service( engine, service_config, nullptr, {true, no_request_is_picked}); - EXPECT_FALSE(s.get_schema_sampler()->get().has_value()); + EXPECT_FALSE(s.get_schema_sampler()->picked()); } { // Constructor. It does not pick if disabled @@ -74,7 +74,7 @@ TEST(ServiceTest, ServicePickSchemaExtractionSamples) auto s = service(engine, service_config, nullptr, {schema_extraction_disabled, all_requests_are_picked}); - EXPECT_FALSE(s.get_schema_sampler()->get().has_value()); + EXPECT_FALSE(s.get_schema_sampler()->picked()); } { // Static constructor. It picks based on rate @@ -83,7 +83,7 @@ TEST(ServiceTest, ServicePickSchemaExtractionSamples) auto service = service::from_settings( service_identifier(sid), engine_settings, {}, meta, metrics, false); - EXPECT_TRUE(service->get_schema_sampler()->get().has_value()); + EXPECT_TRUE(service->get_schema_sampler()->picked()); } { // Static constructor. It does not pick based on rate @@ -92,7 +92,7 @@ TEST(ServiceTest, ServicePickSchemaExtractionSamples) auto service = service::from_settings( service_identifier(sid), engine_settings, {}, meta, metrics, false); - EXPECT_FALSE(service->get_schema_sampler()->get().has_value()); + EXPECT_FALSE(service->get_schema_sampler()->picked()); } { // Static constructor. It does not pick if disabled @@ -101,7 +101,7 @@ TEST(ServiceTest, ServicePickSchemaExtractionSamples) auto service = service::from_settings( service_identifier(sid), engine_settings, {}, meta, metrics, false); - EXPECT_FALSE(service->get_schema_sampler()->get().has_value()); + EXPECT_FALSE(service->get_schema_sampler()->picked()); } }