Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add initial implementation of auth service #69

Merged
merged 8 commits into from
Aug 28, 2024
14 changes: 5 additions & 9 deletions src/attribute.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::configuration::Path;
use crate::filter::http_context::Filter;
use chrono::{DateTime, FixedOffset};
use proxy_wasm::traits::Context;
use proxy_wasm::hostcalls;

pub trait Attribute {
fn parse(raw_attribute: Vec<u8>) -> Result<Self, String>
Expand Down Expand Up @@ -105,15 +104,12 @@ impl Attribute for DateTime<FixedOffset> {
}

#[allow(dead_code)]
pub fn get_attribute<T>(f: &Filter, attr: &str) -> Result<T, String>
pub fn get_attribute<T>(attr: &str) -> Result<T, String>
where
T: Attribute,
{
match f.get_property(Path::from(attr).tokens()) {
None => Err(format!(
"#{} get_attribute: not found: {}",
f.context_id, attr
)),
Some(attribute_bytes) => T::parse(attribute_bytes),
match hostcalls::get_property(Path::from(attr).tokens()) {
Ok(Some(attribute_bytes)) => T::parse(attribute_bytes),
_ => Err(format!("get_attribute: not found: {}", attr)),
}
}
7 changes: 7 additions & 0 deletions src/configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,13 @@ pub enum FailureMode {
Allow,
}

#[derive(Deserialize, Debug, Clone)]
#[serde(rename_all = "lowercase")]
pub enum ExtensionType {
Auth,
RateLimit,
}

#[derive(Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
pub struct PluginConfiguration {
Expand Down
7 changes: 7 additions & 0 deletions src/envoy/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ mod token_bucket;
mod value;

pub use {
address::{Address, SocketAddress},
attribute_context::{
AttributeContext, AttributeContext_HttpRequest, AttributeContext_Peer,
AttributeContext_Request,
},
base::Metadata,
external_auth::CheckRequest,
ratelimit::{RateLimitDescriptor, RateLimitDescriptor_Entry},
rls::{RateLimitRequest, RateLimitResponse, RateLimitResponse_Code},
};
Expand Down
43 changes: 9 additions & 34 deletions src/filter/http_context.rs
Original file line number Diff line number Diff line change
@@ -1,42 +1,19 @@
use crate::configuration::{FailureMode, FilterConfig};
use crate::configuration::{ExtensionType, FailureMode, FilterConfig};
use crate::envoy::{RateLimitResponse, RateLimitResponse_Code};
use crate::filter::http_context::TracingHeader::{Baggage, Traceparent, Tracestate};
use crate::policy::Policy;
use crate::service::rate_limit::RateLimitService;
use crate::service::Service;
use crate::service::{GrpcServiceHandler, HeaderResolver};
use log::{debug, warn};
use protobuf::Message;
use proxy_wasm::traits::{Context, HttpContext};
use proxy_wasm::types::{Action, Bytes};
use proxy_wasm::types::Action;
use std::rc::Rc;

// tracing headers
#[derive(Clone)]
pub enum TracingHeader {
Traceparent,
Tracestate,
Baggage,
}

impl TracingHeader {
fn all() -> [Self; 3] {
[Traceparent, Tracestate, Baggage]
}

pub fn as_str(&self) -> &'static str {
match self {
Traceparent => "traceparent",
Tracestate => "tracestate",
Baggage => "baggage",
}
}
}

pub struct Filter {
pub context_id: u32,
pub config: Rc<FilterConfig>,
pub response_headers_to_add: Vec<(String, String)>,
pub tracing_headers: Vec<(TracingHeader, Bytes)>,
pub header_resolver: Rc<HeaderResolver>,
}

impl Filter {
Expand All @@ -63,7 +40,11 @@ impl Filter {
return Action::Continue;
}

let rls = RateLimitService::new(rlp.service.as_str(), self.tracing_headers.clone());
let rls = GrpcServiceHandler::new(
ExtensionType::RateLimit,
rlp.service.clone(),
Rc::clone(&self.header_resolver),
);
let message = RateLimitService::message(rlp.domain.clone(), descriptors);

match rls.send(message) {
Expand Down Expand Up @@ -98,12 +79,6 @@ impl HttpContext for Filter {
fn on_http_request_headers(&mut self, _: usize, _: bool) -> Action {
debug!("#{} on_http_request_headers", self.context_id);

for header in TracingHeader::all() {
if let Some(value) = self.get_http_request_header_bytes(header.as_str()) {
self.tracing_headers.push((header, value))
}
}

match self
.config
.index
Expand Down
3 changes: 2 additions & 1 deletion src/filter/root_context.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::configuration::{FilterConfig, PluginConfiguration};
use crate::filter::http_context::Filter;
use crate::service::HeaderResolver;
use const_format::formatcp;
use log::{debug, error, info};
use proxy_wasm::traits::{Context, HttpContext, RootContext};
Expand Down Expand Up @@ -40,7 +41,7 @@ impl RootContext for FilterRoot {
context_id,
config: Rc::clone(&self.config),
response_headers_to_add: Vec::default(),
tracing_headers: Vec::default(),
header_resolver: Rc::new(HeaderResolver::new()),
}))
}

Expand Down
122 changes: 119 additions & 3 deletions src/service.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,124 @@
pub(crate) mod auth;
pub(crate) mod rate_limit;

use crate::configuration::ExtensionType;
use crate::service::auth::{AUTH_METHOD_NAME, AUTH_SERVICE_NAME};
use crate::service::rate_limit::{RATELIMIT_METHOD_NAME, RATELIMIT_SERVICE_NAME};
use crate::service::TracingHeader::{Baggage, Traceparent, Tracestate};
use protobuf::Message;
use proxy_wasm::types::Status;
use proxy_wasm::hostcalls;
use proxy_wasm::hostcalls::dispatch_grpc_call;
use proxy_wasm::types::{Bytes, MapType, Status};
use std::cell::OnceCell;
use std::rc::Rc;
use std::time::Duration;

pub trait Service<M: Message> {
fn send(&self, message: M) -> Result<u32, Status>;
pub struct GrpcServiceHandler {
endpoint: String,
service_name: String,
method_name: String,
header_resolver: Rc<HeaderResolver>,
}

impl GrpcServiceHandler {
fn build(
endpoint: String,
service_name: &str,
method_name: &str,
header_resolver: Rc<HeaderResolver>,
) -> Self {
Self {
endpoint: endpoint.to_owned(),
service_name: service_name.to_owned(),
method_name: method_name.to_owned(),
header_resolver,
}
}

pub fn new(
extension_type: ExtensionType,
endpoint: String,
header_resolver: Rc<HeaderResolver>,
) -> Self {
match extension_type {
ExtensionType::Auth => Self::build(
endpoint,
AUTH_SERVICE_NAME,
AUTH_METHOD_NAME,
header_resolver,
),
ExtensionType::RateLimit => Self::build(
endpoint,
RATELIMIT_SERVICE_NAME,
RATELIMIT_METHOD_NAME,
header_resolver,
),
}
}

pub fn send<M: Message>(&self, message: M) -> Result<u32, Status> {
let msg = Message::write_to_bytes(&message).unwrap();
let metadata = self
.header_resolver
.get()
.iter()
.map(|(header, value)| (*header, value.as_slice()))
.collect();

dispatch_grpc_call(
self.endpoint.as_str(),
self.service_name.as_str(),
self.method_name.as_str(),
metadata,
Some(&msg),
Duration::from_secs(5),
)
}
}

pub struct HeaderResolver {
headers: OnceCell<Vec<(&'static str, Bytes)>>,
}

impl HeaderResolver {
pub fn new() -> Self {
Self {
headers: OnceCell::new(),
}
}

pub fn get(&self) -> &Vec<(&'static str, Bytes)> {
self.headers.get_or_init(|| {
let mut headers = Vec::new();
for header in TracingHeader::all() {
if let Ok(Some(value)) =
hostcalls::get_map_value_bytes(MapType::HttpRequestHeaders, (*header).as_str())
{
headers.push(((*header).as_str(), value));
}
}
headers
})
}
}

// tracing headers
pub enum TracingHeader {
Traceparent,
Tracestate,
Baggage,
}

impl TracingHeader {
fn all() -> &'static [Self; 3] {
&[Traceparent, Tracestate, Baggage]
}

pub fn as_str(&self) -> &'static str {
match self {
Traceparent => "traceparent",
Tracestate => "tracestate",
Baggage => "baggage",
}
}
}
81 changes: 81 additions & 0 deletions src/service/auth.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
use crate::attribute::get_attribute;
use crate::envoy::{
Address, AttributeContext, AttributeContext_HttpRequest, AttributeContext_Peer,
AttributeContext_Request, CheckRequest, Metadata, SocketAddress,
};
use chrono::{DateTime, FixedOffset, Timelike};
use protobuf::well_known_types::Timestamp;
use proxy_wasm::hostcalls;
use proxy_wasm::types::MapType;
use std::collections::HashMap;

pub const AUTH_SERVICE_NAME: &str = "envoy.service.auth.v3.Authorization";
pub const AUTH_METHOD_NAME: &str = "Check";

pub struct AuthService;

#[allow(dead_code)]
impl AuthService {
pub fn message(ce_host: String) -> CheckRequest {
AuthService::build_check_req(ce_host)
}

fn build_check_req(ce_host: String) -> CheckRequest {
let mut auth_req = CheckRequest::default();
let mut attr = AttributeContext::default();
attr.set_request(AuthService::build_request());
attr.set_destination(AuthService::build_peer(
get_attribute::<String>("destination.address").unwrap_or_default(),
get_attribute::<i64>("destination.port").unwrap_or_default() as u32,
));
attr.set_source(AuthService::build_peer(
get_attribute::<String>("source.address").unwrap_or_default(),
get_attribute::<i64>("source.port").unwrap_or_default() as u32,
));
// the ce_host is the identifier for authorino to determine which authconfig to use
let context_extensions = HashMap::from([("host".to_string(), ce_host)]);
attr.set_context_extensions(context_extensions);
attr.set_metadata_context(Metadata::default());
auth_req.set_attributes(attr);
auth_req
}

fn build_request() -> AttributeContext_Request {
let mut request = AttributeContext_Request::default();
let mut http = AttributeContext_HttpRequest::default();
let headers: HashMap<String, String> = hostcalls::get_map(MapType::HttpRequestHeaders)
.unwrap()
.into_iter()
.collect();

http.set_host(get_attribute::<String>("request.host").unwrap_or_default());
http.set_method(get_attribute::<String>("request.method").unwrap_or_default());
http.set_scheme(get_attribute::<String>("request.scheme").unwrap_or_default());
http.set_path(get_attribute::<String>("request.path").unwrap_or_default());
http.set_protocol(get_attribute::<String>("request.protocol").unwrap_or_default());

http.set_headers(headers);
request.set_time(get_attribute("request.time").map_or(
Timestamp::new(),
|date_time: DateTime<FixedOffset>| Timestamp {
nanos: date_time.nanosecond() as i32,
seconds: date_time.second() as i64,
unknown_fields: Default::default(),
cached_size: Default::default(),
},
));
adam-cattermole marked this conversation as resolved.
Show resolved Hide resolved
request.set_http(http);
request
}

fn build_peer(host: String, port: u32) -> AttributeContext_Peer {
let mut peer = AttributeContext_Peer::default();
let mut address = Address::default();
let mut socket_address = SocketAddress::default();
socket_address.set_address(host);
socket_address.set_port_value(port);
address.set_socket_address(socket_address);
peer.set_address(address);
peer
}
}
Loading
Loading