Skip to content

Commit

Permalink
Make cmsg_space! usable in const contexts. (#889)
Browse files Browse the repository at this point in the history
Make `cmsg_space!` usable in const contexts, so that it can be used as a
buffer size argument, and add a version of tests/net/unix.rs that uses
stack-allocated buffers instead of `Vec`s.

This exposes an alignment sublety, that buffers must be aligned to the
needed alignment of `cmsghdr`; handle this by auto-aligning the provided
buffer to the needed boundary.
  • Loading branch information
sunfishcode authored Oct 22, 2023
1 parent dd5dc44 commit ff9c7fb
Show file tree
Hide file tree
Showing 4 changed files with 691 additions and 31 deletions.
61 changes: 42 additions & 19 deletions src/net/send_recv/msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::net::UCred;

use core::iter::FusedIterator;
use core::marker::PhantomData;
use core::mem::{size_of, size_of_val, take};
use core::mem::{align_of, size_of, size_of_val, take};
use core::{ptr, slice};

use super::{RecvFlags, SendFlags, SocketAddrAny, SocketAddrV4, SocketAddrV6};
Expand Down Expand Up @@ -40,8 +40,19 @@ macro_rules! cmsg_space {
}

#[doc(hidden)]
pub fn __cmsg_space(len: usize) -> usize {
unsafe { c::CMSG_SPACE(len.try_into().expect("CMSG_SPACE size overflow")) as usize }
pub const fn __cmsg_space(len: usize) -> usize {
// Add `align_of::<c::cmsghdr>()` so that we can align the user-provided
// `&[u8]` to the required alignment boundary.
let len = len + align_of::<c::cmsghdr>();

// Convert `len` to `u32` for `CMSG_SPACE`. This would be `try_into()` if
// we could call that in a `const fn`.
let converted_len = len as u32;
if converted_len as usize != len {
unreachable!(); // `CMSG_SPACE` size overflow
}

unsafe { c::CMSG_SPACE(converted_len) as usize }
}

/// Ancillary message for [`sendmsg`], [`sendmsg_v4`], [`sendmsg_v6`],
Expand All @@ -59,19 +70,11 @@ impl SendAncillaryMessage<'_, '_> {
/// Get the maximum size of an ancillary message.
///
/// This can be helpful in determining the size of the buffer you allocate.
pub fn size(&self) -> usize {
let total_bytes = match self {
Self::ScmRights(slice) => size_of_val(*slice),
pub const fn size(&self) -> usize {
match self {
Self::ScmRights(slice) => cmsg_space!(ScmRights(slice.len())),
#[cfg(linux_kernel)]
Self::ScmCredentials(ucred) => size_of_val(ucred),
};

unsafe {
c::CMSG_SPACE(
total_bytes
.try_into()
.expect("size too large for CMSG_SPACE"),
) as usize
Self::ScmCredentials(_) => cmsg_space!(ScmCredentials(1)),
}
}
}
Expand Down Expand Up @@ -107,15 +110,20 @@ impl<'buf> From<&'buf mut [u8]> for SendAncillaryBuffer<'buf, '_, '_> {

impl Default for SendAncillaryBuffer<'_, '_, '_> {
fn default() -> Self {
Self::new(&mut [])
Self {
buffer: &mut [],
length: 0,
_phantom: PhantomData,
}
}
}

impl<'buf, 'slice, 'fd> SendAncillaryBuffer<'buf, 'slice, 'fd> {
/// Create a new, empty `SendAncillaryBuffer` from a raw byte buffer.
#[inline]
pub fn new(buffer: &'buf mut [u8]) -> Self {
Self {
buffer,
buffer: align_for_cmsghdr(buffer),
length: 0,
_phantom: PhantomData,
}
Expand Down Expand Up @@ -234,15 +242,20 @@ impl<'buf> From<&'buf mut [u8]> for RecvAncillaryBuffer<'buf> {

impl Default for RecvAncillaryBuffer<'_> {
fn default() -> Self {
Self::new(&mut [])
Self {
buffer: &mut [],
read: 0,
length: 0,
}
}
}

impl<'buf> RecvAncillaryBuffer<'buf> {
/// Create a new, empty `RecvAncillaryBuffer` from a raw byte buffer.
#[inline]
pub fn new(buffer: &'buf mut [u8]) -> Self {
Self {
buffer,
buffer: align_for_cmsghdr(buffer),
read: 0,
length: 0,
}
Expand Down Expand Up @@ -297,6 +310,16 @@ impl Drop for RecvAncillaryBuffer<'_> {
}
}

/// Return a slice of `buffer` starting at the first `cmsghdr` alignment
/// boundary.
#[inline]
fn align_for_cmsghdr(buffer: &mut [u8]) -> &mut [u8] {
let align = align_of::<c::cmsghdr>();
let addr = buffer.as_ptr() as usize;
let adjusted = (addr + (align - 1)) & align.wrapping_neg();
&mut buffer[adjusted - addr..]
}

/// An iterator that drains messages from a [`RecvAncillaryBuffer`].
pub struct AncillaryDrain<'buf> {
/// Inner iterator over messages.
Expand Down
2 changes: 2 additions & 0 deletions tests/net/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ mod poll;
mod sockopt;
#[cfg(unix)]
mod unix;
#[cfg(unix)]
mod unix_alloc;
mod v4;
mod v6;

Expand Down
24 changes: 12 additions & 12 deletions tests/net/unix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ fn server(ready: Arc<(Mutex<bool>, Condvar)>, path: &Path) {
cvar.notify_all();
}

let mut buffer = vec![0; BUFFER_SIZE];
let mut buffer = [0; BUFFER_SIZE];
'exit: loop {
let data_socket = accept(&connection_socket).unwrap();
let mut sum = 0;
Expand Down Expand Up @@ -68,7 +68,7 @@ fn client(ready: Arc<(Mutex<bool>, Condvar)>, path: &Path, runs: &[(&[&str], i32
}

let addr = SocketAddrUnix::new(path).unwrap();
let mut buffer = vec![0; BUFFER_SIZE];
let mut buffer = [0; BUFFER_SIZE];

for (args, sum) in runs {
let data_socket = socket(AddressFamily::UNIX, SocketType::SEQPACKET, None).unwrap();
Expand Down Expand Up @@ -136,7 +136,7 @@ fn do_test_unix_msg(addr: SocketAddrUnix) {
listen(&connection_socket, 1).unwrap();

move || {
let mut buffer = vec![0; BUFFER_SIZE];
let mut buffer = [0; BUFFER_SIZE];
'exit: loop {
let data_socket = accept(&connection_socket).unwrap();
let mut sum = 0;
Expand Down Expand Up @@ -173,7 +173,7 @@ fn do_test_unix_msg(addr: SocketAddrUnix) {
};

let client = move || {
let mut buffer = vec![0; BUFFER_SIZE];
let mut buffer = [0; BUFFER_SIZE];
let runs: &[(&[&str], i32)] = &[
(&["1", "2"], 3),
(&["4", "77", "103"], 184),
Expand Down Expand Up @@ -266,7 +266,7 @@ fn do_test_unix_msg_unconnected(addr: SocketAddrUnix) {
bind_unix(&data_socket, &addr).unwrap();

move || {
let mut buffer = vec![0; BUFFER_SIZE];
let mut buffer = [0; BUFFER_SIZE];
for expected_sum in runs {
let mut sum = 0;
loop {
Expand Down Expand Up @@ -434,8 +434,8 @@ fn test_unix_msg_with_scm_rights() {
move || {
let mut pipe_end = None;

let mut buffer = vec![0; BUFFER_SIZE];
let mut cmsg_space = vec![0; rustix::cmsg_space!(ScmRights(1))];
let mut buffer = [0; BUFFER_SIZE];
let mut cmsg_space = [0; rustix::cmsg_space!(ScmRights(1))];

'exit: loop {
let data_socket = accept(&connection_socket).unwrap();
Expand Down Expand Up @@ -495,7 +495,7 @@ fn test_unix_msg_with_scm_rights() {
let client = move || {
let addr = SocketAddrUnix::new(path).unwrap();
let (read_end, write_end) = pipe().unwrap();
let mut buffer = vec![0; BUFFER_SIZE];
let mut buffer = [0; BUFFER_SIZE];
let runs: &[(&[&str], i32)] = &[
(&["1", "2"], 3),
(&["4", "77", "103"], 184),
Expand Down Expand Up @@ -543,7 +543,7 @@ fn test_unix_msg_with_scm_rights() {
// Format the CMSG.
let we = [write_end.as_fd()];
let msg = SendAncillaryMessage::ScmRights(&we);
let mut space = vec![0; msg.size()];
let mut space = [0; rustix::cmsg_space!(ScmRights(1))];
let mut cmsg_buffer = SendAncillaryBuffer::new(&mut space);
assert!(cmsg_buffer.push(msg));

Expand Down Expand Up @@ -606,7 +606,7 @@ fn test_unix_peercred() {
assert_eq!(ucred.gid, getgid());

let msg = SendAncillaryMessage::ScmCredentials(ucred);
let mut space = vec![0; msg.size()];
let mut space = [0; rustix::cmsg_space!(ScmCredentials(1))];
let mut cmsg_buffer = SendAncillaryBuffer::new(&mut space);
assert!(cmsg_buffer.push(msg));

Expand All @@ -618,10 +618,10 @@ fn test_unix_peercred() {
)
.unwrap();

let mut cmsg_space = vec![0; rustix::cmsg_space!(ScmCredentials(1))];
let mut cmsg_space = [0; rustix::cmsg_space!(ScmCredentials(1))];
let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space);

let mut buffer = vec![0; BUFFER_SIZE];
let mut buffer = [0; BUFFER_SIZE];
recvmsg(
&recv_sock,
&mut [IoSliceMut::new(&mut buffer)],
Expand Down
Loading

0 comments on commit ff9c7fb

Please sign in to comment.