diff --git a/Cargo.lock b/Cargo.lock index dc9d07c86..7eb862c69 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1709,6 +1709,7 @@ dependencies = [ "test-case", "thiserror", "tokio", + "tokio-util", "toml", "tracing", "tracing-chrome", diff --git a/Cargo.toml b/Cargo.toml index cd2355f21..5cb012366 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -85,6 +85,7 @@ test-case = "3.3.1" tun = { version = "0.6.1", features = [ "async" ] } serde_yaml = "0.9.30" tokio = { version = "1.35.1", features = [ "full" ] } +tokio-util = "0.7.10" [features] # Enable simulation integration tests diff --git a/tests/sim/network.rs b/tests/sim/network.rs index 2d5fab8ce..920d7ba9d 100644 --- a/tests/sim/network.rs +++ b/tests/sim/network.rs @@ -2,11 +2,8 @@ use crate::simulation::{Response, Simulation, SingleHost}; use crate::tun_device::TunDevice; use std::net::{IpAddr, Ipv4Addr}; use std::sync::Arc; -use std::time::Duration; -use tokio::sync::mpsc::error::TryRecvError; -use tokio::sync::mpsc::Receiver; use tokio::sync::Mutex; -use tokio::time::timeout; +use tokio_util::sync::CancellationToken; use trippy::tracing::packet::checksum::{icmp_ipv4_checksum, ipv4_header_checksum}; use trippy::tracing::packet::icmpv4::echo_request::EchoRequestPacket; use trippy::tracing::packet::icmpv4::time_exceeded::TimeExceededPacket; @@ -17,40 +14,48 @@ use trippy::tracing::packet::IpProtocol; pub async fn run( tun: &Arc>, sim: Arc, - mut recv: Receiver<()>, + token: CancellationToken, ) -> anyhow::Result<()> { + println!("before lock"); let mut tun = tun.lock().await; + println!("locked"); + loop { - match recv.try_recv() { - Ok(_) => { - println!("shutdown"); - return Ok(()); - } - Err(TryRecvError::Disconnected) => { - println!("Disconnected shutdown"); - return Ok(()); - } - Err(TryRecvError::Empty) => {} - } + // match recv.try_recv() { + // Ok(_) => { + // println!("shutdown"); + // return Ok(()); + // } + // Err(TryRecvError::Disconnected) => { + // println!("Disconnected shutdown"); + // return Ok(()); + // } + // Err(TryRecvError::Empty) => {} + // } + + // let bytes_read = match timeout(Duration::from_millis(1000), tun.read(&mut buf)).await { + // Ok(bytes) => bytes?, + // Err(_err) => { + // println!("timeout"); + // continue; + // } + // }; let mut buf = [0_u8; 4096]; - let bytes_read = match timeout(Duration::from_millis(1000), tun.read(&mut buf)).await { - Ok(bytes) => bytes?, - Err(_err) => { - println!("timeout"); - continue; - } - }; + println!("before select"); + + let bytes_read = tokio::select!( + _ = token.cancelled() => { + println!("shutdown"); + return Ok(()) + }, + bytes_read = tun.read(&mut buf) => { + println!("read bytes"); + bytes_read? + }, + ); + println!("after select"); - // let bytes_read = tokio::select!( - // _ = recv.recv() => { - // println!("shutdown"); - // return Ok(()) - // }, - // bytes_read = timeout(Duration::from_millis(1000), tun.read(&mut buf)) => { - // bytes_read - // }?, - // ); // let bytes_read = tun.read(&mut buf).await?; let ipv4 = Ipv4Packet::new_view(&buf[..bytes_read])?; if ipv4.get_version() != 4 { @@ -85,7 +90,10 @@ pub async fn run( ipv4.get_source(), te_packet.packet(), )?; + println!("before write"); + tun.write(ipv4_packet.packet()).await?; + println!("after write"); } } diff --git a/tests/sim/tests.rs b/tests/sim/tests.rs index 9e511f763..b75fbfd13 100644 --- a/tests/sim/tests.rs +++ b/tests/sim/tests.rs @@ -2,6 +2,7 @@ use crate::simulation::Simulation; use crate::tun_device::tun; use crate::{network, tracer}; use std::sync::Arc; +use tokio_util::sync::CancellationToken; macro_rules! sim { ($path:expr) => {{ @@ -12,7 +13,7 @@ macro_rules! sim { }}; } -#[tokio::test] +#[tokio::test(flavor = "multi_thread")] async fn test_simulations() -> anyhow::Result<()> { sim!("ipv4_icmp_simple.yaml"); sim!("ipv4_icmp_simple2.yaml"); @@ -22,10 +23,12 @@ async fn test_simulations() -> anyhow::Result<()> { async fn run_simulation(simulation: Simulation) -> anyhow::Result<()> { let tun = tun(); let sim = Arc::new(simulation); - let (send, recv) = tokio::sync::mpsc::channel(1); + // let (send, recv) = tokio::sync::mpsc::channel(1); + + let token = CancellationToken::new(); // spawn the network simulator task but do not join it yet. - let handle = tokio::spawn(network::run(tun, sim.clone(), recv)); + let handle = tokio::spawn(network::run(tun, sim.clone(), token.clone())); // spawn the tracer as a blocking task and wait for it to finish or fail. tokio::task::spawn_blocking(move || tracer::Tracer::new(sim).trace()).await??; @@ -33,9 +36,9 @@ async fn run_simulation(simulation: Simulation) -> anyhow::Result<()> { println!("tracing complete"); // notify the network simulator to stop. - send.send(()).await?; + token.cancel(); - println!("send complete"); + println!("cancel signalled"); // join the network simulator task once it has shutdown. handle.await??;