Skip to content

Commit

Permalink
refactoring: refactored the FilterState object field support (#36399)
Browse files Browse the repository at this point in the history
The filter state reflection provides a great feature to access the inner
status/property of filter state. However, it has two limitations:
1. It requires the object key be same with the factory key. This
limitation make we cannot set multiple objects that with same type.
2. It is a little complex to enable the Field support. We need to define
additional reflection class and a factory class.

This PR make things much simpler.


Risk Level: low.
Testing: n/a.
Docs Changes: n/a.
Release Notes: n/a.
Platform Specific Features: n/a.

---------

Signed-off-by: wangbaiping <wangbaiping@bytedance.com>
  • Loading branch information
wbpcode authored Oct 3, 2024
1 parent 37b725d commit 24fe164
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 110 deletions.
26 changes: 9 additions & 17 deletions envoy/stream_info/filter_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ class FilterState {

class Object {
public:
using FieldType = absl::variant<absl::monostate, absl::string_view, int64_t>;

virtual ~Object() = default;

/**
Expand All @@ -102,21 +104,17 @@ class FilterState {
* This method can be used to get an unstructured serialization result.
*/
virtual absl::optional<std::string> serializeAsString() const { return absl::nullopt; }
};

/**
* Generic reflection support for the filter state objects.
*/
class ObjectReflection {
public:
virtual ~ObjectReflection() = default;

using FieldType = absl::variant<absl::monostate, absl::string_view, int64_t>;
/**
* @return bool true if the object supports field access. False if the object does not support
* field access. Default implementation returns false.
*/
virtual bool hasFieldSupport() const { return false; }

/**
* @return FieldType a field value for a field name.
* @return FieldType a single state property or field value for a name.
*/
virtual FieldType getField(absl::string_view) const PURE;
virtual FieldType getField(absl::string_view) const { return absl::monostate{}; }
};

/**
Expand All @@ -134,12 +132,6 @@ class FilterState {
* is malformed.
*/
virtual std::unique_ptr<Object> createFromBytes(absl::string_view data) const PURE;

/**
* @return std::unique_ptr<ObjectReflection> for the runtime object
* Note that the reflection object is a view and should not outlive the object.
*/
virtual std::unique_ptr<ObjectReflection> reflect(const Object*) const { return nullptr; }
};

struct FilterObject {
Expand Down
19 changes: 2 additions & 17 deletions source/common/formatter/stream_info_formatter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ FilterStateFormatter::FilterStateFormatter(absl::string_view key, absl::optional
if (!field_name.empty()) {
format_ = FilterStateFormat::Field;
field_name_ = std::string(field_name);
factory_ = Registry::FactoryRegistry<StreamInfo::FilterState::ObjectFactory>::getFactory(key);
} else if (serialize_as_string) {
format_ = FilterStateFormat::String;
} else {
Expand Down Expand Up @@ -264,14 +263,7 @@ FilterStateFormatter::format(const StreamInfo::StreamInfo& stream_info) const {
#endif
}
case FilterStateFormat::Field: {
if (!factory_) {
return absl::nullopt;
}
const auto reflection = factory_->reflect(state);
if (!reflection) {
return absl::nullopt;
}
auto field_value = reflection->getField(field_name_);
auto field_value = state->getField(field_name_);
auto string_value = absl::visit(StringFieldVisitor(), field_value);
if (!string_value) {
return absl::nullopt;
Expand Down Expand Up @@ -315,14 +307,7 @@ FilterStateFormatter::formatValue(const StreamInfo::StreamInfo& stream_info) con
return SubstitutionFormatUtils::unspecifiedValue();
}
case FilterStateFormat::Field: {
if (!factory_) {
return SubstitutionFormatUtils::unspecifiedValue();
}
const auto reflection = factory_->reflect(state);
if (!reflection) {
return SubstitutionFormatUtils::unspecifiedValue();
}
auto field_value = reflection->getField(field_name_);
auto field_value = state->getField(field_name_);
auto string_value = absl::visit(StringFieldVisitor(), field_value);
if (!string_value) {
return SubstitutionFormatUtils::unspecifiedValue();
Expand Down
1 change: 0 additions & 1 deletion source/common/formatter/stream_info_formatter.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ class FilterStateFormatter : public StreamInfoFormatterProvider {
const bool is_upstream_;
FilterStateFormat format_;
std::string field_name_;
StreamInfo::FilterState::ObjectFactory* factory_;
};

class CommonDurationFormatter : public StreamInfoFormatterProvider {
Expand Down
36 changes: 11 additions & 25 deletions source/common/network/filter_state_dst_address.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,39 +9,25 @@ absl::optional<uint64_t> AddressObject::hash() const {
return HashUtil::xxHash64(address_->asStringView());
}

class AddressObjectReflection : public StreamInfo::FilterState::ObjectReflection {
public:
AddressObjectReflection(const AddressObject* object) : object_(object) {}
FieldType getField(absl::string_view field_name) const override {
const auto* ip = object_->address_->ip();
if (!ip) {
return {};
}
if (field_name == "ip") {
return ip->addressAsString();
} else if (field_name == "port") {
return int64_t(ip->port());
}
StreamInfo::FilterState::Object::FieldType
AddressObject::getField(absl::string_view field_name) const {
const auto* ip = address_->ip();
if (!ip) {
return {};
}

private:
const AddressObject* object_;
};
if (field_name == "ip") {
return ip->addressAsString();
} else if (field_name == "port") {
return int64_t(ip->port());
}
return {};
}

std::unique_ptr<StreamInfo::FilterState::Object>
BaseAddressObjectFactory::createFromBytes(absl::string_view data) const {
const auto address = Utility::parseInternetAddressAndPortNoThrow(std::string(data));
return address ? std::make_unique<AddressObject>(address) : nullptr;
}
std::unique_ptr<StreamInfo::FilterState::ObjectReflection>
BaseAddressObjectFactory::reflect(const StreamInfo::FilterState::Object* data) const {
const auto* object = dynamic_cast<const AddressObject*>(data);
if (object) {
return std::make_unique<AddressObjectReflection>(object);
}
return nullptr;
}

} // namespace Network
} // namespace Envoy
6 changes: 3 additions & 3 deletions source/common/network/filter_state_dst_address.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ class AddressObject : public StreamInfo::FilterState::Object, public Hashable {
absl::optional<std::string> serializeAsString() const override {
return address_ ? absl::make_optional(address_->asString()) : absl::nullopt;
}
bool hasFieldSupport() const override { return true; }
FieldType getField(absl::string_view field_name) const override;

// Implements hashing interface because the value is applied once per upstream connection.
// Multiple streams sharing the upstream connection must have the same address object.
absl::optional<uint64_t> hash() const override;

private:
const Network::Address::InstanceConstSharedPtr address_;
friend class AddressObjectReflection;
};

/**
Expand All @@ -33,8 +35,6 @@ class BaseAddressObjectFactory : public StreamInfo::FilterState::ObjectFactory {
public:
std::unique_ptr<StreamInfo::FilterState::Object>
createFromBytes(absl::string_view data) const override;
std::unique_ptr<StreamInfo::FilterState::ObjectReflection>
reflect(const StreamInfo::FilterState::Object* data) const override;
};

} // namespace Network
Expand Down
25 changes: 9 additions & 16 deletions source/extensions/filters/common/expr/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -300,13 +300,12 @@ absl::optional<CelValue> PeerWrapper::operator[](CelValue key) const {

class FilterStateObjectWrapper : public google::api::expr::runtime::CelMap {
public:
FilterStateObjectWrapper(const StreamInfo::FilterState::ObjectReflection* reflection)
: reflection_(reflection) {}
FilterStateObjectWrapper(const StreamInfo::FilterState::Object* object) : object_(object) {}
absl::optional<CelValue> operator[](CelValue key) const override {
if (reflection_ == nullptr || !key.IsString()) {
if (object_ == nullptr || !key.IsString()) {
return {};
}
auto field_value = reflection_->getField(key.StringOrDie().value());
auto field_value = object_->getField(key.StringOrDie().value());
return absl::visit(Visitor{}, field_value);
}
// Default stubs.
Expand All @@ -325,7 +324,7 @@ class FilterStateObjectWrapper : public google::api::expr::runtime::CelMap {
}
absl::optional<CelValue> operator()(absl::monostate) { return {}; }
};
const StreamInfo::FilterState::ObjectReflection* reflection_;
const StreamInfo::FilterState::Object* object_;
};

absl::optional<CelValue> FilterStateWrapper::operator[](CelValue key) const {
Expand All @@ -339,17 +338,11 @@ absl::optional<CelValue> FilterStateWrapper::operator[](CelValue key) const {
if (cel_state) {
return cel_state->exprValue(&arena_, false);
} else if (object != nullptr) {
// Attempt to find the reflection object.
auto factory =
Registry::FactoryRegistry<StreamInfo::FilterState::ObjectFactory>::getFactory(value);
if (factory) {
auto reflection = factory->reflect(object);
if (reflection) {
auto* raw_reflection = reflection.release();
arena_.Own(raw_reflection);
return CelValue::CreateMap(
ProtobufWkt::Arena::Create<FilterStateObjectWrapper>(&arena_, raw_reflection));
}
// TODO(wbpcode): the implementation of cannot handle the case where the object has provided
// field support, but callers only want to access the whole object.
if (object->hasFieldSupport()) {
return CelValue::CreateMap(
ProtobufWkt::Arena::Create<FilterStateObjectWrapper>(&arena_, object));
}
absl::optional<std::string> serialized = object->serializeAsString();
if (serialized.has_value()) {
Expand Down
30 changes: 3 additions & 27 deletions test/common/formatter/substitution_formatter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,44 +123,20 @@ class TestSerializedStringFilterState : public StreamInfo::FilterState::Object {
message->set_value(raw_string_ + " By TYPED");
return message;
}

private:
std::string raw_string_;
friend class TestSerializedStringReflection;
};

class TestSerializedStringReflection : public StreamInfo::FilterState::ObjectReflection {
public:
TestSerializedStringReflection(const TestSerializedStringFilterState* data) : data_(data) {}
bool hasFieldSupport() const override { return true; }
FieldType getField(absl::string_view field_name) const override {
if (field_name == "test_field") {
return data_->raw_string_;
return raw_string_;
} else if (field_name == "test_num") {
return 137;
}
return {};
}

private:
const TestSerializedStringFilterState* data_;
};

class TestSerializedStringFilterStateFactory : public StreamInfo::FilterState::ObjectFactory {
public:
std::string name() const override { return "test_key"; }
std::unique_ptr<StreamInfo::FilterState::Object>
createFromBytes(absl::string_view) const override {
return nullptr;
}
std::unique_ptr<StreamInfo::FilterState::ObjectReflection>
reflect(const StreamInfo::FilterState::Object* data) const override {
return std::make_unique<TestSerializedStringReflection>(
dynamic_cast<const TestSerializedStringFilterState*>(data));
}
std::string raw_string_;
};

REGISTER_FACTORY(TestSerializedStringFilterStateFactory, StreamInfo::FilterState::ObjectFactory);

// Test tests multiple versions of variadic template method parseSubcommand
// extracting tokens.
TEST(SubstitutionFormatParser, commandParser) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1163,10 +1163,8 @@ TEST(DestinationAddress, ObjectFactory) {
auto object = factory->createFromBytes(address);
ASSERT_NE(nullptr, object);
EXPECT_EQ(address, object->serializeAsString());
auto mirror = factory->reflect(object.get());
ASSERT_NE(nullptr, mirror);
EXPECT_THAT(mirror->getField("ip"), testing::VariantWith<absl::string_view>("10.0.0.10"));
EXPECT_THAT(mirror->getField("port"), testing::VariantWith<int64_t>(8080));
EXPECT_THAT(object->getField("ip"), testing::VariantWith<absl::string_view>("10.0.0.10"));
EXPECT_THAT(object->getField("port"), testing::VariantWith<int64_t>(8080));
EXPECT_EQ(nullptr, factory->createFromBytes("foo"));
}

Expand Down
17 changes: 17 additions & 0 deletions test/extensions/filters/common/expr/context_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,8 @@ TEST(Context, FilterStateAttributes) {
StreamInfo::FilterStateImpl filter_state(StreamInfo::FilterState::LifeSpan::FilterChain);
ProtobufWkt::Arena arena;
FilterStateWrapper wrapper(arena, filter_state);
auto status_or = wrapper.ListKeys(&arena);
EXPECT_EQ(status_or.status().message(), "ListKeys() is not implemented");

const std::string key = "filter_state_key";
const std::string serialized = "filter_state_value";
Expand Down Expand Up @@ -941,6 +943,21 @@ TEST(Context, XDSAttributes) {
}
}

TEST(Context, EmptyXdsWrapper) {
Protobuf::Arena arena;
XDSWrapper wrapper(arena, nullptr, nullptr);

{
const auto value = wrapper[CelValue::CreateStringView(Node)];
EXPECT_FALSE(value.has_value());
}

{
const auto value = wrapper[CelValue::CreateStringView(ClusterName)];
EXPECT_FALSE(value.has_value());
}
}

} // namespace
} // namespace Expr
} // namespace Common
Expand Down

0 comments on commit 24fe164

Please sign in to comment.