Skip to content

Commit

Permalink
Decode parts of IP headers and apply TS filtering.
Browse files Browse the repository at this point in the history
In addition to having a correct implementation, this also helps with
debugging, and finding the right ESP session (if multiple child SAs
would be used later in the future).
  • Loading branch information
zlogic committed Sep 3, 2024
1 parent 4ef9e03 commit d3d7253
Show file tree
Hide file tree
Showing 3 changed files with 232 additions and 9 deletions.
212 changes: 209 additions & 3 deletions src/ikev2/esp.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::{error, fmt, net::SocketAddr};
use std::{
error, fmt,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
};

use log::warn;

Expand Down Expand Up @@ -59,8 +62,14 @@ impl SecurityAssociation {
8 + self.crypto_stack.encrypted_payload_length(msg_len)
}

fn contains(&self, remote_addr: &SocketAddr, local_addr: &SocketAddr) -> bool {
ts_accepts(&self.ts_remote, remote_addr) && ts_accepts(&self.ts_local, local_addr)
pub fn accepts_esp_to_vpn(&self, hdr: &IpHeader) -> bool {
ts_accepts_header(&self.ts_local, &hdr, TsCheck::Destination)
&& ts_accepts_header(&self.ts_remote, &hdr, TsCheck::Source)
}

pub fn accepts_vpn_to_esp(&self, hdr: &IpHeader) -> bool {
ts_accepts_header(&self.ts_remote, &hdr, TsCheck::Destination)
&& ts_accepts_header(&self.ts_local, &hdr, TsCheck::Source)
}

pub fn handle_esp<'a>(&self, data: &'a mut [u8]) -> Result<&'a [u8], EspError> {
Expand Down Expand Up @@ -146,6 +155,203 @@ pub fn ts_accepts(ts: &[message::TrafficSelector], addr: &SocketAddr) -> bool {
.any(|ts| ts.addr_range().contains(&addr.ip()) && ts.port_range().contains(&addr.port()))
}

enum TsCheck {
Source,
Destination,
}

fn ts_accepts_header(ts: &[message::TrafficSelector], hdr: &IpHeader, ts_check: TsCheck) -> bool {
ts.iter().any(|ts| {
let accepts_procotol = ts.ip_protocol() == message::IPProtocolType::ANY
|| ts.ip_protocol() == hdr.transport_protocol;
if !accepts_procotol {
return false;
}
let check_addr = match ts_check {
TsCheck::Source => &hdr.src_addr,
TsCheck::Destination => &hdr.dst_addr,
};
if !ts.addr_range().contains(check_addr) {
return false;
}
let check_port = match ts_check {
TsCheck::Source => hdr.src_port.as_ref(),
TsCheck::Destination => hdr.dst_port.as_ref(),
};
if let Some(check_port) = check_port {
ts.port_range().contains(check_port)
} else {
// If no port specified for TCP or UDP, this is an error.
hdr.transport_protocol != message::IPProtocolType::TCP
&& hdr.transport_protocol != message::IPProtocolType::UDP
}
})
}

pub struct IpHeader {
src_addr: IpAddr,
dst_addr: IpAddr,
src_port: Option<u16>,
dst_port: Option<u16>,
transport_protocol: message::IPProtocolType,
}

#[derive(Clone, Copy, PartialEq, Eq)]
struct Ipv6NextHeader(u8);

impl Ipv6NextHeader {
const HOP_BY_HOP_OPTIONS: Ipv6NextHeader = Ipv6NextHeader(0);
const ROUTING: Ipv6NextHeader = Ipv6NextHeader(43);
const FRAGMENT: Ipv6NextHeader = Ipv6NextHeader(44);
const DESTINATION_OPTIONS: Ipv6NextHeader = Ipv6NextHeader(60);
const NO_NEXT_HEADER: Ipv6NextHeader = Ipv6NextHeader(59);
}

impl Ipv6NextHeader {
fn length(&self, data: &[u8]) -> Option<usize> {
match *self {
Self::HOP_BY_HOP_OPTIONS => Some(data[1] as usize + 1),
Self::ROUTING => Some(data[1] as usize + 1),
Self::FRAGMENT => Some(8),
Self::DESTINATION_OPTIONS => Some(data[1] as usize + 1),
Self::NO_NEXT_HEADER => None,
_ => None,
}
}

fn min_bytes(&self) -> usize {
match *self {
Self::HOP_BY_HOP_OPTIONS => 2,
Self::ROUTING => 2,
Self::FRAGMENT => 8,
Self::DESTINATION_OPTIONS => 2,
Self::NO_NEXT_HEADER => 0,
_ => 0,
}
}
}

impl IpHeader {
pub fn from_packet(data: &[u8]) -> Result<IpHeader, EspError> {
if data.is_empty() {
return Err("IP packet is empty, cannot extract header data".into());
}
match data[0] >> 4 {
4 => Self::from_ipv4_packet(data),
6 => Self::from_ipv6_packet(data),
_ => {
warn!("ESP IP packet is not a supported IP version: {:x}", data[0]);
return Err("Unsupported IP prococol version".into());
}
}
}

fn from_ipv4_packet(data: &[u8]) -> Result<IpHeader, EspError> {
if data.len() < 20 {
return Err("Not enough bytes in IPv4 header".into());
}
let header_length = (data[0] & 0x0f) as usize * 4;
if data.len() < header_length {
return Err("IPv4 header length overflow".into());
}
let transport_protocol = message::IPProtocolType::from_u8(data[9]);
let (src_port, dst_port) = match transport_protocol {
message::IPProtocolType::TCP | message::IPProtocolType::UDP => {
Self::extract_ports(&data[header_length..])?
}
message::IPProtocolType::ANY => return Err("IPv4 protocol is 0".into()),
_ => (None, None),
};
let mut src_addr = [0u8; 4];
src_addr.copy_from_slice(&data[12..16]);
let src_addr = IpAddr::V4(Ipv4Addr::from(src_addr));
let mut dst_addr = [0u8; 4];
dst_addr.copy_from_slice(&data[16..20]);
let dst_addr = IpAddr::V4(Ipv4Addr::from(dst_addr));
Ok(IpHeader {
src_addr,
dst_addr,
src_port,
dst_port,
transport_protocol,
})
}

fn from_ipv6_packet(data: &[u8]) -> Result<IpHeader, EspError> {
if data.len() < 40 {
return Err("Not enough bytes in IPv6 header".into());
}
// TODO: test that this works.
let mut next_header = Ipv6NextHeader(data[6]);
let mut next_header_start = 40;
loop {
if next_header_start + next_header.min_bytes() > data.len() {
return Err("IPv6 header length overlow".into());
}
if let Some(header_length) = next_header.length(&data[next_header_start..]) {
next_header = Ipv6NextHeader(data[next_header_start]);
next_header_start += header_length;
} else {
break;
}
}
let transport_protocol = message::IPProtocolType::from_u8(next_header.0);
let (src_port, dst_port) = match transport_protocol {
message::IPProtocolType::TCP | message::IPProtocolType::UDP => {
Self::extract_ports(&data[next_header_start..])?
}
message::IPProtocolType::ANY => return Err("IPv4 protocol is 0".into()),
_ => (None, None),
};
let mut src_addr = [0u8; 16];
src_addr.copy_from_slice(&data[8..24]);
let src_addr = IpAddr::V6(Ipv6Addr::from(src_addr));
let mut dst_addr = [0u8; 16];
dst_addr.copy_from_slice(&data[24..40]);
let dst_addr = IpAddr::V6(Ipv6Addr::from(dst_addr));
Ok(IpHeader {
src_addr,
dst_addr,
src_port,
dst_port,
transport_protocol,
})
}

fn extract_ports(data: &[u8]) -> Result<(Option<u16>, Option<u16>), EspError> {
if data.len() < 4 {
return Err("Not enough data in transport layer to extract ports".into());
}
let mut src_port = [0u8; 2];
src_port.copy_from_slice(&data[0..2]);
let src_port = u16::from_be_bytes(src_port);
let mut dst_port = [0u8; 2];
dst_port.copy_from_slice(&data[2..4]);
let dst_port = u16::from_be_bytes(dst_port);
Ok((Some(src_port), Some(dst_port)))
}
}

impl fmt::Display for IpHeader {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(src_port) = self.src_port {
write!(
f,
"{} {}:{} -> ",
self.transport_protocol, self.src_addr, src_port
)?;
} else {
write!(f, "{} {} -> ", self.transport_protocol, self.src_addr)?;
}
if let Some(dst_port) = self.dst_port {
write!(f, "{}:{}", self.dst_addr, dst_port)?;
} else {
write!(f, "{}", self.dst_addr)?;
}
Ok(())
}
}

#[derive(Debug)]
pub enum EspError {
Internal(&'static str),
Expand Down
2 changes: 1 addition & 1 deletion src/ikev2/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1977,7 +1977,7 @@ impl IPProtocolType {
pub const TCP: IPProtocolType = IPProtocolType(6);
pub const UDP: IPProtocolType = IPProtocolType(17);

fn from_u8(value: u8) -> IPProtocolType {
pub fn from_u8(value: u8) -> IPProtocolType {
IPProtocolType(value)
}
}
Expand Down
27 changes: 22 additions & 5 deletions src/ikev2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,11 @@ impl Sessions {
datagram.remote_addr,
decrypted_slice
);
let hdr = esp::IpHeader::from_packet(decrypted_slice)?;
trace!("IP header {}", hdr);
if !sa.accepts_esp_to_vpn(&hdr) {
return Err("ESP packet dropped by traffic selector".into());
}
if decrypted_slice.len() > MAX_ESP_PACKET_SIZE {
warn!(
"Decrypted packet size {} exceeds MTU {}",
Expand All @@ -740,10 +745,22 @@ impl Sessions {
}

async fn process_vpn_packet(&mut self, mut data: Vec<u8>) -> Result<(), IKEv2Error> {
// TODO: select SA based on packet data.
// TODO: log protocol & address data
trace!("Received packet from VPN\n{:?}", data);
if let Some(sa) = self.security_associations.values_mut().next() {
let hdr = match esp::IpHeader::from_packet(&data) {
Ok(hdr) => hdr,
Err(err) => {
warn!(
"Failed to read header in IP packet from VPN: {}\n{:?}",
err, data
);
return Err("Failed to read header in IP packet from VPN".into());
}
};
trace!("Received packet from VPN {}\n{:?}", hdr, data);
if let Some(sa) = self
.security_associations
.values_mut()
.find(|sa| sa.accepts_vpn_to_esp(&hdr))
{
let msg_len = data.len();
if data.len() >= MAX_ESP_PACKET_SIZE {
return Err("Vector doesn't have capacity for ESP headers".into());
Expand All @@ -760,7 +777,7 @@ impl Sessions {
.await?;
Ok(())
} else {
Err("Target Security Association not found".into())
Err("No matchig Security Associations found".into())
}
}
}
Expand Down

0 comments on commit d3d7253

Please sign in to comment.