From 40c4af79bf5b32b8fbdbf6f2e5c16290e1d3d406 Mon Sep 17 00:00:00 2001 From: Eliza Weisman Date: Mon, 2 Nov 2020 10:15:04 -0800 Subject: [PATCH] net: add `TcpSocket::set_{send, recv}_buffer_size` (#1384) This commit adds methods for setting `SO_RECVBUF` and `SO_SNDBUF` to Mio's `TcpSocket` type. It would be nice to eventually expose these in Tokio, so adding them to Mio is the first step. See tokio-rs/tokio#3082 for details. Signed-off-by: Eliza Weisman --- src/net/tcp/socket.rs | 58 ++++++++++++++++++++++++++++++++++++++ src/sys/shell/tcp.rs | 16 +++++++++++ src/sys/unix/tcp.rs | 57 ++++++++++++++++++++++++++++++++++++-- src/sys/windows/tcp.rs | 63 +++++++++++++++++++++++++++++++++++++++++- tests/tcp_socket.rs | 45 ++++++++++++++++++++++++++++++ 5 files changed, 236 insertions(+), 3 deletions(-) diff --git a/src/net/tcp/socket.rs b/src/net/tcp/socket.rs index f3e27c370..b43690b24 100644 --- a/src/net/tcp/socket.rs +++ b/src/net/tcp/socket.rs @@ -106,6 +106,64 @@ impl TcpSocket { sys::tcp::set_linger(self.sys, dur) } + /// Sets the value of `SO_RCVBUF` on this socket. + pub fn set_recv_buffer_size(&self, size: u32) -> io::Result<()> { + sys::tcp::set_recv_buffer_size(self.sys, size) + } + + /// Get the value of `SO_RCVBUF` set on this socket. + /// + /// Note that if [`set_recv_buffer_size`] has been called on this socket + /// previously, the value returned by this function may not be the same as + /// the argument provided to `set_recv_buffer_size`. This is for the + /// following reasons: + /// + /// * Most operating systems have minimum and maximum allowed sizes for the + /// receive buffer, and will clamp the provided value if it is below the + /// minimum or above the maximum. The minimum and maximum buffer sizes are + /// OS-dependent. + /// * Linux will double the buffer size to account for internal bookkeeping + /// data, and returns the doubled value from `getsockopt(2)`. As per `man + /// 7 socket`: + /// > Sets or gets the maximum socket receive buffer in bytes. The + /// > kernel doubles this value (to allow space for bookkeeping + /// > overhead) when it is set using `setsockopt(2)`, and this doubled + /// > value is returned by `getsockopt(2)`. + /// + /// [`set_recv_buffer_size`]: #method.set_recv_buffer_size + pub fn get_recv_buffer_size(&self) -> io::Result { + sys::tcp::get_recv_buffer_size(self.sys) + } + + /// Sets the value of `SO_SNDBUF` on this socket. + pub fn set_send_buffer_size(&self, size: u32) -> io::Result<()> { + sys::tcp::set_send_buffer_size(self.sys, size) + } + + /// Get the value of `SO_SNDBUF` set on this socket. + /// + /// Note that if [`set_send_buffer_size`] has been called on this socket + /// previously, the value returned by this function may not be the same as + /// the argument provided to `set_send_buffer_size`. This is for the + /// following reasons: + /// + /// * Most operating systems have minimum and maximum allowed sizes for the + /// receive buffer, and will clamp the provided value if it is below the + /// minimum or above the maximum. The minimum and maximum buffer sizes are + /// OS-dependent. + /// * Linux will double the buffer size to account for internal bookkeeping + /// data, and returns the doubled value from `getsockopt(2)`. As per `man + /// 7 socket`: + /// > Sets or gets the maximum socket send buffer in bytes. The + /// > kernel doubles this value (to allow space for bookkeeping + /// > overhead) when it is set using `setsockopt(2)`, and this doubled + /// > value is returned by `getsockopt(2)`. + /// + /// [`set_send_buffer_size`]: #method.set_send_buffer_size + pub fn get_send_buffer_size(&self) -> io::Result { + sys::tcp::get_send_buffer_size(self.sys) + } + /// Returns the local address of this socket /// /// Will return `Err` result in windows if called before calling `bind` diff --git a/src/sys/shell/tcp.rs b/src/sys/shell/tcp.rs index 3073d42f7..b67e33db3 100644 --- a/src/sys/shell/tcp.rs +++ b/src/sys/shell/tcp.rs @@ -50,6 +50,22 @@ pub(crate) fn set_linger(_: TcpSocket, _: Option) -> io::Result<()> { os_required!(); } +pub(crate) fn set_recv_buffer_size(_: TcpSocket, _: u32) -> io::Result<()> { + os_required!(); +} + +pub(crate) fn get_recv_buffer_size(_: TcpSocket) -> io::Result { + os_required!(); +} + +pub(crate) fn set_send_buffer_size(_: TcpSocket, _: u32) -> io::Result<()> { + os_required!(); +} + +pub(crate) fn get_send_buffer_size(_: TcpSocket) -> io::Result { + os_required!(); +} + pub fn accept(_: &net::TcpListener) -> io::Result<(net::TcpStream, SocketAddr)> { os_required!(); } diff --git a/src/sys/unix/tcp.rs b/src/sys/unix/tcp.rs index 65b7400e9..8d5f5aa5c 100644 --- a/src/sys/unix/tcp.rs +++ b/src/sys/unix/tcp.rs @@ -1,4 +1,5 @@ use std::io; +use std::convert::TryInto; use std::mem; use std::mem::{size_of, MaybeUninit}; use std::net::{self, SocketAddr}; @@ -37,8 +38,6 @@ pub(crate) fn connect(socket: TcpSocket, addr: SocketAddr) -> io::Result io::Result { - use std::convert::TryInto; - let backlog = backlog.try_into().unwrap_or(i32::max_value()); syscall!(listen(socket, backlog))?; Ok(unsafe { net::TcpListener::from_raw_fd(socket) }) @@ -130,6 +129,60 @@ pub(crate) fn set_linger(socket: TcpSocket, dur: Option) -> io::Result )).map(|_| ()) } +pub(crate) fn set_recv_buffer_size(socket: TcpSocket, size: u32) -> io::Result<()> { + let size = size.try_into().ok().unwrap_or_else(i32::max_value); + syscall!(setsockopt( + socket, + libc::SOL_SOCKET, + libc::SO_RCVBUF, + &size as *const _ as *const libc::c_void, + size_of::() as libc::socklen_t + )) + .map(|_| ()) +} + +pub(crate) fn get_recv_buffer_size(socket: TcpSocket) -> io::Result { + let mut optval: libc::c_int = 0; + let mut optlen = size_of::() as libc::socklen_t; + + syscall!(getsockopt( + socket, + libc::SOL_SOCKET, + libc::SO_RCVBUF, + &mut optval as *mut _ as *mut _, + &mut optlen, + ))?; + + Ok(optval as u32) +} + +pub(crate) fn set_send_buffer_size(socket: TcpSocket, size: u32) -> io::Result<()> { + let size = size.try_into().ok().unwrap_or_else(i32::max_value); + syscall!(setsockopt( + socket, + libc::SOL_SOCKET, + libc::SO_SNDBUF, + &size as *const _ as *const libc::c_void, + size_of::() as libc::socklen_t + )) + .map(|_| ()) +} + +pub(crate) fn get_send_buffer_size(socket: TcpSocket) -> io::Result { + let mut optval: libc::c_int = 0; + let mut optlen = size_of::() as libc::socklen_t; + + syscall!(getsockopt( + socket, + libc::SOL_SOCKET, + libc::SO_SNDBUF, + &mut optval as *mut _ as *mut _, + &mut optlen, + ))?; + + Ok(optval as u32) +} + pub fn accept(listener: &net::TcpListener) -> io::Result<(net::TcpStream, SocketAddr)> { let mut addr: MaybeUninit = MaybeUninit::uninit(); let mut length = size_of::() as libc::socklen_t; diff --git a/src/sys/windows/tcp.rs b/src/sys/windows/tcp.rs index b78d86479..e14f1c8bd 100644 --- a/src/sys/windows/tcp.rs +++ b/src/sys/windows/tcp.rs @@ -2,6 +2,7 @@ use std::io; use std::mem::size_of; use std::net::{self, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::time::Duration; +use std::convert::TryInto; use std::os::windows::io::FromRawSocket; use std::os::windows::raw::SOCKET as StdSocket; // winapi uses usize, stdlib uses u32/u64. @@ -12,7 +13,7 @@ use winapi::shared::ws2ipdef::SOCKADDR_IN6_LH; use winapi::shared::minwindef::{BOOL, TRUE, FALSE}; use winapi::um::winsock2::{ self, closesocket, linger, setsockopt, getsockopt, getsockname, PF_INET, PF_INET6, SOCKET, SOCKET_ERROR, - SOCK_STREAM, SOL_SOCKET, SO_LINGER, SO_REUSEADDR, + SOCK_STREAM, SOL_SOCKET, SO_LINGER, SO_REUSEADDR, SO_RCVBUF, SO_SNDBUF, }; use crate::sys::windows::net::{init, new_socket, socket_addr}; @@ -149,6 +150,66 @@ pub(crate) fn set_linger(socket: TcpSocket, dur: Option) -> io::Result } } + +pub(crate) fn set_recv_buffer_size(socket: TcpSocket, size: u32) -> io::Result<()> { + let size = size.try_into().ok().unwrap_or_else(i32::max_value); + match unsafe { setsockopt( + socket, + SOL_SOCKET, + SO_RCVBUF, + &size as *const _ as *const c_char, + size_of::() as c_int + ) } { + SOCKET_ERROR => Err(io::Error::last_os_error()), + _ => Ok(()), + } +} + +pub(crate) fn get_recv_buffer_size(socket: TcpSocket) -> io::Result { + let mut optval: c_int = 0; + let mut optlen = size_of::() as c_int; + match unsafe { getsockopt( + socket, + SOL_SOCKET, + SO_RCVBUF, + &mut optval as *mut _ as *mut _, + &mut optlen as *mut _, + ) } { + SOCKET_ERROR => Err(io::Error::last_os_error()), + _ => Ok(optval as u32), + } +} + +pub(crate) fn set_send_buffer_size(socket: TcpSocket, size: u32) -> io::Result<()> { + let size = size.try_into().ok().unwrap_or_else(i32::max_value); + match unsafe { setsockopt( + socket, + SOL_SOCKET, + SO_SNDBUF, + &size as *const _ as *const c_char, + size_of::() as c_int + ) } { + SOCKET_ERROR => Err(io::Error::last_os_error()), + _ => Ok(()), + } +} + +pub(crate) fn get_send_buffer_size(socket: TcpSocket) -> io::Result { + let mut optval: c_int = 0; + let mut optlen = size_of::() as c_int; + match unsafe { getsockopt( + socket, + SOL_SOCKET, + SO_SNDBUF, + &mut optval as *mut _ as *mut _, + &mut optlen as *mut _, + ) } { + SOCKET_ERROR => Err(io::Error::last_os_error()), + _ => Ok(optval as u32), + } +} + + pub(crate) fn accept(listener: &net::TcpListener) -> io::Result<(net::TcpStream, SocketAddr)> { // The non-blocking state of `listener` is inherited. See // https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-accept#remarks. diff --git a/tests/tcp_socket.rs b/tests/tcp_socket.rs index 0ad2c7ba5..bb57ade71 100644 --- a/tests/tcp_socket.rs +++ b/tests/tcp_socket.rs @@ -1,6 +1,7 @@ #![cfg(all(feature = "os-poll", feature = "tcp"))] use mio::net::TcpSocket; +use std::io; #[test] fn is_send_and_sync() { @@ -56,3 +57,47 @@ fn get_localaddr() { let _ = socket.listen(128).unwrap(); } + +#[test] +fn send_buffer_size_roundtrips() { + test_buffer_sizes( + TcpSocket::set_send_buffer_size, + TcpSocket::get_send_buffer_size, + ) +} + +#[test] +fn recv_buffer_size_roundtrips() { + test_buffer_sizes( + TcpSocket::set_recv_buffer_size, + TcpSocket::get_recv_buffer_size, + ) +} + +// Helper for testing send/recv buffer size. +fn test_buffer_sizes( + set: impl Fn(&TcpSocket, u32) -> io::Result<()>, + get: impl Fn(&TcpSocket) -> io::Result, +) { + let test = |size: u32| { + println!("testing buffer size: {}", size); + let socket = TcpSocket::new_v4().unwrap(); + set(&socket, size).unwrap(); + // Note that this doesn't assert that the values are equal: on Linux, + // the kernel doubles the requested buffer size, and returns the doubled + // value from `getsockopt`. As per `man socket(7)`: + // > Sets or gets the maximum socket send buffer in bytes. The + // > kernel doubles this value (to allow space for bookkeeping + // > overhead) when it is set using setsockopt(2), and this doubled + // > value is returned by getsockopt(2). + // + // Additionally, the buffer size may be clamped above a minimum value, + // and this minimum value is OS-dependent. + let actual = get(&socket).unwrap(); + assert!(actual >= size, "\tactual: {}\n\texpected: {}", actual, size); + }; + + test(256); + test(4096); + test(65512); +}