diff --git a/envoy/stream_info/filter_state.h b/envoy/stream_info/filter_state.h index 7e7cfd7fdef0..f43ffe00eb3b 100644 --- a/envoy/stream_info/filter_state.h +++ b/envoy/stream_info/filter_state.h @@ -87,6 +87,8 @@ class FilterState { class Object { public: + using FieldType = absl::variant; + virtual ~Object() = default; /** @@ -102,21 +104,17 @@ class FilterState { * This method can be used to get an unstructured serialization result. */ virtual absl::optional serializeAsString() const { return absl::nullopt; } - }; - - /** - * Generic reflection support for the filter state objects. - */ - class ObjectReflection { - public: - virtual ~ObjectReflection() = default; - using FieldType = absl::variant; + /** + * @return bool true if the object supports field access. False if the object does not support + * field access. Default implementation returns false. + */ + virtual bool hasFieldSupport() const { return false; } /** - * @return FieldType a field value for a field name. + * @return FieldType a single state property or field value for a name. */ - virtual FieldType getField(absl::string_view) const PURE; + virtual FieldType getField(absl::string_view) const { return absl::monostate{}; } }; /** @@ -134,12 +132,6 @@ class FilterState { * is malformed. */ virtual std::unique_ptr createFromBytes(absl::string_view data) const PURE; - - /** - * @return std::unique_ptr for the runtime object - * Note that the reflection object is a view and should not outlive the object. - */ - virtual std::unique_ptr reflect(const Object*) const { return nullptr; } }; struct FilterObject { diff --git a/source/common/formatter/stream_info_formatter.cc b/source/common/formatter/stream_info_formatter.cc index c7f532db0098..5ff679a731fd 100644 --- a/source/common/formatter/stream_info_formatter.cc +++ b/source/common/formatter/stream_info_formatter.cc @@ -192,7 +192,6 @@ FilterStateFormatter::FilterStateFormatter(absl::string_view key, absl::optional if (!field_name.empty()) { format_ = FilterStateFormat::Field; field_name_ = std::string(field_name); - factory_ = Registry::FactoryRegistry::getFactory(key); } else if (serialize_as_string) { format_ = FilterStateFormat::String; } else { @@ -264,14 +263,7 @@ FilterStateFormatter::format(const StreamInfo::StreamInfo& stream_info) const { #endif } case FilterStateFormat::Field: { - if (!factory_) { - return absl::nullopt; - } - const auto reflection = factory_->reflect(state); - if (!reflection) { - return absl::nullopt; - } - auto field_value = reflection->getField(field_name_); + auto field_value = state->getField(field_name_); auto string_value = absl::visit(StringFieldVisitor(), field_value); if (!string_value) { return absl::nullopt; @@ -315,14 +307,7 @@ FilterStateFormatter::formatValue(const StreamInfo::StreamInfo& stream_info) con return SubstitutionFormatUtils::unspecifiedValue(); } case FilterStateFormat::Field: { - if (!factory_) { - return SubstitutionFormatUtils::unspecifiedValue(); - } - const auto reflection = factory_->reflect(state); - if (!reflection) { - return SubstitutionFormatUtils::unspecifiedValue(); - } - auto field_value = reflection->getField(field_name_); + auto field_value = state->getField(field_name_); auto string_value = absl::visit(StringFieldVisitor(), field_value); if (!string_value) { return SubstitutionFormatUtils::unspecifiedValue(); diff --git a/source/common/formatter/stream_info_formatter.h b/source/common/formatter/stream_info_formatter.h index 0cc80e0d911c..10f4288c9d19 100644 --- a/source/common/formatter/stream_info_formatter.h +++ b/source/common/formatter/stream_info_formatter.h @@ -110,7 +110,6 @@ class FilterStateFormatter : public StreamInfoFormatterProvider { const bool is_upstream_; FilterStateFormat format_; std::string field_name_; - StreamInfo::FilterState::ObjectFactory* factory_; }; class CommonDurationFormatter : public StreamInfoFormatterProvider { diff --git a/source/common/network/filter_state_dst_address.cc b/source/common/network/filter_state_dst_address.cc index b164573e5f7a..231f3e0b1c60 100644 --- a/source/common/network/filter_state_dst_address.cc +++ b/source/common/network/filter_state_dst_address.cc @@ -9,39 +9,25 @@ absl::optional AddressObject::hash() const { return HashUtil::xxHash64(address_->asStringView()); } -class AddressObjectReflection : public StreamInfo::FilterState::ObjectReflection { -public: - AddressObjectReflection(const AddressObject* object) : object_(object) {} - FieldType getField(absl::string_view field_name) const override { - const auto* ip = object_->address_->ip(); - if (!ip) { - return {}; - } - if (field_name == "ip") { - return ip->addressAsString(); - } else if (field_name == "port") { - return int64_t(ip->port()); - } +StreamInfo::FilterState::Object::FieldType +AddressObject::getField(absl::string_view field_name) const { + const auto* ip = address_->ip(); + if (!ip) { return {}; } - -private: - const AddressObject* object_; -}; + if (field_name == "ip") { + return ip->addressAsString(); + } else if (field_name == "port") { + return int64_t(ip->port()); + } + return {}; +} std::unique_ptr BaseAddressObjectFactory::createFromBytes(absl::string_view data) const { const auto address = Utility::parseInternetAddressAndPortNoThrow(std::string(data)); return address ? std::make_unique(address) : nullptr; } -std::unique_ptr -BaseAddressObjectFactory::reflect(const StreamInfo::FilterState::Object* data) const { - const auto* object = dynamic_cast(data); - if (object) { - return std::make_unique(object); - } - return nullptr; -} } // namespace Network } // namespace Envoy diff --git a/source/common/network/filter_state_dst_address.h b/source/common/network/filter_state_dst_address.h index ec6c565fd689..b35cb25b9fc0 100644 --- a/source/common/network/filter_state_dst_address.h +++ b/source/common/network/filter_state_dst_address.h @@ -17,13 +17,15 @@ class AddressObject : public StreamInfo::FilterState::Object, public Hashable { absl::optional serializeAsString() const override { return address_ ? absl::make_optional(address_->asString()) : absl::nullopt; } + bool hasFieldSupport() const override { return true; } + FieldType getField(absl::string_view field_name) const override; + // Implements hashing interface because the value is applied once per upstream connection. // Multiple streams sharing the upstream connection must have the same address object. absl::optional hash() const override; private: const Network::Address::InstanceConstSharedPtr address_; - friend class AddressObjectReflection; }; /** @@ -33,8 +35,6 @@ class BaseAddressObjectFactory : public StreamInfo::FilterState::ObjectFactory { public: std::unique_ptr createFromBytes(absl::string_view data) const override; - std::unique_ptr - reflect(const StreamInfo::FilterState::Object* data) const override; }; } // namespace Network diff --git a/source/extensions/filters/common/expr/context.cc b/source/extensions/filters/common/expr/context.cc index 10438dff4df2..a13a39656af3 100644 --- a/source/extensions/filters/common/expr/context.cc +++ b/source/extensions/filters/common/expr/context.cc @@ -300,13 +300,12 @@ absl::optional PeerWrapper::operator[](CelValue key) const { class FilterStateObjectWrapper : public google::api::expr::runtime::CelMap { public: - FilterStateObjectWrapper(const StreamInfo::FilterState::ObjectReflection* reflection) - : reflection_(reflection) {} + FilterStateObjectWrapper(const StreamInfo::FilterState::Object* object) : object_(object) {} absl::optional operator[](CelValue key) const override { - if (reflection_ == nullptr || !key.IsString()) { + if (object_ == nullptr || !key.IsString()) { return {}; } - auto field_value = reflection_->getField(key.StringOrDie().value()); + auto field_value = object_->getField(key.StringOrDie().value()); return absl::visit(Visitor{}, field_value); } // Default stubs. @@ -325,7 +324,7 @@ class FilterStateObjectWrapper : public google::api::expr::runtime::CelMap { } absl::optional operator()(absl::monostate) { return {}; } }; - const StreamInfo::FilterState::ObjectReflection* reflection_; + const StreamInfo::FilterState::Object* object_; }; absl::optional FilterStateWrapper::operator[](CelValue key) const { @@ -339,17 +338,11 @@ absl::optional FilterStateWrapper::operator[](CelValue key) const { if (cel_state) { return cel_state->exprValue(&arena_, false); } else if (object != nullptr) { - // Attempt to find the reflection object. - auto factory = - Registry::FactoryRegistry::getFactory(value); - if (factory) { - auto reflection = factory->reflect(object); - if (reflection) { - auto* raw_reflection = reflection.release(); - arena_.Own(raw_reflection); - return CelValue::CreateMap( - ProtobufWkt::Arena::Create(&arena_, raw_reflection)); - } + // TODO(wbpcode): the implementation of cannot handle the case where the object has provided + // field support, but callers only want to access the whole object. + if (object->hasFieldSupport()) { + return CelValue::CreateMap( + ProtobufWkt::Arena::Create(&arena_, object)); } absl::optional serialized = object->serializeAsString(); if (serialized.has_value()) { diff --git a/test/common/formatter/substitution_formatter_test.cc b/test/common/formatter/substitution_formatter_test.cc index 41b3e2fcbac4..11bfd2135c6d 100644 --- a/test/common/formatter/substitution_formatter_test.cc +++ b/test/common/formatter/substitution_formatter_test.cc @@ -123,18 +123,10 @@ class TestSerializedStringFilterState : public StreamInfo::FilterState::Object { message->set_value(raw_string_ + " By TYPED"); return message; } - -private: - std::string raw_string_; - friend class TestSerializedStringReflection; -}; - -class TestSerializedStringReflection : public StreamInfo::FilterState::ObjectReflection { -public: - TestSerializedStringReflection(const TestSerializedStringFilterState* data) : data_(data) {} + bool hasFieldSupport() const override { return true; } FieldType getField(absl::string_view field_name) const override { if (field_name == "test_field") { - return data_->raw_string_; + return raw_string_; } else if (field_name == "test_num") { return 137; } @@ -142,25 +134,9 @@ class TestSerializedStringReflection : public StreamInfo::FilterState::ObjectRef } private: - const TestSerializedStringFilterState* data_; -}; - -class TestSerializedStringFilterStateFactory : public StreamInfo::FilterState::ObjectFactory { -public: - std::string name() const override { return "test_key"; } - std::unique_ptr - createFromBytes(absl::string_view) const override { - return nullptr; - } - std::unique_ptr - reflect(const StreamInfo::FilterState::Object* data) const override { - return std::make_unique( - dynamic_cast(data)); - } + std::string raw_string_; }; -REGISTER_FACTORY(TestSerializedStringFilterStateFactory, StreamInfo::FilterState::ObjectFactory); - // Test tests multiple versions of variadic template method parseSubcommand // extracting tokens. TEST(SubstitutionFormatParser, commandParser) { diff --git a/test/extensions/clusters/original_dst/original_dst_cluster_test.cc b/test/extensions/clusters/original_dst/original_dst_cluster_test.cc index d0e1cd8e82f7..ba0777efa77b 100644 --- a/test/extensions/clusters/original_dst/original_dst_cluster_test.cc +++ b/test/extensions/clusters/original_dst/original_dst_cluster_test.cc @@ -1163,10 +1163,8 @@ TEST(DestinationAddress, ObjectFactory) { auto object = factory->createFromBytes(address); ASSERT_NE(nullptr, object); EXPECT_EQ(address, object->serializeAsString()); - auto mirror = factory->reflect(object.get()); - ASSERT_NE(nullptr, mirror); - EXPECT_THAT(mirror->getField("ip"), testing::VariantWith("10.0.0.10")); - EXPECT_THAT(mirror->getField("port"), testing::VariantWith(8080)); + EXPECT_THAT(object->getField("ip"), testing::VariantWith("10.0.0.10")); + EXPECT_THAT(object->getField("port"), testing::VariantWith(8080)); EXPECT_EQ(nullptr, factory->createFromBytes("foo")); } diff --git a/test/extensions/filters/common/expr/context_test.cc b/test/extensions/filters/common/expr/context_test.cc index c766d47e258d..f7cec4fb53a2 100644 --- a/test/extensions/filters/common/expr/context_test.cc +++ b/test/extensions/filters/common/expr/context_test.cc @@ -777,6 +777,8 @@ TEST(Context, FilterStateAttributes) { StreamInfo::FilterStateImpl filter_state(StreamInfo::FilterState::LifeSpan::FilterChain); ProtobufWkt::Arena arena; FilterStateWrapper wrapper(arena, filter_state); + auto status_or = wrapper.ListKeys(&arena); + EXPECT_EQ(status_or.status().message(), "ListKeys() is not implemented"); const std::string key = "filter_state_key"; const std::string serialized = "filter_state_value"; @@ -941,6 +943,21 @@ TEST(Context, XDSAttributes) { } } +TEST(Context, EmptyXdsWrapper) { + Protobuf::Arena arena; + XDSWrapper wrapper(arena, nullptr, nullptr); + + { + const auto value = wrapper[CelValue::CreateStringView(Node)]; + EXPECT_FALSE(value.has_value()); + } + + { + const auto value = wrapper[CelValue::CreateStringView(ClusterName)]; + EXPECT_FALSE(value.has_value()); + } +} + } // namespace } // namespace Expr } // namespace Common