diff --git a/Cargo.lock b/Cargo.lock index c1677d9..6499e35 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -32,12 +32,24 @@ dependencies = [ "rustc-demangle", ] +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + [[package]] name = "bitflags" version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "bytes" version = "1.6.0" @@ -56,6 +68,70 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" + +[[package]] +name = "defmt" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a99dd22262668b887121d4672af5a64b238f026099f1a2a1b322066c9ecfe9e0" +dependencies = [ + "bitflags 1.3.2", + "defmt-macros", +] + +[[package]] +name = "defmt-macros" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3a9f309eff1f79b3ebdf252954d90ae440599c26c2c553fe87a2d17195f2dcb" +dependencies = [ + "defmt-parser", + "proc-macro-error", + "proc-macro2", + "quote", + "syn 2.0.70", +] + +[[package]] +name = "defmt-parser" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff4a5fefe330e8d7f31b16a318f9ce81000d8e35e69b93eae154d16d2278f70f" +dependencies = [ + "thiserror", +] + +[[package]] +name = "errno" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + +[[package]] +name = "fastrand" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" + [[package]] name = "foreign-types" version = "0.3.2" @@ -88,18 +164,49 @@ version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" +[[package]] +name = "hash32" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47d60b12902ba28e2730cd37e95b8c9223af2808df9e902d4df49588d1470606" +dependencies = [ + "byteorder", +] + +[[package]] +name = "heapless" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bfb9eb618601c89945a70e254898da93b13be0388091d42117462b265bb3fad" +dependencies = [ + "hash32", + "stable_deref_trait", +] + [[package]] name = "libc" version = "0.2.155" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" +[[package]] +name = "linux-raw-sys" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" + [[package]] name = "log" version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +[[package]] +name = "managed" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ca88d725a0a943b096803bd34e73a4437208b6077654cc4ecb2947a5f91618d" + [[package]] name = "memchr" version = "2.7.4" @@ -126,6 +233,23 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "native-tls" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8614eb2c83d59d1c8cc974dd3f920198647674a0a035e1af1fa58707e317466" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + [[package]] name = "object" version = "0.36.1" @@ -147,7 +271,7 @@ version = "0.10.64" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f" dependencies = [ - "bitflags", + "bitflags 2.6.0", "cfg-if", "foreign-types", "libc", @@ -164,9 +288,15 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.70", ] +[[package]] +name = "openssl-probe" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" + [[package]] name = "openssl-sys" version = "0.9.102" @@ -197,6 +327,30 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +[[package]] +name = "proc-macro-error" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +dependencies = [ + "proc-macro-error-attr", + "proc-macro2", + "quote", + "syn 1.0.109", + "version_check", +] + +[[package]] +name = "proc-macro-error-attr" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +dependencies = [ + "proc-macro2", + "quote", + "version_check", +] + [[package]] name = "proc-macro2" version = "1.0.86" @@ -211,9 +365,10 @@ name = "pterodapter" version = "0.0.1" dependencies = [ "log", - "openssl", "rand", + "smoltcp", "tokio", + "tokio-native-tls", ] [[package]] @@ -261,6 +416,51 @@ version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" +[[package]] +name = "rustix" +version = "0.38.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" +dependencies = [ + "bitflags 2.6.0", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.52.0", +] + +[[package]] +name = "schannel" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbc91545643bcf3a0bbb6569265615222618bdf33ce4ffbbd13c4bbd4c093534" +dependencies = [ + "windows-sys 0.52.0", +] + +[[package]] +name = "security-framework" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c627723fd09706bacdb5cf41499e95098555af3c3c29d014dc3c458ef6be11c0" +dependencies = [ + "bitflags 2.6.0", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "317936bbbd05227752583946b9e66d7ce3b489f84e11a94a510b4437fef407d7" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "signal-hook-registry" version = "1.4.2" @@ -270,6 +470,20 @@ dependencies = [ "libc", ] +[[package]] +name = "smoltcp" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a1a996951e50b5971a2c8c0fa05a381480d70a933064245c4a223ddc87ccc97" +dependencies = [ + "bitflags 1.3.2", + "byteorder", + "cfg-if", + "defmt", + "heapless", + "managed", +] + [[package]] name = "socket2" version = "0.5.7" @@ -280,6 +494,22 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "unicode-ident", +] + [[package]] name = "syn" version = "2.0.70" @@ -291,6 +521,38 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "tempfile" +version = "3.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1" +dependencies = [ + "cfg-if", + "fastrand", + "rustix", + "windows-sys 0.52.0", +] + +[[package]] +name = "thiserror" +version = "1.0.62" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2675633b1499176c2dff06b0856a27976a8f9d436737b4cf4f312d4d91d8bbb" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.62" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d20468752b09f49e909e55a5d338caa8bedf615594e9d80bc4c565d30faf798c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.70", +] + [[package]] name = "tokio" version = "1.38.0" @@ -307,6 +569,16 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + [[package]] name = "unicode-ident" version = "1.0.12" @@ -319,6 +591,12 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" diff --git a/Cargo.toml b/Cargo.toml index 405e97a..14ad37f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,8 +10,9 @@ rust-version = "1.77" [dependencies] log = { version = "0.4", default-features = false } tokio = { version = "1.38", default-features = false, features = ["rt", "io-util", "signal", "net", "time", "sync"] } +tokio-native-tls = { version = "0.3", default-features = false } +smoltcp = { version = "0.11", default-features = false, features = ["std", "medium-ip", "proto-ipv4", "socket-tcp"] } rand = { version = "0.8", default-features = false, features = ["std", "std_rng"] } -openssl = { version = "0.10", default-features = false } [profile.release] strip = true diff --git a/src/main.rs b/src/main.rs index e4ddcf9..624bc5a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,15 +1,15 @@ use std::{ - env, fmt, fs, - net::{IpAddr, Ipv6Addr}, + env, fmt, + net::{IpAddr, Ipv6Addr, SocketAddr}, process, str::FromStr, }; -mod ikev2; mod logger; +mod socks; enum Action { - Serve(ikev2::Config), + Proxy(socks::Config), } pub struct Args { @@ -17,15 +17,12 @@ pub struct Args { action: Action, } -const USAGE_INSTRUCTIONS: &str = "Usage: pterodapter [OPTIONS] serve\n\n\ +const USAGE_INSTRUCTIONS: &str = "Usage: pterodapter [OPTIONS] proxy\n\n\ Options:\ -\n --log-level= Log level [default: info]\ -\n --listen-ip= Listen IP address, multiple options can be provided [default: ::]\ -\n --id-hostname= Hostname for identification [default: pterodapter]\ -\n --cacert= Path to root CA certificate (in PKCS 8 PEM format)\ -\n --cert= Path to public certificate (in PKCS 8 PEM format)\ -\n --key= Path to private key (in PKCS 8 PEM format)\ -\n --help Print help"; +\n --log-level= Log level [default: info]\ +\n --listen-address= Listen IP address [default: :::5328]\ +\n --destination= Destination FortiVPN address, e.g. sslvpn.example.com:443\ +\n --help Print help"; impl Args { fn parse() -> Args { @@ -39,11 +36,8 @@ impl Args { }; let mut log_level = log::LevelFilter::Info; - let mut listen_ips = vec![]; - let mut id_hostname = None; - let mut root_ca = None; - let mut private_key = None; - let mut public_cert = None; + let mut listen_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 5328); + let mut destination = None; for arg in env::args() .take(env::args().len().saturating_sub(1)) @@ -74,44 +68,22 @@ impl Args { process::exit(2); } }; - } else if name == "--listen-ip" { - match IpAddr::from_str(value) { - Ok(ip) => { - listen_ips.push(ip); - } + } else if name == "--listen-address" { + match SocketAddr::from_str(value) { + Ok(addr) => listen_addr = addr, Err(err) => fail_with_error( name, value, - format_args!("Failed to parse IP address: {}", err), + format_args!("Failed to parse listen address: {}", err), ), }; - } else if name == "--id-hostname" { - id_hostname = Some(value.into()); - } else if name == "--cacert" { - match fs::read_to_string(value) { - Ok(cert) => root_ca = Some(cert), + } else if name == "--destination" { + match SocketAddr::from_str(value) { + Ok(addr) => destination = Some(addr), Err(err) => fail_with_error( name, value, - format_args!("Failed to read root CA cert: {}", err), - ), - }; - } else if name == "--cert" { - match fs::read_to_string(value) { - Ok(cert) => public_cert = Some(cert), - Err(err) => fail_with_error( - name, - value, - format_args!("Failed to read root CA cert: {}", err), - ), - }; - } else if name == "--key" { - match fs::read_to_string(value) { - Ok(cert) => private_key = Some(cert), - Err(err) => fail_with_error( - name, - value, - format_args!("Failed to read root CA cert: {}", err), + format_args!("Failed to parse destination address: {}", err), ), }; } else { @@ -128,21 +100,14 @@ impl Args { }; match action.as_str() { - "serve" => { - let server_cert = match (public_cert.clone(), private_key.clone()) { - (Some(public_cert), Some(private_key)) => Some((public_cert, private_key)), - _ => None, - }; - if listen_ips.is_empty() { - listen_ips = vec![IpAddr::V6(Ipv6Addr::UNSPECIFIED)]; + "proxy" => { + if destination.is_none() { + eprintln!("No destination specified"); + println!("{}", USAGE_INSTRUCTIONS); + process::exit(2); } - let action = Action::Serve(ikev2::Config { - hostname: id_hostname.clone(), - listen_ips: listen_ips.clone(), - root_ca: root_ca.clone(), - server_cert, - }); + let action = Action::Proxy(socks::Config { listen_addr }); Args { log_level, action } } _ => { @@ -165,18 +130,14 @@ fn main() { eprintln!("Failed to set up logger, error is {}", err); } match args.action { - Action::Serve(config) => { - let server = match ikev2::Server::new(config) { + Action::Proxy(config) => { + match socks::run(config) { Ok(server) => server, Err(err) => { - println!("Failed to create server, error is {}", err); + println!("Failed to run server, error is {}", err); std::process::exit(1) } }; - if let Err(err) = server.run() { - println!("Failed to run server, error is {}", err); - std::process::exit(1); - } } } } diff --git a/src/socks.rs b/src/socks.rs new file mode 100644 index 0000000..7778646 --- /dev/null +++ b/src/socks.rs @@ -0,0 +1,382 @@ +use std::{ + error, fmt, io, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + time::Duration, +}; + +use log::{debug, info, warn}; +use tokio::{ + io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader, BufWriter}, + net::{tcp, TcpListener, TcpStream}, + runtime, signal, + sync::mpsc, + task::JoinHandle, +}; + +pub struct Config { + pub listen_addr: SocketAddr, +} +pub fn run(config: Config) -> Result<(), SocksError> { + let server = Server::new(config)?; + let rt = tokio::runtime::Builder::new_current_thread() + .enable_io() + .enable_time() + .build()?; + let (accept_sender, accept_receiver) = mpsc::channel(10); + let handles = vec![rt.spawn(async move { server.listen_socket(accept_sender).await })]; + + rt.block_on(Server::wait_termination(handles))?; + rt.shutdown_timeout(Duration::from_secs(60)); + + info!("Stopped server"); + Ok(()) +} + +pub struct Server { + listen_addr: SocketAddr, +} + +impl Server { + pub fn new(config: Config) -> Result { + Ok(Server { + listen_addr: config.listen_addr, + }) + } + + async fn listen_socket( + &self, + offload_connections: mpsc::Sender, + ) -> Result<(), SocksError> { + let listener = match TcpListener::bind(&self.listen_addr).await { + Ok(listener) => listener, + Err(err) => { + warn!("Failed to bind listener on {}: {}", self.listen_addr, err); + return Err("Failed to bind listener".into()); + } + }; + info!("Started server on {}", self.listen_addr); + loop { + match listener.accept().await { + Ok((socket, addr)) => { + debug!("Received connection from {}", addr); + let mut handler = SocksConnection::new(socket); + let rt = runtime::Handle::current(); + rt.spawn(async move { + if let Err(err) = handler.handle_connection().await { + warn!("SOCKS connection failed: {}", err); + } + }); + } + Err(err) => warn!("Failed to accept incoming connection: {}", err), + } + } + } + + async fn wait_termination( + handles: Vec>>, + ) -> Result<(), SocksError> { + signal::ctrl_c().await?; + handles.iter().for_each(|handle| handle.abort()); + Ok(()) + } +} + +struct SocksConnection { + reader: BufReader, + writer: BufWriter, +} + +impl SocksConnection { + fn new<'a>(socket: TcpStream) -> SocksConnection { + let (reader, writer) = socket.into_split(); + let reader = BufReader::new(reader); + let writer = BufWriter::new(writer); + SocksConnection { reader, writer } + } + + async fn handle_connection(&mut self) -> Result<(), SocksError> { + self.perform_handshake().await?; + self.read_request().await?; + + let mut request = String::new(); + while self.reader.read_line(&mut request).await? > 0 { + println!("Received request {}", request); + request.clear(); + } + + Ok(()) + } + + async fn perform_handshake(&mut self) -> Result<(), SocksError> { + let version = self.reader.read_u8().await?; + if version != SOCKS5_VERSION { + return Err("Unsupported SOCKS version".into()); + } + let nmethods = self.reader.read_u8().await?; + let mut selected_method = AuthenticationMethod::NO_ACCEPTABLE_METHODS; + for _ in 0..nmethods { + let method = AuthenticationMethod::from_u8(self.reader.read_u8().await?); + if method == AuthenticationMethod::NO_AUTHENTICATION_REQUIRED { + selected_method = method; + } + } + self.writer.write_u8(SOCKS5_VERSION).await?; + self.writer.write_u8(selected_method.0).await?; + self.writer.flush().await?; + Ok(()) + } + + async fn read_request(&mut self) -> Result<(), SocksError> { + println!("Prepating to read request"); + let version = self.reader.read_u8().await?; + println!("Read to read request"); + if version != SOCKS5_VERSION { + return Err("Unsupported SOCKS version".into()); + } + let cmd = SocksCommand::from_u8(self.reader.read_u8().await?); + let _ = self.reader.read_u8().await?; // Reserved byte. + let addr_type = SocksAddressType::from_u8(self.reader.read_u8().await?); + let addr = match addr_type { + SocksAddressType::IPV4 => { + let mut octets = [0u8; 4]; + self.reader.read_exact(&mut octets).await?; + Some(DestinationAddress::IpAddr(IpAddr::V4(Ipv4Addr::from( + octets, + )))) + } + SocksAddressType::DOMAINNAME => { + let len = self.reader.read_u8().await?; + let mut dest = vec![0; len as usize]; + self.reader.read_exact(dest.as_mut_slice()).await?; + let domain = match String::from_utf8(dest) { + Ok(domain) => Ok(domain), + Err(err) => { + debug!("Failed to decode domain name: {}", err); + Err("Failed to decode domain") + } + }?; + Some(DestinationAddress::Domain(domain)) + } + SocksAddressType::IPV6 => { + let mut octets = [0u8; 16]; + self.reader.read_exact(&mut octets).await?; + Some(DestinationAddress::IpAddr(IpAddr::V6(Ipv6Addr::from( + octets, + )))) + } + _ => None, + }; + let port = self.reader.read_u16().await?; + + self.writer.write_u8(SOCKS5_VERSION).await?; + if cmd != SocksCommand::CONNECT { + self.write_error_response(CommandResponse::COMMAND_NOT_SUPPORTED) + .await?; + debug!("Command {} is not supported", cmd); + return Err("Command is not supported".into()); + } + let addr = if let Some(addr) = addr { + addr + } else { + self.write_error_response(CommandResponse::ADDRESS_TYPE_NOT_SUPPORTED) + .await?; + debug!("Address type {} is not supported", addr_type); + return Err("Address type is not supported".into()); + }; + + self.connect_to_host(addr, port).await?; + Ok(()) + } + + async fn connect_to_host( + &mut self, + addr: DestinationAddress, + port: u16, + ) -> Result<(), SocksError> { + // TODO: open SmolTCP/FortiVPN connection here. + let bnd_addr = Ipv4Addr::LOCALHOST; + let bnd_port = port; + + self.writer.write_u8(CommandResponse::SUCCEDED.0).await?; + self.writer.write_u8(0).await?; // Reserved byte. + self.writer.write_u8(SocksAddressType::IPV4.0).await?; + self.writer.write_all(&bnd_addr.octets()).await?; + self.writer.write_u16(bnd_port).await?; + self.writer.flush().await?; + Ok(()) + } + + async fn write_error_response(&mut self, response: CommandResponse) -> Result<(), SocksError> { + self.writer.write_u8(response.0).await?; + self.writer.write_u8(0).await?; // Reserved byte. + self.writer.write_u8(SocksAddressType::DOMAINNAME.0).await?; + self.writer.write_u8(0).await?; // Empty domain name. + self.writer.write_u16(0).await?; + self.writer.flush().await?; + Ok(()) + } +} + +const SOCKS5_VERSION: u8 = 0x05; + +#[derive(Clone, Copy, PartialEq, Eq)] +struct AuthenticationMethod(u8); +impl AuthenticationMethod { + const NO_AUTHENTICATION_REQUIRED: AuthenticationMethod = AuthenticationMethod(0x00); + const GSSAPI: AuthenticationMethod = AuthenticationMethod(0x01); + const USERNAME_PASSWORD: AuthenticationMethod = AuthenticationMethod(0x02); + const NO_ACCEPTABLE_METHODS: AuthenticationMethod = AuthenticationMethod(0xff); + + fn from_u8(method: u8) -> AuthenticationMethod { + AuthenticationMethod(method) + } +} +impl fmt::Display for AuthenticationMethod { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + &Self::NO_AUTHENTICATION_REQUIRED => write!(f, "NO AUTHENTICATION REQUIRED"), + &Self::GSSAPI => write!(f, "GSSAPI"), + &Self::USERNAME_PASSWORD => write!(f, "USERNAME/PASSWORD"), + &Self::NO_ACCEPTABLE_METHODS => write!(f, "NO ACCEPTABLE METHODS"), + _ => write!(f, "Unknown authentication method {}", self.0), + } + } +} + +#[derive(Clone, Copy, PartialEq, Eq)] +struct SocksCommand(u8); +impl SocksCommand { + const CONNECT: SocksCommand = SocksCommand(0x01); + const BIND: SocksCommand = SocksCommand(0x02); + const UDP_ASSOCIATE: SocksCommand = SocksCommand(0x03); + + fn from_u8(method: u8) -> SocksCommand { + SocksCommand(method) + } +} +impl fmt::Display for SocksCommand { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + &Self::CONNECT => write!(f, "CONNECT"), + &Self::BIND => write!(f, "BIND"), + &Self::UDP_ASSOCIATE => write!(f, "UDP ASSOCIATE"), + _ => write!(f, "Unknown command {}", self.0), + } + } +} + +#[derive(Clone, Copy, PartialEq, Eq)] +struct CommandResponse(u8); +impl CommandResponse { + const SUCCEDED: CommandResponse = CommandResponse(0x00); + const GENERAL_FAILURE: CommandResponse = CommandResponse(0x01); + const CONNECTION_NOT_ALLOWED: CommandResponse = CommandResponse(0x02); + const NETWORK_UNREACHABLE: CommandResponse = CommandResponse(0x03); + const HOST_UNREACHABLE: CommandResponse = CommandResponse(0x04); + const CONNECTION_REFUSED: CommandResponse = CommandResponse(0x05); + const TTL_EXPIRED: CommandResponse = CommandResponse(0x06); + const COMMAND_NOT_SUPPORTED: CommandResponse = CommandResponse(0x07); + const ADDRESS_TYPE_NOT_SUPPORTED: CommandResponse = CommandResponse(0x08); + + fn from_u8(method: u8) -> SocksCommand { + SocksCommand(method) + } +} +impl fmt::Display for CommandResponse { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + &Self::SUCCEDED => write!(f, "succeeded"), + &Self::GENERAL_FAILURE => write!(f, "general SOCKS server failure"), + &Self::CONNECTION_NOT_ALLOWED => write!(f, "connection not allowed by ruleset"), + &Self::NETWORK_UNREACHABLE => write!(f, "Network unreachable"), + &Self::HOST_UNREACHABLE => write!(f, "Host unreachable"), + &Self::CONNECTION_REFUSED => write!(f, "Connection refused"), + &Self::TTL_EXPIRED => write!(f, "TTL expired"), + &Self::COMMAND_NOT_SUPPORTED => write!(f, "Command not supported"), + &Self::ADDRESS_TYPE_NOT_SUPPORTED => write!(f, "Address type not supported"), + _ => write!(f, "Unknown command response {}", self.0), + } + } +} + +enum DestinationAddress { + IpAddr(IpAddr), + Domain(String), +} +impl fmt::Display for DestinationAddress { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::IpAddr(ref addr) => addr.fmt(f), + Self::Domain(ref domain) => domain.fmt(f), + } + } +} + +#[derive(Clone, Copy, PartialEq, Eq)] +struct SocksAddressType(u8); +impl SocksAddressType { + const IPV4: SocksAddressType = SocksAddressType(0x01); + const DOMAINNAME: SocksAddressType = SocksAddressType(0x02); + const IPV6: SocksAddressType = SocksAddressType(0x04); + + fn from_u8(method: u8) -> SocksAddressType { + SocksAddressType(method) + } +} +impl fmt::Display for SocksAddressType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + &Self::IPV4 => write!(f, "IP V4 address"), + &Self::DOMAINNAME => write!(f, "DOMAINNAME"), + &Self::IPV6 => write!(f, "IP V6 address"), + _ => write!(f, "Unknown address type {}", self.0), + } + } +} + +#[derive(Debug)] +pub enum SocksError { + Internal(&'static str), + Join(tokio::task::JoinError), + Io(io::Error), +} + +impl fmt::Display for SocksError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Self::Internal(msg) => f.write_str(msg), + Self::Join(ref e) => write!(f, "Tokio join error: {}", e), + Self::Io(ref e) => { + write!(f, "IO error: {}", e) + } + } + } +} + +impl error::Error for SocksError { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + match *self { + Self::Internal(_msg) => None, + Self::Join(ref err) => Some(err), + Self::Io(ref err) => Some(err), + } + } +} + +impl From<&'static str> for SocksError { + fn from(msg: &'static str) -> SocksError { + Self::Internal(msg) + } +} + +impl From for SocksError { + fn from(err: tokio::task::JoinError) -> SocksError { + Self::Join(err) + } +} + +impl From for SocksError { + fn from(err: io::Error) -> SocksError { + Self::Io(err) + } +}