Skip to content

Commit

Permalink
Don't assume memory layout of std::net::SocketAddr
Browse files Browse the repository at this point in the history
Fixes #1388.
  • Loading branch information
faern committed Nov 7, 2020
1 parent 13e82ce commit 152e075
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 46 deletions.
104 changes: 86 additions & 18 deletions src/sys/unix/net.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::net::SocketAddr;
use std::mem;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};

pub(crate) fn new_ip_socket(
addr: SocketAddr,
Expand Down Expand Up @@ -64,32 +65,99 @@ pub(crate) fn new_socket(
socket
}

pub(crate) fn socket_addr(addr: &SocketAddr) -> (*const libc::sockaddr, libc::socklen_t) {
use std::mem::size_of_val;
/// A type with the same memory layout as `libc::sockaddr`. Used in converting Rust level
/// SocketAddr* types into their system representation. The benefit of this specific
/// type over using `libc::sockaddr_storage` is that this type is exactly as large as it
/// needs to be and not a lot larger. And it can be initialized cleaner from Rust.
#[repr(C)]
pub(crate) union SocketAddrCRepr {
v4: libc::sockaddr_in,
v6: libc::sockaddr_in6,
}

impl SocketAddrCRepr {
pub(crate) fn as_ptr(&self) -> *const libc::sockaddr {
self as *const _ as *const libc::sockaddr
}
}

/// Converts a Rust `SocketAddr` into the system representation.
pub(crate) fn socket_addr(addr: &SocketAddr) -> (SocketAddrCRepr, libc::socklen_t) {
match addr {
SocketAddr::V4(ref addr) => (
addr as *const _ as *const libc::sockaddr,
size_of_val(addr) as libc::socklen_t,
),
SocketAddr::V6(ref addr) => (
addr as *const _ as *const libc::sockaddr,
size_of_val(addr) as libc::socklen_t,
),
SocketAddr::V4(ref addr) => {
// `s_addr` is stored as BE on all machine and the array is in BE order.
// So the native endian conversion method is used so that it's never swapped.
let sin_addr = libc::in_addr { s_addr: u32::from_ne_bytes(addr.ip().octets()) };

let sockaddr_in = libc::sockaddr_in {
sin_family: libc::AF_INET as libc::sa_family_t,
sin_port: addr.port().to_be(),
sin_addr,
sin_zero: [0; 8],
#[cfg(any(
target_os = "dragonfly",
target_os = "freebsd",
target_os = "ios",
target_os = "macos",
target_os = "netbsd",
target_os = "openbsd"
))]
sin_len: 0,
};

let sockaddr = SocketAddrCRepr { v4: sockaddr_in };
(sockaddr, mem::size_of::<libc::sockaddr_in>() as libc::socklen_t)
}
SocketAddr::V6(ref addr) => {
let sockaddr_in6 = libc::sockaddr_in6 {
sin6_family: libc::AF_INET6 as libc::sa_family_t,
sin6_port: addr.port().to_be(),
sin6_addr: libc::in6_addr { s6_addr: addr.ip().octets() },
sin6_flowinfo: addr.flowinfo(),
sin6_scope_id: addr.scope_id(),
#[cfg(any(
target_os = "dragonfly",
target_os = "freebsd",
target_os = "ios",
target_os = "macos",
target_os = "netbsd",
target_os = "openbsd"
))]
sin6_len: 0,
#[cfg(any(target_os = "solaris", target_os = "illumos"))]
__sin6_src_id: 0,
};

let sockaddr = SocketAddrCRepr { v6: sockaddr_in6 };
(sockaddr, mem::size_of::<libc::sockaddr_in6>() as libc::socklen_t)
}
}
}

/// `storage` must be initialised to `sockaddr_in` or `sockaddr_in6`.
/// Converts a `libc::sockaddr` compatible struct into a native Rust `SocketAddr`.
///
/// # Safety
///
/// `storage` must have the `ss_family` field correctly initialized.
/// `storage` must be initialised to a `sockaddr_in` or `sockaddr_in6`.
pub(crate) unsafe fn to_socket_addr(
storage: *const libc::sockaddr_storage,
) -> std::io::Result<SocketAddr> {
match (*storage).ss_family as libc::c_int {
libc::AF_INET => Ok(SocketAddr::V4(
*(storage as *const libc::sockaddr_in as *const _),
)),
libc::AF_INET6 => Ok(SocketAddr::V6(
*(storage as *const libc::sockaddr_in6 as *const _),
)),
libc::AF_INET => {
// Safety: if the ss_family field is AF_INET then storage must be a sockaddr_in.
let addr: &libc::sockaddr_in = &*(storage as *const libc::sockaddr_in);
let ip = Ipv4Addr::from(addr.sin_addr.s_addr.to_ne_bytes());
let port = u16::from_be(addr.sin_port);
Ok(SocketAddr::V4(SocketAddrV4::new(ip, port)))
},
libc::AF_INET6 => {
// Safety: if the ss_family field is AF_INET6 then storage must be a sockaddr_in6.
let addr: &libc::sockaddr_in6 = &*(storage as *const libc::sockaddr_in6);
let ip = Ipv6Addr::from(addr.sin6_addr.s6_addr);
let port = u16::from_be(addr.sin6_port);
Ok(SocketAddr::V6(SocketAddrV6::new(ip, port, addr.sin6_flowinfo, addr.sin6_scope_id)))
},
_ => Err(std::io::ErrorKind::InvalidInput.into()),
}
}
4 changes: 2 additions & 2 deletions src/sys/unix/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ pub(crate) fn new_v6_socket() -> io::Result<TcpSocket> {

pub(crate) fn bind(socket: TcpSocket, addr: SocketAddr) -> io::Result<()> {
let (raw_addr, raw_addr_length) = socket_addr(&addr);
syscall!(bind(socket, raw_addr, raw_addr_length))?;
syscall!(bind(socket, raw_addr.as_ptr(), raw_addr_length))?;
Ok(())
}

pub(crate) fn connect(socket: TcpSocket, addr: SocketAddr) -> io::Result<net::TcpStream> {
let (raw_addr, raw_addr_length) = socket_addr(&addr);

match syscall!(connect(socket, raw_addr, raw_addr_length)) {
match syscall!(connect(socket, raw_addr.as_ptr(), raw_addr_length)) {
Err(err) if err.raw_os_error() != Some(libc::EINPROGRESS) => {
Err(err)
}
Expand Down
2 changes: 1 addition & 1 deletion src/sys/unix/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub fn bind(addr: SocketAddr) -> io::Result<net::UdpSocket> {

socket.and_then(|socket| {
let (raw_addr, raw_addr_length) = socket_addr(&addr);
syscall!(bind(socket, raw_addr, raw_addr_length))
syscall!(bind(socket, raw_addr.as_ptr(), raw_addr_length))
.map_err(|err| {
// Close the socket if we hit an error, ignoring the error
// from closing since we can't pass back two errors.
Expand Down
75 changes: 64 additions & 11 deletions src/sys/windows/net.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use std::io;
use std::mem::size_of_val;
use std::mem;
use std::net::SocketAddr;
use std::sync::Once;

use winapi::ctypes::c_int;
use winapi::shared::ws2def::SOCKADDR;
use winapi::shared::inaddr::{in_addr_S_un, IN_ADDR};
use winapi::shared::in6addr::{in6_addr_u, IN6_ADDR};
use winapi::shared::ws2def::{AF_INET, AF_INET6, ADDRESS_FAMILY, SOCKADDR, SOCKADDR_IN};
use winapi::shared::ws2ipdef::{SOCKADDR_IN6_LH, SOCKADDR_IN6_LH_u};
use winapi::um::winsock2::{ioctlsocket, socket, FIONBIO, INVALID_SOCKET, SOCKET};

/// Initialise the network stack for Windows.
Expand Down Expand Up @@ -41,15 +44,65 @@ pub(crate) fn new_socket(domain: c_int, socket_type: c_int) -> io::Result<SOCKET
})
}

pub(crate) fn socket_addr(addr: &SocketAddr) -> (*const SOCKADDR, c_int) {
/// A type with the same memory layout as `SOCKADDR`. Used in converting Rust level
/// SocketAddr* types into their system representation. The benefit of this specific
/// type over using `SOCKADDR_STORAGE` is that this type is exactly as large as it
/// needs to be and not a lot larger. And it can be initialized cleaner from Rust.
#[repr(C)]
pub(crate) union SocketAddrCRepr {
v4: SOCKADDR_IN,
v6: SOCKADDR_IN6_LH,
}

impl SocketAddrCRepr {
pub(crate) fn as_ptr(&self) -> *const SOCKADDR {
self as *const _ as *const SOCKADDR
}
}

pub(crate) fn socket_addr(addr: &SocketAddr) -> (SocketAddrCRepr, c_int) {
match addr {
SocketAddr::V4(ref addr) => (
addr as *const _ as *const SOCKADDR,
size_of_val(addr) as c_int,
),
SocketAddr::V6(ref addr) => (
addr as *const _ as *const SOCKADDR,
size_of_val(addr) as c_int,
),
SocketAddr::V4(ref addr) => {
// `s_addr` is stored as BE on all machine and the array is in BE order.
// So the native endian conversion method is used so that it's never swapped.
let sin_addr = unsafe {
let mut s_un = mem::zeroed::<in_addr_S_un>();
*s_un.S_addr_mut() = u32::from_ne_bytes(addr.ip().octets());
IN_ADDR { S_un: s_un }
};

let sockaddr_in = SOCKADDR_IN {
sin_family: AF_INET as ADDRESS_FAMILY,
sin_port: addr.port().to_be(),
sin_addr,
sin_zero: [0; 8],
};

let sockaddr = SocketAddrCRepr { v4: sockaddr_in };
(sockaddr, mem::size_of::<SOCKADDR_IN>() as c_int)
},
SocketAddr::V6(ref addr) => {
let sin6_addr = unsafe {
let mut u = mem::zeroed::<in6_addr_u>();
*u.Byte_mut() = addr.ip().octets();
IN6_ADDR { u }
};
let u = unsafe {
let mut u = mem::zeroed::<SOCKADDR_IN6_LH_u>();
*u.sin6_scope_id_mut() = addr.scope_id();
u
};

let sockaddr_in6 = SOCKADDR_IN6_LH {
sin6_family: AF_INET6 as ADDRESS_FAMILY,
sin6_port: addr.port().to_be(),
sin6_addr,
sin6_flowinfo: addr.flowinfo(),
u,
};

let sockaddr = SocketAddrCRepr { v6: sockaddr_in6 };
(sockaddr, mem::size_of::<SOCKADDR_IN6_LH>() as c_int)
}
}
}
35 changes: 22 additions & 13 deletions src/sys/windows/tcp.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use std::io;
use std::mem::size_of;
use std::net::{self, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::net::{self, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::time::Duration;
use std::convert::TryInto;
use std::os::windows::io::FromRawSocket;
use std::os::windows::raw::SOCKET as StdSocket; // winapi uses usize, stdlib uses u32/u64.

use winapi::ctypes::{c_char, c_int, c_ushort};
use winapi::shared::ws2def::{SOCKADDR_STORAGE, AF_INET, SOCKADDR_IN};
use winapi::shared::ws2def::{SOCKADDR_STORAGE, AF_INET, AF_INET6, SOCKADDR_IN};
use winapi::shared::ws2ipdef::SOCKADDR_IN6_LH;

use winapi::shared::minwindef::{BOOL, TRUE, FALSE};
Expand Down Expand Up @@ -35,7 +35,7 @@ pub(crate) fn bind(socket: TcpSocket, addr: SocketAddr) -> io::Result<()> {

let (raw_addr, raw_addr_length) = socket_addr(&addr);
syscall!(
bind(socket, raw_addr, raw_addr_length),
bind(socket, raw_addr.as_ptr(), raw_addr_length),
PartialEq::eq,
SOCKET_ERROR
)?;
Expand All @@ -48,7 +48,7 @@ pub(crate) fn connect(socket: TcpSocket, addr: SocketAddr) -> io::Result<net::Tc
let (raw_addr, raw_addr_length) = socket_addr(&addr);

let res = syscall!(
connect(socket, raw_addr, raw_addr_length),
connect(socket, raw_addr.as_ptr(), raw_addr_length),
PartialEq::eq,
SOCKET_ERROR
);
Expand Down Expand Up @@ -108,23 +108,32 @@ pub(crate) fn get_reuseaddr(socket: TcpSocket) -> io::Result<bool> {
}

pub(crate) fn get_localaddr(socket: TcpSocket) -> io::Result<SocketAddr> {
let mut addr: SOCKADDR_STORAGE = unsafe { std::mem::zeroed() };
let mut length = std::mem::size_of_val(&addr) as c_int;
let mut storage: SOCKADDR_STORAGE = unsafe { std::mem::zeroed() };
let mut length = std::mem::size_of_val(&storage) as c_int;

match unsafe { getsockname(
socket,
&mut addr as *mut _ as *mut _,
&mut storage as *mut _ as *mut _,
&mut length
) } {
SOCKET_ERROR => Err(io::Error::last_os_error()),
_ => {
let storage: *const SOCKADDR_STORAGE = (&addr) as *const _;
if addr.ss_family as c_int == AF_INET {
let sock_addr : SocketAddrV4 = unsafe { *(storage as *const SOCKADDR_IN as *const _) };
Ok(sock_addr.into())
if storage.ss_family as c_int == AF_INET {
// Safety: if the ss_family field is AF_INET then storage must be a sockaddr_in.
let addr: &SOCKADDR_IN = unsafe { &*(&storage as *const _ as *const SOCKADDR_IN) };
let ip_bytes = unsafe { addr.sin_addr.S_un.S_un_b() };
let ip = Ipv4Addr::from([ip_bytes.s_b1, ip_bytes.s_b2, ip_bytes.s_b3, ip_bytes.s_b4]);
let port = u16::from_be(addr.sin_port);
Ok(SocketAddr::V4(SocketAddrV4::new(ip, port)))
} else if storage.ss_family as c_int == AF_INET6 {
// Safety: if the ss_family field is AF_INET6 then storage must be a sockaddr_in6.
let addr: &SOCKADDR_IN6_LH = unsafe { &*(&storage as *const _ as *const SOCKADDR_IN6_LH) };
let ip = Ipv6Addr::from(*unsafe { addr.sin6_addr.u.Byte() });
let port = u16::from_be(addr.sin6_port);
let scope_id = unsafe { *addr.u.sin6_scope_id() };
Ok(SocketAddr::V6(SocketAddrV6::new(ip, port, addr.sin6_flowinfo, scope_id)))
} else {
let sock_addr : SocketAddrV6 = unsafe { *(storage as *const SOCKADDR_IN6_LH as *const _) };
Ok(sock_addr.into())
Err(std::io::ErrorKind::InvalidInput.into())
}
},
}
Expand Down
2 changes: 1 addition & 1 deletion src/sys/windows/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub fn bind(addr: SocketAddr) -> io::Result<net::UdpSocket> {
new_ip_socket(addr, SOCK_DGRAM).and_then(|socket| {
let (raw_addr, raw_addr_length) = socket_addr(&addr);
syscall!(
win_bind(socket, raw_addr, raw_addr_length,),
win_bind(socket, raw_addr.as_ptr(), raw_addr_length,),
PartialEq::eq,
SOCKET_ERROR
)
Expand Down

0 comments on commit 152e075

Please sign in to comment.