diff --git a/source/common/tcp_proxy/tcp_proxy.cc b/source/common/tcp_proxy/tcp_proxy.cc index 4c10ffc6e2..17df75177d 100644 --- a/source/common/tcp_proxy/tcp_proxy.cc +++ b/source/common/tcp_proxy/tcp_proxy.cc @@ -239,6 +239,12 @@ void Filter::initialize(Network::ReadFilterCallbacks& callbacks, bool set_connec getStreamInfo().setDownstreamLocalAddress(read_callbacks_->connection().localAddress()); getStreamInfo().setDownstreamRemoteAddress(read_callbacks_->connection().remoteAddress()); getStreamInfo().setDownstreamSslConnection(read_callbacks_->connection().ssl()); + read_callbacks_->connection().streamInfo().setDownstreamLocalAddress( + read_callbacks_->connection().localAddress()); + read_callbacks_->connection().streamInfo().setDownstreamRemoteAddress( + read_callbacks_->connection().remoteAddress()); + read_callbacks_->connection().streamInfo().setDownstreamSslConnection( + read_callbacks_->connection().ssl()); // Need to disable reads so that we don't write to an upstream that might fail // in onData(). This will get re-enabled when the upstream connection is @@ -469,6 +475,10 @@ void Filter::onPoolReady(Tcp::ConnectionPool::ConnectionDataPtr&& conn_data, getStreamInfo().onUpstreamHostSelected(host); getStreamInfo().setUpstreamLocalAddress(connection.localAddress()); getStreamInfo().setUpstreamSslConnection(connection.streamInfo().downstreamSslConnection()); + read_callbacks_->connection().streamInfo().onUpstreamHostSelected(host); + read_callbacks_->connection().streamInfo().setUpstreamLocalAddress(connection.localAddress()); + read_callbacks_->connection().streamInfo().setUpstreamSslConnection( + connection.streamInfo().downstreamSslConnection()); read_callbacks_->connection().streamInfo().setUpstreamFilterState( connection.streamInfo().filterState()); diff --git a/source/extensions/common/wasm/context.cc b/source/extensions/common/wasm/context.cc index e3dd1198ae..7bb8a10a34 100644 --- a/source/extensions/common/wasm/context.cc +++ b/source/extensions/common/wasm/context.cc @@ -348,32 +348,29 @@ WasmResult serializeValue(Filters::Common::Expr::CelValue value, std::string* re class WasmStateWrapper : public google::api::expr::runtime::CelMap { public: WasmStateWrapper(const StreamInfo::FilterState& filter_state, - const StreamInfo::FilterState* connection_filter_state) - : filter_state_(filter_state), connection_filter_state_(connection_filter_state) {} - WasmStateWrapper(const StreamInfo::FilterState& filter_state) - : filter_state_(filter_state), connection_filter_state_(nullptr) {} + const StreamInfo::FilterState* upstream_connection_filter_state) + : filter_state_(filter_state), + upstream_connection_filter_state_(upstream_connection_filter_state) {} absl::optional operator[](google::api::expr::runtime::CelValue key) const override { if (!key.IsString()) { return {}; } auto value = key.StringOrDie().value(); - try { + if (filter_state_.hasData(value)) { const WasmState& result = filter_state_.getDataReadOnly(value); return google::api::expr::runtime::CelValue::CreateBytes(&result.value()); - } catch (const EnvoyException& e) { - // If doesn't exist in request filter state, try looking up in connection filter state. - try { - if (connection_filter_state_) { - const WasmState& result = connection_filter_state_->getDataReadOnly(value); - return google::api::expr::runtime::CelValue::CreateBytes(&result.value()); - } - } catch (const EnvoyException& e) { - return {}; - } - return {}; } + + if (upstream_connection_filter_state_ && + upstream_connection_filter_state_->hasData(value)) { + const WasmState& result = + upstream_connection_filter_state_->getDataReadOnly(value); + return google::api::expr::runtime::CelValue::CreateBytes(&result.value()); + } + return {}; } + int size() const override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } bool empty() const override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } const google::api::expr::runtime::CelList* ListKeys() const override { @@ -382,7 +379,7 @@ class WasmStateWrapper : public google::api::expr::runtime::CelMap { private: const StreamInfo::FilterState& filter_state_; - const StreamInfo::FilterState* connection_filter_state_; + const StreamInfo::FilterState* upstream_connection_filter_state_; }; #define PROPERTY_TOKENS(_f) \ @@ -423,14 +420,9 @@ Context::FindValue(absl::string_view name, Protobuf::Arena* arena) const { break; case PropertyToken::FILTER_STATE: if (info) { - const Envoy::Network::Connection* connection = getConnection(); - if (connection) { - return CelValue::CreateMap(Protobuf::Arena::Create( - arena, info->filterState(), &connection->streamInfo().filterState())); - } else { - return CelValue::CreateMap( - Protobuf::Arena::Create(arena, info->filterState())); - } + + return CelValue::CreateMap(Protobuf::Arena::Create( + arena, info->filterState(), info->upstreamFilterState().get())); } break; case PropertyToken::REQUEST: @@ -1004,6 +996,10 @@ const Network::Connection* Context::getConnection() const { return encoder_callbacks_->connection(); } else if (decoder_callbacks_) { return decoder_callbacks_->connection(); + } else if (network_read_filter_callbacks_) { + return &network_read_filter_callbacks_->connection(); + } else if (network_write_filter_callbacks_) { + return &network_write_filter_callbacks_->connection(); } return nullptr; }