diff --git a/Cargo.lock b/Cargo.lock index 866f8e719..c542522c8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -484,6 +484,12 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" +[[package]] +name = "hex-literal" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fe2267d4ed49bc07b63801559be28c718ea06c4738b7a03c94df7386d2cde46" + [[package]] name = "home" version = "0.5.5" @@ -1441,6 +1447,7 @@ dependencies = [ "derive_more", "dns-lookup", "etcetera", + "hex-literal", "humantime", "indexmap", "itertools", diff --git a/Cargo.toml b/Cargo.toml index 89c5fda38..23732fe3d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -74,6 +74,7 @@ windows-sys = { version = "0.48.0", features = [ ] } [dev-dependencies] +hex-literal = "0.4.1" rand = "0.8.5" test-case = "3.2.1" diff --git a/src/tracing/constants.rs b/src/tracing/constants.rs index 32bfba349..7919741eb 100644 --- a/src/tracing/constants.rs +++ b/src/tracing/constants.rs @@ -14,3 +14,9 @@ pub const MAX_SEQUENCE_PER_ROUND: u16 = 1024; /// require two rounds to ensure that delayed probe responses from the immediate prior round can be /// detected and excluded. pub const MAX_SEQUENCE: u16 = u16::MAX - (MAX_SEQUENCE_PER_ROUND * 2); + +/// The maximum number of extensions allowed per probe response. +pub const MAX_EXTENSIONS_PER_PROBE: usize = 8; + +/// The maximum number of members allowed per MPLS stack. +pub const MAX_MPLS_MEMBERS_PER_STACK: usize = 8; diff --git a/src/tracing/net.rs b/src/tracing/net.rs index 3c047d9e0..d2b627ea0 100644 --- a/src/tracing/net.rs +++ b/src/tracing/net.rs @@ -8,6 +8,9 @@ mod ipv4; /// IPv6 implementation. mod ipv6; +/// ICMP extensions. +mod extension; + /// Platform specific network code. mod platform; diff --git a/src/tracing/net/channel.rs b/src/tracing/net/channel.rs index 035fe0629..9243e2718 100644 --- a/src/tracing/net/channel.rs +++ b/src/tracing/net/channel.rs @@ -91,7 +91,7 @@ impl Network for TracerChannel { resp => Ok(resp), }, }?; - if let Some(resp) = prob_response { + if let Some(resp) = &prob_response { tracing::debug!(?resp); } Ok(prob_response) diff --git a/src/tracing/net/extension.rs b/src/tracing/net/extension.rs new file mode 100644 index 000000000..bcfb170f7 --- /dev/null +++ b/src/tracing/net/extension.rs @@ -0,0 +1,61 @@ +use crate::tracing::constants::{MAX_EXTENSIONS_PER_PROBE, MAX_MPLS_MEMBERS_PER_STACK}; +use crate::tracing::error::TracerError; +use crate::tracing::packet::icmp_extension::extension_object::ClassNum; +use crate::tracing::packet::icmp_extension::extension_structure::ExtensionStructure; +use crate::tracing::packet::icmp_extension::mpls_label_stack::MplsLabelStack; +use crate::tracing::packet::icmp_extension::mpls_label_stack_member::MplsLabelStackMember; +use crate::tracing::probe::{ + MplsExtensionData, MplsExtensionMember, ProbeResponseExtension, ProbeResponseExtensions, +}; +use crate::tracing::util::Required; + +impl TryFrom<&[u8]> for ProbeResponseExtensions { + type Error = TracerError; + + fn try_from(value: &[u8]) -> Result { + Self::try_from(ExtensionStructure::new_view(value).req()?) + } +} + +impl TryFrom> for ProbeResponseExtensions { + type Error = TracerError; + + fn try_from(value: ExtensionStructure<'_>) -> Result { + let extensions = value + .objects() + .take(MAX_EXTENSIONS_PER_PROBE) + .map(|obj| match obj.get_class_num() { + ClassNum::MultiProtocolLabelSwitchingLabelStack => { + MplsLabelStack::new_view(obj.payload()) + .req() + .map(|mpls| ProbeResponseExtension::Mpls(MplsExtensionData::from(mpls))) + } + _ => Ok(ProbeResponseExtension::Unknown), + }) + .collect::>()?; + Ok(Self { extensions }) + } +} + +impl From> for MplsExtensionData { + fn from(value: MplsLabelStack<'_>) -> Self { + Self { + members: value + .members() + .take(MAX_MPLS_MEMBERS_PER_STACK) + .map(MplsExtensionMember::from) + .collect(), + } + } +} + +impl From> for MplsExtensionMember { + fn from(value: MplsLabelStackMember<'_>) -> Self { + Self { + label: value.get_label(), + exp: value.get_exp(), + bos: value.get_bos(), + ttl: value.get_ttl(), + } + } +} diff --git a/src/tracing/net/ipv4.rs b/src/tracing/net/ipv4.rs index 4ae744fe6..f1494e70f 100644 --- a/src/tracing/net/ipv4.rs +++ b/src/tracing/net/ipv4.rs @@ -15,8 +15,8 @@ use crate::tracing::packet::tcp::TcpPacket; use crate::tracing::packet::udp::UdpPacket; use crate::tracing::packet::IpProtocol; use crate::tracing::probe::{ - ProbeResponse, ProbeResponseData, ProbeResponseSeq, ProbeResponseSeqIcmp, ProbeResponseSeqTcp, - ProbeResponseSeqUdp, + ProbeResponse, ProbeResponseData, ProbeResponseExtensions, ProbeResponseSeq, + ProbeResponseSeqIcmp, ProbeResponseSeqTcp, ProbeResponseSeqUdp, }; use crate::tracing::types::{PacketSize, PayloadPattern, Sequence, TraceId, TypeOfService}; use crate::tracing::util::Required; @@ -229,11 +229,10 @@ pub fn recv_tcp_socket( } if platform::is_host_unreachable_error(code) { let error_addr = tcp_socket.icmp_error_info()?; - return Ok(Some(ProbeResponse::TimeExceeded(ProbeResponseData::new( - SystemTime::now(), - error_addr, - resp_seq, - )))); + return Ok(Some(ProbeResponse::TimeExceeded( + ProbeResponseData::new(SystemTime::now(), error_addr, resp_seq), + None, + ))); } } } @@ -332,16 +331,28 @@ fn extract_probe_resp( Ok(match icmp_v4.get_icmp_type() { IcmpType::TimeExceeded => { let packet = TimeExceededPacket::new_view(icmp_v4.packet()).req()?; - let resp_seq = extract_probe_resp_seq(packet.payload(), protocol)?; - Some(ProbeResponse::TimeExceeded(ProbeResponseData::new( - recv, src, resp_seq, - ))) + let payload = packet.payload(); + let extension = packet + .extension() + .map(ProbeResponseExtensions::try_from) + .transpose()?; + let resp_seq = extract_probe_resp_seq(payload, protocol)?; + Some(ProbeResponse::TimeExceeded( + ProbeResponseData::new(recv, src, resp_seq), + extension, + )) } IcmpType::DestinationUnreachable => { let packet = DestinationUnreachablePacket::new_view(icmp_v4.packet()).req()?; - let resp_seq = extract_probe_resp_seq(packet.payload(), protocol)?; + let payload = packet.payload(); + let extension = packet + .extension() + .map(ProbeResponseExtensions::try_from) + .transpose()?; + let resp_seq = extract_probe_resp_seq(payload, protocol)?; Some(ProbeResponse::DestinationUnreachable( ProbeResponseData::new(recv, src, resp_seq), + extension, )) } IcmpType::EchoReply => match protocol { diff --git a/src/tracing/net/ipv6.rs b/src/tracing/net/ipv6.rs index 6a00170d9..991a98701 100644 --- a/src/tracing/net/ipv6.rs +++ b/src/tracing/net/ipv6.rs @@ -185,11 +185,10 @@ pub fn recv_tcp_socket( } if platform::is_host_unreachable_error(code) { let error_addr = tcp_socket.icmp_error_info()?; - return Ok(Some(ProbeResponse::TimeExceeded(ProbeResponseData::new( - SystemTime::now(), - error_addr, - resp_seq, - )))); + return Ok(Some(ProbeResponse::TimeExceeded( + ProbeResponseData::new(SystemTime::now(), error_addr, resp_seq), + None, + ))); } } } @@ -263,15 +262,17 @@ fn extract_probe_resp( IcmpType::TimeExceeded => { let packet = TimeExceededPacket::new_view(icmp_v6.packet()).req()?; let resp_seq = extract_probe_resp_seq(packet.payload(), protocol)?; - Some(ProbeResponse::TimeExceeded(ProbeResponseData::new( - recv, ip, resp_seq, - ))) + Some(ProbeResponse::TimeExceeded( + ProbeResponseData::new(recv, ip, resp_seq), + None, + )) } IcmpType::DestinationUnreachable => { let packet = DestinationUnreachablePacket::new_view(icmp_v6.packet()).req()?; let resp_seq = extract_probe_resp_seq(packet.payload(), protocol)?; Some(ProbeResponse::DestinationUnreachable( ProbeResponseData::new(recv, ip, resp_seq), + None, )) } IcmpType::EchoReply => match protocol { diff --git a/src/tracing/packet.rs b/src/tracing/packet.rs index ab9f56160..29e742a15 100644 --- a/src/tracing/packet.rs +++ b/src/tracing/packet.rs @@ -9,6 +9,9 @@ pub mod icmpv4; /// `ICMPv6` packets. pub mod icmpv6; +/// ICMP extensions +pub mod icmp_extension; + /// `IPv4` packets. pub mod ipv4; @@ -21,7 +24,8 @@ pub mod udp; /// `TCP` packets. pub mod tcp; -fn fmt_payload(bytes: &[u8]) -> String { +#[must_use] +pub fn fmt_payload(bytes: &[u8]) -> String { use itertools::Itertools as _; format!("{:02x}", bytes.iter().format(" ")) } diff --git a/src/tracing/packet/icmp_extension.rs b/src/tracing/packet/icmp_extension.rs new file mode 100644 index 000000000..9ebea3ee9 --- /dev/null +++ b/src/tracing/packet/icmp_extension.rs @@ -0,0 +1,950 @@ +pub mod extension_structure { + use crate::tracing::packet::buffer::Buffer; + use crate::tracing::packet::icmp_extension::extension_object::ExtensionObject; + + /// Represents an ICMP `ExtensionStructure` pseudo object. + /// + /// The internal representation is held in network byte order (big-endian) and all accessor + /// methods take and return data in host byte order, converting as necessary for the given + /// architecture. + pub struct ExtensionStructure<'a> { + buf: Buffer<'a>, + } + + impl<'a> ExtensionStructure<'a> { + pub fn new(packet: &'a mut [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Mutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub fn new_view(packet: &'a [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Immutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub const fn minimum_packet_size() -> usize { + 4 + } + + #[must_use] + pub fn packet(&self) -> &[u8] { + self.buf.as_slice() + } + + // TODO return Option here or &[u8]? + #[must_use] + pub fn header(&self) -> &[u8] { + &self.buf.as_slice()[..Self::minimum_packet_size()] + } + + /// An iterator of Extension Objects contained within this `ExtensionStructure`. + #[must_use] + pub fn objects(&self) -> ExtensionObjectIter<'_> { + ExtensionObjectIter::new(&self.buf) + } + } + + pub struct ExtensionObjectIter<'a> { + buf: &'a Buffer<'a>, + offset: usize, + } + + impl<'a> ExtensionObjectIter<'a> { + #[must_use] + pub fn new(buf: &'a Buffer<'_>) -> Self { + Self { + buf, + offset: ExtensionStructure::minimum_packet_size(), + } + } + } + + impl<'a> Iterator for ExtensionObjectIter<'a> { + type Item = ExtensionObject<'a>; // TODO or return &[u8]? + + fn next(&mut self) -> Option { + if self.offset >= self.buf.as_slice().len() { + None + } else { + // TODO check for edge cases here + ExtensionObject::new_view(&self.buf.as_slice()[self.offset..]).map(|obj| { + self.offset += usize::from(obj.get_length()); + obj + }) + } + } + } + + #[cfg(test)] + mod tests { + use super::*; + use crate::tracing::packet::icmp_extension::extension_header::ExtensionHeader; + use crate::tracing::packet::icmp_extension::extension_object::{ClassNum, ClassSubType}; + + #[test] + fn test_header() { + let buf = [ + 0x20, 0x00, 0x99, 0x3a, 0x00, 0x08, 0x01, 0x01, 0x04, 0xbb, 0x41, 0x01, + ]; + let extensions = ExtensionStructure::new_view(&buf).unwrap(); + let header = ExtensionHeader::new_view(extensions.header()).unwrap(); + assert_eq!(2, header.get_version()); + assert_eq!(0x993A, header.get_checksum()); + } + + #[test] + fn test_object_iterator() { + let buf = [ + 0x20, 0x00, 0x99, 0x3a, 0x00, 0x08, 0x01, 0x01, 0x04, 0xbb, 0x41, 0x01, + ]; + let extensions = ExtensionStructure::new_view(&buf).unwrap(); + let mut object_iter = extensions.objects(); + let object = object_iter.next().unwrap(); + assert_eq!(8, object.get_length()); + assert_eq!( + ClassNum::MultiProtocolLabelSwitchingLabelStack, + object.get_class_num() + ); + assert_eq!(ClassSubType(1), object.get_class_subtype()); + assert_eq!([0x04, 0xbb, 0x41, 0x01], object.payload()); + assert!(object_iter.next().is_none()); + } + } +} + +pub mod extension_header { + use crate::tracing::packet::buffer::Buffer; + use std::fmt::{Debug, Formatter}; + + const VERSION_OFFSET: usize = 0; + const CHECKSUM_OFFSET: usize = 2; + + /// Represents an ICMP `ExtensionHeader`. + /// + /// The internal representation is held in network byte order (big-endian) and all accessor + /// methods take and return data in host byte order, converting as necessary for the given + /// architecture. + pub struct ExtensionHeader<'a> { + buf: Buffer<'a>, + } + + impl<'a> ExtensionHeader<'a> { + pub fn new(packet: &'a mut [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Mutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub fn new_view(packet: &'a [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Immutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub const fn minimum_packet_size() -> usize { + 4 + } + + #[must_use] + pub fn get_version(&self) -> u8 { + (self.buf.read(VERSION_OFFSET) & 0xf0) >> 4 + } + + #[must_use] + pub fn get_checksum(&self) -> u16 { + u16::from_be_bytes(self.buf.get_bytes(CHECKSUM_OFFSET)) + } + + pub fn set_version(&mut self, val: u8) { + *self.buf.write(VERSION_OFFSET) = + (self.buf.read(VERSION_OFFSET) & 0xf) | ((val & 0xf) << 4); + } + + pub fn set_checksum(&mut self, val: u16) { + self.buf.set_bytes(CHECKSUM_OFFSET, val.to_be_bytes()); + } + + #[must_use] + pub fn packet(&self) -> &[u8] { + self.buf.as_slice() + } + } + + impl Debug for ExtensionHeader<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ExtensionHeader") + .field("version", &self.get_version()) + .field("checksum", &self.get_checksum()) + // .field("payload", &fmt_payload(self.payload())) + .finish() + } + } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + fn test_version() { + let mut buf = [0_u8; ExtensionHeader::minimum_packet_size()]; + let mut extension = ExtensionHeader::new(&mut buf).unwrap(); + extension.set_version(0); + assert_eq!(0, extension.get_version()); + assert_eq!([0x00], extension.packet()[0..1]); + extension.set_version(2); + assert_eq!(2, extension.get_version()); + assert_eq!([0x20], extension.packet()[0..1]); + extension.set_version(15); + assert_eq!(15, extension.get_version()); + assert_eq!([0xF0], extension.packet()[0..1]); + } + + #[test] + fn test_checksum() { + let mut buf = [0_u8; ExtensionHeader::minimum_packet_size()]; + let mut extension = ExtensionHeader::new(&mut buf).unwrap(); + extension.set_checksum(0); + assert_eq!(0, extension.get_checksum()); + assert_eq!([0x00, 0x00], extension.packet()[2..=3]); + extension.set_checksum(1999); + assert_eq!(1999, extension.get_checksum()); + assert_eq!([0x07, 0xCF], extension.packet()[2..=3]); + extension.set_checksum(39226); + assert_eq!(39226, extension.get_checksum()); + assert_eq!([0x99, 0x3A], extension.packet()[2..=3]); + extension.set_checksum(u16::MAX); + assert_eq!(u16::MAX, extension.get_checksum()); + assert_eq!([0xFF, 0xFF], extension.packet()[2..=3]); + } + + #[test] + fn test_extension_header_view() { + let buf = [ + 0x20, 0x00, 0x99, 0x3a, 0x00, 0x08, 0x01, 0x01, 0x04, 0xbb, 0x41, 0x01, + ]; + let extension = ExtensionHeader::new_view(&buf).unwrap(); + assert_eq!(2, extension.get_version()); + assert_eq!(0x993A, extension.get_checksum()); + } + } +} + +pub mod extension_object { + use crate::tracing::packet::buffer::Buffer; + use crate::tracing::packet::fmt_payload; + use std::fmt::{Debug, Formatter}; + + /// The ICMP Extension Object Class Num. + #[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)] + pub enum ClassNum { + MultiProtocolLabelSwitchingLabelStack, + InterfaceInformationObject, + InterfaceIdentificationObject, + ExtendedInformation, + Other(u8), + } + + impl ClassNum { + #[must_use] + pub fn id(&self) -> u8 { + match self { + Self::MultiProtocolLabelSwitchingLabelStack => 1, + Self::InterfaceInformationObject => 2, + Self::InterfaceIdentificationObject => 3, + Self::ExtendedInformation => 4, + Self::Other(id) => *id, + } + } + } + + impl From for ClassNum { + fn from(val: u8) -> Self { + match val { + 1 => Self::MultiProtocolLabelSwitchingLabelStack, + 2 => Self::InterfaceInformationObject, + 3 => Self::InterfaceIdentificationObject, + 4 => Self::ExtendedInformation, + id => Self::Other(id), + } + } + } + + /// The ICMP Extension Object Class Sub-type. + #[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)] + pub struct ClassSubType(pub u8); + + impl From for ClassSubType { + fn from(val: u8) -> Self { + Self(val) + } + } + + const LENGTH_OFFSET: usize = 0; + const CLASS_NUM_OFFSET: usize = 2; + const CLASS_SUBTYPE_OFFSET: usize = 3; + + /// Represents an ICMP `ExtensionObject`. + /// + /// The internal representation is held in network byte order (big-endian) and all accessor + /// methods take and return data in host byte order, converting as necessary for the given + /// architecture. + pub struct ExtensionObject<'a> { + buf: Buffer<'a>, + } + + impl<'a> ExtensionObject<'a> { + pub fn new(packet: &'a mut [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Mutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub fn new_view(packet: &'a [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Immutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub const fn minimum_packet_size() -> usize { + 4 + } + + pub fn set_length(&mut self, val: u16) { + self.buf.set_bytes(LENGTH_OFFSET, val.to_be_bytes()); + } + + pub fn set_class_num(&mut self, val: ClassNum) { + *self.buf.write(CLASS_NUM_OFFSET) = val.id(); + } + + pub fn set_class_subtype(&mut self, val: ClassSubType) { + *self.buf.write(CLASS_SUBTYPE_OFFSET) = val.0; + } + + pub fn set_payload(&mut self, vals: &[u8]) { + let current_offset = Self::minimum_packet_size(); + self.buf.as_slice_mut()[current_offset..current_offset + vals.len()] + .copy_from_slice(vals); + } + + #[must_use] + pub fn get_length(&self) -> u16 { + u16::from_be_bytes(self.buf.get_bytes(LENGTH_OFFSET)) + } + + #[must_use] + pub fn get_class_num(&self) -> ClassNum { + ClassNum::from(self.buf.read(CLASS_NUM_OFFSET)) + } + + #[must_use] + pub fn get_class_subtype(&self) -> ClassSubType { + ClassSubType::from(self.buf.read(CLASS_SUBTYPE_OFFSET)) + } + + #[must_use] + pub fn packet(&self) -> &[u8] { + self.buf.as_slice() + } + + // TODO should use the length here to get the payload for this object only + #[must_use] + pub fn payload(&self) -> &[u8] { + &self.buf.as_slice()[Self::minimum_packet_size()..] + } + } + + impl Debug for ExtensionObject<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ExtensionObject") + .field("length", &self.get_length()) + .field("class_num", &self.get_class_num()) + .field("class_subtype", &self.get_class_subtype()) + .field("payload", &fmt_payload(self.payload())) + .finish() + } + } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + fn test_length() { + let mut buf = [0_u8; ExtensionObject::minimum_packet_size()]; + let mut extension = ExtensionObject::new(&mut buf).unwrap(); + extension.set_length(0); + assert_eq!(0, extension.get_length()); + assert_eq!([0x00, 0x00], extension.packet()[0..=1]); + extension.set_length(8); + assert_eq!(8, extension.get_length()); + assert_eq!([0x00, 0x08], extension.packet()[0..=1]); + extension.set_length(u16::MAX); + assert_eq!(u16::MAX, extension.get_length()); + assert_eq!([0xFF, 0xFF], extension.packet()[0..=1]); + } + + #[test] + fn test_class_num() { + let mut buf = [0_u8; ExtensionObject::minimum_packet_size()]; + let mut extension = ExtensionObject::new(&mut buf).unwrap(); + extension.set_class_num(ClassNum::MultiProtocolLabelSwitchingLabelStack); + assert_eq!( + ClassNum::MultiProtocolLabelSwitchingLabelStack, + extension.get_class_num() + ); + assert_eq!([0x01], extension.packet()[2..3]); + extension.set_class_num(ClassNum::InterfaceInformationObject); + assert_eq!( + ClassNum::InterfaceInformationObject, + extension.get_class_num() + ); + assert_eq!([0x02], extension.packet()[2..3]); + extension.set_class_num(ClassNum::InterfaceIdentificationObject); + assert_eq!( + ClassNum::InterfaceIdentificationObject, + extension.get_class_num() + ); + assert_eq!([0x03], extension.packet()[2..3]); + extension.set_class_num(ClassNum::ExtendedInformation); + assert_eq!(ClassNum::ExtendedInformation, extension.get_class_num()); + assert_eq!([0x04], extension.packet()[2..3]); + extension.set_class_num(ClassNum::Other(255)); + assert_eq!(ClassNum::Other(255), extension.get_class_num()); + assert_eq!([0xFF], extension.packet()[2..3]); + } + + #[test] + fn test_class_subtype() { + let mut buf = [0_u8; ExtensionObject::minimum_packet_size()]; + let mut extension = ExtensionObject::new(&mut buf).unwrap(); + extension.set_class_subtype(ClassSubType(0)); + assert_eq!(ClassSubType(0), extension.get_class_subtype()); + assert_eq!([0x00], extension.packet()[3..4]); + extension.set_class_subtype(ClassSubType(1)); + assert_eq!(ClassSubType(1), extension.get_class_subtype()); + assert_eq!([0x01], extension.packet()[3..4]); + extension.set_class_subtype(ClassSubType(255)); + assert_eq!(ClassSubType(255), extension.get_class_subtype()); + assert_eq!([0xff], extension.packet()[3..4]); + } + + #[test] + fn test_extension_header_view() { + let buf = [0x00, 0x08, 0x01, 0x01, 0x04, 0xbb, 0x41, 0x01]; + let object = ExtensionObject::new_view(&buf).unwrap(); + assert_eq!(8, object.get_length()); + assert_eq!( + ClassNum::MultiProtocolLabelSwitchingLabelStack, + object.get_class_num() + ); + assert_eq!(ClassSubType(1), object.get_class_subtype()); + assert_eq!([0x04, 0xbb, 0x41, 0x01], object.payload()); + } + } +} + +pub mod mpls_label_stack { + use crate::tracing::packet::buffer::Buffer; + use crate::tracing::packet::icmp_extension::mpls_label_stack_member::MplsLabelStackMember; + + /// Represents an ICMP `MplsLabelStack`. + /// + /// The internal representation is held in network byte order (big-endian) and all accessor + /// methods take and return data in host byte order, converting as necessary for the given + /// architecture. + pub struct MplsLabelStack<'a> { + buf: Buffer<'a>, + } + + impl<'a> MplsLabelStack<'a> { + pub fn new(packet: &'a mut [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Mutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub fn new_view(packet: &'a [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Immutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub const fn minimum_packet_size() -> usize { + 4 + } + + #[must_use] + pub fn packet(&self) -> &[u8] { + self.buf.as_slice() + } + + #[must_use] + pub fn members(&self) -> MplsLabelStackIter<'_> { + MplsLabelStackIter::new(&self.buf) + } + } + + pub struct MplsLabelStackIter<'a> { + buf: &'a Buffer<'a>, + offset: usize, + bos: u8, + } + + impl<'a> MplsLabelStackIter<'a> { + #[must_use] + pub fn new(buf: &'a Buffer<'_>) -> Self { + Self { + buf, + offset: 0, + bos: 0, + } + } + } + + impl<'a> Iterator for MplsLabelStackIter<'a> { + type Item = MplsLabelStackMember<'a>; + + fn next(&mut self) -> Option { + if self.bos > 0 || self.offset >= self.buf.as_slice().len() { + None + } else { + // TODO check for edge cases here + MplsLabelStackMember::new_view(&self.buf.as_slice()[self.offset..]).map(|obj| { + self.bos = obj.get_bos(); + self.offset += MplsLabelStackMember::minimum_packet_size(); + obj + }) + } + } + } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + fn test_stack_member_iterator() { + let buf = [0x04, 0xbb, 0x41, 0x01]; + let stack = MplsLabelStack::new_view(&buf).unwrap(); + let mut member_iter = stack.members(); + let member = member_iter.next().unwrap(); + assert_eq!(19380, member.get_label()); + assert_eq!(0, member.get_exp()); + assert_eq!(1, member.get_bos()); + assert_eq!(1, member.get_ttl()); + assert!(member_iter.next().is_none()); + } + } +} + +pub mod mpls_label_stack_member { + use crate::tracing::packet::buffer::Buffer; + use std::fmt::{Debug, Formatter}; + + const LABEL_OFFSET: usize = 0; + const EXP_OFFSET: usize = 2; + const BOS_OFFSET: usize = 2; + const TTL_OFFSET: usize = 3; + + /// Represents an ICMP `MplsLabelStackMember`. + /// + /// The internal representation is held in network byte order (big-endian) and all accessor + /// methods take and return data in host byte order, converting as necessary for the given + /// architecture. + pub struct MplsLabelStackMember<'a> { + buf: Buffer<'a>, + } + + impl<'a> MplsLabelStackMember<'a> { + pub fn new(packet: &'a mut [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Mutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub fn new_view(packet: &'a [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Immutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub const fn minimum_packet_size() -> usize { + 4 + } + + #[must_use] + pub fn get_label(&self) -> u32 { + u32::from_be_bytes([ + 0x0, + self.buf.read(LABEL_OFFSET), + self.buf.read(LABEL_OFFSET + 1), + self.buf.read(LABEL_OFFSET + 2), + ]) >> 4 + } + + #[must_use] + pub fn get_exp(&self) -> u8 { + (self.buf.read(EXP_OFFSET) & 0x0e) >> 1 + } + + #[must_use] + pub fn get_bos(&self) -> u8 { + self.buf.read(BOS_OFFSET) & 0x01 + } + + #[must_use] + pub fn get_ttl(&self) -> u8 { + self.buf.read(TTL_OFFSET) + } + + pub fn set_label(&mut self, val: u32) { + let bytes = (val << 4).to_be_bytes(); + *self.buf.write(LABEL_OFFSET) = bytes[1]; + *self.buf.write(LABEL_OFFSET + 1) = bytes[2]; + *self.buf.write(LABEL_OFFSET + 2) = + (self.buf.read(LABEL_OFFSET + 2) & 0x0f) | (bytes[3] & 0xf0); + } + + pub fn set_exp(&mut self, exp: u8) { + *self.buf.write(EXP_OFFSET) = (self.buf.read(EXP_OFFSET) & 0xf1) | ((exp << 1) & 0x0e); + } + + pub fn set_bos(&mut self, bos: u8) { + *self.buf.write(BOS_OFFSET) = (self.buf.read(BOS_OFFSET) & 0xfe) | (bos & 0x01); + } + + pub fn set_ttl(&mut self, ttl: u8) { + *self.buf.write(TTL_OFFSET) = ttl; + } + + #[must_use] + pub fn packet(&self) -> &[u8] { + self.buf.as_slice() + } + } + + impl Debug for MplsLabelStackMember<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MplsLabelStackMember") + .field("label", &self.get_label()) + .field("exp", &self.get_exp()) + .field("bos", &self.get_bos()) + .field("ttl", &self.get_ttl()) + .finish() + } + } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + fn test_label() { + let mut buf = [0_u8; MplsLabelStackMember::minimum_packet_size()]; + let mut mpls_extension = MplsLabelStackMember::new(&mut buf).unwrap(); + mpls_extension.set_label(0); + assert_eq!(0, mpls_extension.get_label()); + assert_eq!([0x00, 0x00, 0x00], mpls_extension.packet()[0..3]); + mpls_extension.set_label(19380); + assert_eq!(19380, mpls_extension.get_label()); + assert_eq!([0x04, 0xbb, 0x40], mpls_extension.packet()[0..3]); + mpls_extension.set_label(1_048_575); + assert_eq!(1_048_575, mpls_extension.get_label()); + assert_eq!([0xff, 0xff, 0xf0], mpls_extension.packet()[0..3]); + } + + #[test] + fn test_exp() { + let mut buf = [0_u8; MplsLabelStackMember::minimum_packet_size()]; + let mut mpls_extension = MplsLabelStackMember::new(&mut buf).unwrap(); + mpls_extension.set_exp(0); + assert_eq!(0, mpls_extension.get_exp()); + assert_eq!([0x00], mpls_extension.packet()[2..3]); + mpls_extension.set_exp(7); + assert_eq!(7, mpls_extension.get_exp()); + assert_eq!([0x0e], mpls_extension.packet()[2..3]); + } + + #[test] + fn test_bos() { + let mut buf = [0_u8; MplsLabelStackMember::minimum_packet_size()]; + let mut mpls_extension = MplsLabelStackMember::new(&mut buf).unwrap(); + mpls_extension.set_bos(0); + assert_eq!(0, mpls_extension.get_bos()); + assert_eq!([0x00], mpls_extension.packet()[2..3]); + mpls_extension.set_bos(1); + assert_eq!(1, mpls_extension.get_bos()); + assert_eq!([0x01], mpls_extension.packet()[2..3]); + } + + #[test] + fn test_ttl() { + let mut buf = [0_u8; MplsLabelStackMember::minimum_packet_size()]; + let mut mpls_extension = MplsLabelStackMember::new(&mut buf).unwrap(); + mpls_extension.set_ttl(0); + assert_eq!(0, mpls_extension.get_ttl()); + assert_eq!([0x00], mpls_extension.packet()[3..4]); + mpls_extension.set_ttl(1); + assert_eq!(1, mpls_extension.get_ttl()); + assert_eq!([0x01], mpls_extension.packet()[3..4]); + mpls_extension.set_ttl(255); + assert_eq!(255, mpls_extension.get_ttl()); + assert_eq!([0xff], mpls_extension.packet()[3..4]); + } + + #[test] + fn test_combined() { + let mut buf = [0_u8; MplsLabelStackMember::minimum_packet_size()]; + let mut mpls_extension = MplsLabelStackMember::new(&mut buf).unwrap(); + mpls_extension.set_label(19380); + mpls_extension.set_exp(0); + mpls_extension.set_bos(1); + mpls_extension.set_ttl(1); + assert_eq!(19380, mpls_extension.get_label()); + assert_eq!(0, mpls_extension.get_exp()); + assert_eq!(1, mpls_extension.get_bos()); + assert_eq!(1, mpls_extension.get_ttl()); + assert_eq!([0x04, 0xbb, 0x41, 0x01], mpls_extension.packet()[0..4]); + mpls_extension.set_label(1_048_575); + mpls_extension.set_exp(7); + mpls_extension.set_bos(1); + mpls_extension.set_ttl(255); + assert_eq!(1_048_575, mpls_extension.get_label()); + assert_eq!(7, mpls_extension.get_exp()); + assert_eq!(1, mpls_extension.get_bos()); + assert_eq!(255, mpls_extension.get_ttl()); + assert_eq!([0xff, 0xff, 0xff, 0xff], mpls_extension.packet()[0..4]); + } + + #[test] + fn test_view() { + let buf = [0x04, 0xbb, 0x41, 0x01]; + let object = MplsLabelStackMember::new_view(&buf).unwrap(); + assert_eq!(19380, object.get_label()); + assert_eq!(0, object.get_exp()); + assert_eq!(1, object.get_bos()); + assert_eq!(1, object.get_ttl()); + } + } +} + +pub mod extension_splitter { + const ICMP_ORIG_DATAGRAM_MIN_LENGTH: usize = 128; + + /// Separate an ICMP payload from ICMP extensions as defined in rfc4884. + /// + /// Applies to `TimeExceeded` and `DestinationUnreachable` ICMP messages only. + #[must_use] + pub fn split(rfc4884_length: u8, icmp_payload: &[u8]) -> (&[u8], Option<&[u8]>) { + let orig_datagram_length = usize::from(rfc4884_length * 4); + + // TODO what to do if the claimed orig_datagram_length is bigger than the actual payload? + // we could truncate or we can err or we could return empty? + if orig_datagram_length > icmp_payload.len() { + return (&[], None); + } + + if orig_datagram_length > 0 { + // compliant message case + if icmp_payload.len() > orig_datagram_length { + // extension case (untested): the icmp_payload is longer than the orig_datagram and so whatever remains must be an extension + let extension_len = icmp_payload.len() - orig_datagram_length; + let extension = + &icmp_payload[orig_datagram_length..orig_datagram_length + extension_len]; + ( + &icmp_payload[..orig_datagram_length - extension_len], + Some(extension), + ) + } else { + (&icmp_payload[..orig_datagram_length], None) + } + // "Specifically, when a TRACEROUTE application operating in non- + // compliant mode receives a sufficiently long ICMP message that does + // not specify a length attribute, it will parse for a valid extension + // header at a fixed location, assuming a 128-octet "original datagram" + // field." + // TODO - have to include length of the extension header here? MTR does + } else if orig_datagram_length == 0 && icmp_payload.len() > ICMP_ORIG_DATAGRAM_MIN_LENGTH { + // extension present, non-compliant message + let extension_len = icmp_payload.len() - ICMP_ORIG_DATAGRAM_MIN_LENGTH; + let extension = &icmp_payload + [ICMP_ORIG_DATAGRAM_MIN_LENGTH..ICMP_ORIG_DATAGRAM_MIN_LENGTH + extension_len]; + ( + &icmp_payload[..icmp_payload.len() - extension_len], + Some(extension), + ) + } else { + // no extension present + (icmp_payload, None) + } + } + + #[cfg(test)] + mod tests { + use crate::tracing::packet::icmp_extension::extension_header::ExtensionHeader; + use crate::tracing::packet::icmp_extension::extension_object::{ClassNum, ClassSubType}; + use crate::tracing::packet::icmp_extension::extension_structure::ExtensionStructure; + use crate::tracing::packet::icmp_extension::mpls_label_stack_member::MplsLabelStackMember; + use crate::tracing::packet::icmpv4::echo_request::EchoRequestPacket; + use crate::tracing::packet::icmpv4::time_exceeded::TimeExceededPacket; + use crate::tracing::packet::icmpv4::{IcmpCode, IcmpType}; + use crate::tracing::packet::ipv4::Ipv4Packet; + use std::net::Ipv4Addr; + + // This ICMP TimeExceeded packet does not have a `length` field and is therefore rfc4884 non-complaint and has a + // single `MPLS` extension object. + #[test] + #[allow(clippy::cognitive_complexity)] + fn test_split_extension_ipv4_time_exceeded_non_compliant_mpls() { + let buf = hex_literal::hex!( + " + 0b 00 f4 ff 00 00 00 00 45 00 00 54 cc 1c 40 00 + 01 01 b5 f4 c0 a8 01 15 5d b8 d8 22 08 00 0f e3 + 65 da 82 42 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 20 00 99 3a 00 08 01 01 + 04 bb 41 01 + " + ); + let time_exceeded_packet = TimeExceededPacket::new_view(&buf).unwrap(); + assert_eq!(IcmpType::TimeExceeded, time_exceeded_packet.get_icmp_type()); + assert_eq!(IcmpCode(0), time_exceeded_packet.get_icmp_code()); + assert_eq!(62719, time_exceeded_packet.get_checksum()); + assert_eq!(0, time_exceeded_packet.get_length()); + assert_eq!(&buf[8..136], time_exceeded_packet.payload()); + assert_eq!(Some(&buf[136..]), time_exceeded_packet.extension()); + + let nested_ipv4 = Ipv4Packet::new_view(&buf[8..136]).unwrap(); + assert_eq!(Ipv4Addr::from([192, 168, 1, 21]), nested_ipv4.get_source()); + assert_eq!( + Ipv4Addr::from([93, 184, 216, 34]), + nested_ipv4.get_destination() + ); + assert_eq!(&buf[28..136], nested_ipv4.payload()); + + let nested_echo = EchoRequestPacket::new_view(nested_ipv4.payload()).unwrap(); + assert_eq!(IcmpCode(0), nested_echo.get_icmp_code()); + assert_eq!(IcmpType::EchoRequest, nested_echo.get_icmp_type()); + assert_eq!(0x0FE3, nested_echo.get_checksum()); + assert_eq!(26074, nested_echo.get_identifier()); + assert_eq!(33346, nested_echo.get_sequence()); + assert_eq!(&buf[36..136], nested_echo.payload()); + + let extensions = + ExtensionStructure::new_view(time_exceeded_packet.extension().unwrap()).unwrap(); + + let extension_header = ExtensionHeader::new_view(extensions.header()).unwrap(); + assert_eq!(2, extension_header.get_version()); + assert_eq!(0x993A, extension_header.get_checksum()); + + let extension_object = extensions.objects().next().unwrap(); + assert_eq!(8, extension_object.get_length()); + assert_eq!( + ClassNum::MultiProtocolLabelSwitchingLabelStack, + extension_object.get_class_num() + ); + assert_eq!(ClassSubType(1), extension_object.get_class_subtype()); + assert_eq!([0x04, 0xbb, 0x41, 0x01], extension_object.payload()); + + let mpls = MplsLabelStackMember::new_view(extension_object.payload()).unwrap(); + assert_eq!(19380, mpls.get_label()); + assert_eq!(0, mpls.get_exp()); + assert_eq!(1, mpls.get_bos()); + assert_eq!(1, mpls.get_ttl()); + } + + // This ICMP TimeExceeded packet has a rfc4884 complaint `length` field and does not have any ICMP extensions. + #[test] + fn test_split_extension_ipv4_time_exceeded_compliant_no_extension() { + let buf = hex_literal::hex!( + " + 0b 00 f4 ee 00 11 00 00 45 00 00 54 a2 ee 40 00 + 01 01 df 22 c0 a8 01 15 5d b8 d8 22 08 00 0f e1 + 65 da 82 44 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 + " + ); + let time_exceeded_packet = TimeExceededPacket::new_view(&buf).unwrap(); + assert_eq!(IcmpType::TimeExceeded, time_exceeded_packet.get_icmp_type()); + assert_eq!(IcmpCode(0), time_exceeded_packet.get_icmp_code()); + assert_eq!(62702, time_exceeded_packet.get_checksum()); + assert_eq!(17, time_exceeded_packet.get_length()); + assert_eq!(&buf[8..76], time_exceeded_packet.payload()); + assert_eq!(None, time_exceeded_packet.extension()); + + let nested_ipv4 = Ipv4Packet::new_view(&buf[8..76]).unwrap(); + assert_eq!(Ipv4Addr::from([192, 168, 1, 21]), nested_ipv4.get_source()); + assert_eq!( + Ipv4Addr::from([93, 184, 216, 34]), + nested_ipv4.get_destination() + ); + assert_eq!(&buf[28..76], nested_ipv4.payload()); + + let nested_echo = EchoRequestPacket::new_view(nested_ipv4.payload()).unwrap(); + assert_eq!(IcmpCode(0), nested_echo.get_icmp_code()); + assert_eq!(IcmpType::EchoRequest, nested_echo.get_icmp_type()); + assert_eq!(0x0FE1, nested_echo.get_checksum()); + assert_eq!(26074, nested_echo.get_identifier()); + assert_eq!(33348, nested_echo.get_sequence()); + assert_eq!(&buf[36..76], nested_echo.payload()); + } + } +} diff --git a/src/tracing/packet/icmpv4.rs b/src/tracing/packet/icmpv4.rs index 577e41ec0..cba2bbd4c 100644 --- a/src/tracing/packet/icmpv4.rs +++ b/src/tracing/packet/icmpv4.rs @@ -631,12 +631,14 @@ pub mod echo_reply { pub mod time_exceeded { use crate::tracing::packet::buffer::Buffer; use crate::tracing::packet::fmt_payload; + use crate::tracing::packet::icmp_extension::extension_splitter::split; use crate::tracing::packet::icmpv4::{IcmpCode, IcmpType}; use std::fmt::{Debug, Formatter}; const TYPE_OFFSET: usize = 0; const CODE_OFFSET: usize = 1; const CHECKSUM_OFFSET: usize = 2; + const LENGTH_OFFSET: usize = 5; /// Represents an ICMP `TimeExceeded` packet. /// @@ -689,6 +691,11 @@ pub mod time_exceeded { u16::from_be_bytes(self.buf.get_bytes(CHECKSUM_OFFSET)) } + #[must_use] + pub fn get_length(&self) -> u8 { + self.buf.read(LENGTH_OFFSET) + } + pub fn set_icmp_type(&mut self, val: IcmpType) { *self.buf.write(TYPE_OFFSET) = val.id(); } @@ -701,6 +708,10 @@ pub mod time_exceeded { self.buf.set_bytes(CHECKSUM_OFFSET, val.to_be_bytes()); } + pub fn set_length(&mut self, val: u8) { + *self.buf.write(LENGTH_OFFSET) = val; + } + pub fn set_payload(&mut self, vals: &[u8]) { let current_offset = Self::minimum_packet_size(); self.buf.as_slice_mut()[current_offset..current_offset + vals.len()] @@ -714,7 +725,20 @@ pub mod time_exceeded { #[must_use] pub fn payload(&self) -> &[u8] { - &self.buf.as_slice()[Self::minimum_packet_size()..] + let (payload, _) = self.split_payload_extension(); + payload + } + + #[must_use] + pub fn extension(&self) -> Option<&[u8]> { + let (_, extension) = self.split_payload_extension(); + extension + } + + fn split_payload_extension(&self) -> (&[u8], Option<&[u8]>) { + let rfc4884_length = self.get_length(); + let icmp_payload = &self.buf.as_slice()[Self::minimum_packet_size()..]; + split(rfc4884_length, icmp_payload) } } @@ -724,6 +748,7 @@ pub mod time_exceeded { .field("icmp_type", &self.get_icmp_type()) .field("icmp_code", &self.get_icmp_code()) .field("checksum", &self.get_checksum()) + .field("length", &self.get_length()) .field("payload", &fmt_payload(self.payload())) .finish() } @@ -799,13 +824,14 @@ pub mod time_exceeded { pub mod destination_unreachable { use crate::tracing::packet::buffer::Buffer; use crate::tracing::packet::fmt_payload; + use crate::tracing::packet::icmp_extension::extension_splitter::split; use crate::tracing::packet::icmpv4::{IcmpCode, IcmpType}; use std::fmt::{Debug, Formatter}; const TYPE_OFFSET: usize = 0; const CODE_OFFSET: usize = 1; const CHECKSUM_OFFSET: usize = 2; - const UNUSED_OFFSET: usize = 4; + const LENGTH_OFFSET: usize = 5; const NEXT_HOP_MTU_OFFSET: usize = 6; /// Represents an ICMP `DestinationUnreachable` packet. @@ -860,8 +886,8 @@ pub mod destination_unreachable { } #[must_use] - pub fn get_unused(&self) -> u16 { - u16::from_be_bytes(self.buf.get_bytes(UNUSED_OFFSET)) + pub fn get_length(&self) -> u8 { + self.buf.read(LENGTH_OFFSET) } #[must_use] @@ -881,8 +907,8 @@ pub mod destination_unreachable { self.buf.set_bytes(CHECKSUM_OFFSET, val.to_be_bytes()); } - pub fn set_unused(&mut self, val: u16) { - self.buf.set_bytes(UNUSED_OFFSET, val.to_be_bytes()); + pub fn set_length(&mut self, val: u8) { + *self.buf.write(LENGTH_OFFSET) = val; } pub fn set_next_hop_mtu(&mut self, val: u16) { @@ -902,7 +928,20 @@ pub mod destination_unreachable { #[must_use] pub fn payload(&self) -> &[u8] { - &self.buf.as_slice()[Self::minimum_packet_size()..] + let (payload, _) = self.split_payload_extension(); + payload + } + + #[must_use] + pub fn extension(&self) -> Option<&[u8]> { + let (_, extension) = self.split_payload_extension(); + extension + } + + fn split_payload_extension(&self) -> (&[u8], Option<&[u8]>) { + let rfc4884_length = self.get_length(); + let icmp_payload = &self.buf.as_slice()[Self::minimum_packet_size()..]; + split(rfc4884_length, icmp_payload) } } @@ -912,7 +951,7 @@ pub mod destination_unreachable { .field("icmp_type", &self.get_icmp_type()) .field("icmp_code", &self.get_icmp_code()) .field("checksum", &self.get_checksum()) - .field("unused", &self.get_unused()) + .field("length", &self.get_length()) .field("next_hop_mtu", &self.get_next_hop_mtu()) .field("payload", &fmt_payload(self.payload())) .finish() diff --git a/src/tracing/probe.rs b/src/tracing/probe.rs index c9518d5e3..a1b4414bc 100644 --- a/src/tracing/probe.rs +++ b/src/tracing/probe.rs @@ -1,4 +1,6 @@ +use crate::tracing::constants::{MAX_EXTENSIONS_PER_PROBE, MAX_MPLS_MEMBERS_PER_STACK}; use crate::tracing::types::{Port, Round, Sequence, TimeToLive, TraceId}; +use arrayvec::ArrayVec; use std::net::IpAddr; use std::time::{Duration, SystemTime}; @@ -128,17 +130,45 @@ pub enum IcmpPacketType { } /// The response to a probe. -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] pub enum ProbeResponse { - TimeExceeded(ProbeResponseData), - DestinationUnreachable(ProbeResponseData), + TimeExceeded(ProbeResponseData, Option), + DestinationUnreachable(ProbeResponseData, Option), EchoReply(ProbeResponseData), TcpReply(ProbeResponseData), TcpRefused(ProbeResponseData), } +/// The ICMP extensions for a probe response. +#[derive(Debug, Clone)] +pub struct ProbeResponseExtensions { + pub extensions: ArrayVec, +} + +/// A probe response extension. +#[derive(Debug, Clone)] +pub enum ProbeResponseExtension { + Mpls(MplsExtensionData), + Unknown, +} + +/// The members of a MPLS probe response extension. +#[derive(Debug, Clone)] +pub struct MplsExtensionData { + pub members: ArrayVec, +} + +/// A member of a MPLS probe response extension. +#[derive(Debug, Clone)] +pub struct MplsExtensionMember { + pub label: u32, + pub exp: u8, + pub bos: u8, + pub ttl: u8, +} + /// The data in the probe response. -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] pub struct ProbeResponseData { /// Timestamp of the probe response. pub recv: SystemTime, @@ -158,14 +188,14 @@ impl ProbeResponseData { } } -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] pub enum ProbeResponseSeq { Icmp(ProbeResponseSeqIcmp), Udp(ProbeResponseSeqUdp), Tcp(ProbeResponseSeqTcp), } -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] pub struct ProbeResponseSeqIcmp { pub identifier: u16, pub sequence: u16, @@ -180,7 +210,7 @@ impl ProbeResponseSeqIcmp { } } -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] pub struct ProbeResponseSeqUdp { pub identifier: u16, pub src_port: u16, @@ -199,7 +229,7 @@ impl ProbeResponseSeqUdp { } } -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] pub struct ProbeResponseSeqTcp { pub src_port: u16, pub dest_port: u16, diff --git a/src/tracing/tracer.rs b/src/tracing/tracer.rs index f3c8c5aca..ef60159b8 100644 --- a/src/tracing/tracer.rs +++ b/src/tracing/tracer.rs @@ -2,8 +2,8 @@ use self::state::TracerState; use crate::tracing::error::{TraceResult, TracerError}; use crate::tracing::net::Network; use crate::tracing::probe::{ - ProbeResponse, ProbeResponseData, ProbeResponseSeq, ProbeResponseSeqIcmp, ProbeResponseSeqTcp, - ProbeResponseSeqUdp, + ProbeResponse, ProbeResponseData, ProbeResponseExtension, ProbeResponseSeq, + ProbeResponseSeqIcmp, ProbeResponseSeqTcp, ProbeResponseSeqUdp, }; use crate::tracing::types::{Sequence, TimeToLive, TraceId}; use crate::tracing::{MultipathStrategy, PortDirection, TracerProtocol}; @@ -144,14 +144,34 @@ impl)> Tracer { fn recv_response(&self, network: &mut N, st: &mut TracerState) -> TraceResult<()> { let next = network.recv_probe()?; match next { - Some(ProbeResponse::TimeExceeded(data)) => { + Some(ProbeResponse::TimeExceeded(data, extensions)) => { let (trace_id, sequence, received, host) = self.extract(&data); + + // TODO + if let Some(ext) = extensions { + for ext in ext.extensions { + match ext { + ProbeResponseExtension::Mpls(mpls) => { + for member in mpls.members { + println!( + "mpls extension object: label={}, exp={}, bos={}, ttl={}", + member.label, member.exp, member.bos, member.ttl + ); + } + } + ProbeResponseExtension::Unknown => { + println!("unknown extension object") + } + } + } + } + let is_target = host == self.config.target_addr; if self.check_trace_id(trace_id) && st.in_round(sequence) { st.complete_probe_time_exceeded(sequence, host, received, is_target); } } - Some(ProbeResponse::DestinationUnreachable(data)) => { + Some(ProbeResponse::DestinationUnreachable(data, _extensions)) => { let (trace_id, sequence, received, host) = self.extract(&data); if self.check_trace_id(trace_id) && st.in_round(sequence) { st.complete_probe_unreachable(sequence, host, received);