Skip to content

Commit

Permalink
fix(driver, net): get control message length from RecvMsg
Browse files Browse the repository at this point in the history
  • Loading branch information
AsakuraMizu committed Aug 3, 2024
1 parent 84ed77b commit d41a42f
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 29 deletions.
13 changes: 11 additions & 2 deletions compio-driver/src/iocp/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,7 @@ pub struct RecvMsg<T: IoVectoredBufMut, C: IoBufMut, S> {
fd: SharedFd<S>,
buffer: T,
control: C,
control_len: u32,
_p: PhantomPinned,
}

Expand All @@ -806,16 +807,22 @@ impl<T: IoVectoredBufMut, C: IoBufMut, S> RecvMsg<T, C, S> {
fd,
buffer,
control,
control_len: 0,
_p: PhantomPinned,
}
}
}

impl<T: IoVectoredBufMut, C: IoBufMut, S> IntoInner for RecvMsg<T, C, S> {
type Inner = ((T, C), SOCKADDR_STORAGE, socklen_t);
type Inner = ((T, C), SOCKADDR_STORAGE, socklen_t, usize);

fn into_inner(self) -> Self::Inner {
((self.buffer, self.control), self.addr, self.addr_len)
(
(self.buffer, self.control),
self.addr,
self.addr_len,
self.control_len as _,
)
}
}

Expand All @@ -837,6 +844,7 @@ impl<T: IoVectoredBufMut, C: IoBufMut, S: AsRawFd> OpCode for RecvMsg<T, C, S> {
Control: std::mem::transmute::<IoSliceMut, WSABUF>(this.control.as_io_slice_mut()),
dwFlags: 0,
};
this.control_len = 0;

let mut received = 0;
let res = recvmsg_fn(
Expand All @@ -846,6 +854,7 @@ impl<T: IoVectoredBufMut, C: IoBufMut, S: AsRawFd> OpCode for RecvMsg<T, C, S> {
optr,
None,
);
this.control_len = msg.Control.len;
winsock_result(res, received)
}

Expand Down
38 changes: 31 additions & 7 deletions compio-driver/src/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,47 @@ impl<T: SetBufInit, O> BufResultExt for BufResult<(usize, O), T> {
}
}

/// Helper trait for [`RecvFrom`] and [`RecvFromVectored`].
impl<T: SetBufInit, C: SetBufInit, O> BufResultExt for BufResult<(usize, usize, O), (T, C)> {
fn map_advanced(self) -> Self {
self.map(
|(init_buffer, init_control, obj), (mut buffer, mut control)| {
unsafe {
buffer.set_buf_init(init_buffer);
control.set_buf_init(init_control);
}
((init_buffer, init_control, obj), (buffer, control))
},
)
}
}

/// Helper trait for [`RecvFrom`], [`RecvFromVectored`] and [`RecvMsg`].
pub trait RecvResultExt {
/// The mapped result.
type RecvFromResult;
type RecvResult;

/// Create [`SockAddr`] if the result is [`Ok`].
fn map_addr(self) -> Self::RecvFromResult;
fn map_addr(self) -> Self::RecvResult;
}

impl<T> RecvResultExt for BufResult<usize, (T, sockaddr_storage, socklen_t)> {
type RecvFromResult = BufResult<(usize, SockAddr), T>;
type RecvResult = BufResult<(usize, SockAddr), T>;

fn map_addr(self) -> Self::RecvResult {
self.map_buffer(|(buffer, addr_buffer, addr_size)| (buffer, addr_buffer, addr_size, 0))
.map_addr()
.map_res(|(res, _, addr)| (res, addr))
}
}

impl<T> RecvResultExt for BufResult<usize, (T, sockaddr_storage, socklen_t, usize)> {
type RecvResult = BufResult<(usize, usize, SockAddr), T>;

fn map_addr(self) -> Self::RecvFromResult {
fn map_addr(self) -> Self::RecvResult {
self.map2(
|res, (buffer, addr_buffer, addr_size)| {
|res, (buffer, addr_buffer, addr_size, len)| {
let addr = unsafe { SockAddr::new(addr_buffer, addr_size) };
((res, addr), buffer)
((res, len, addr), buffer)
},
|(buffer, ..)| buffer,
)
Expand Down
11 changes: 8 additions & 3 deletions compio-driver/src/unix/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,15 +411,20 @@ impl<T: IoVectoredBufMut, C: IoBufMut, S> RecvMsg<T, C, S> {
self.msg.msg_iov = self.slices.as_mut_ptr() as _;
self.msg.msg_iovlen = self.slices.len() as _;
self.msg.msg_control = self.control.as_buf_mut_ptr() as _;
self.msg.msg_controllen = self.control.buf_len() as _;
self.msg.msg_controllen = self.control.buf_capacity() as _;
}
}

impl<T: IoVectoredBufMut, C: IoBufMut, S> IntoInner for RecvMsg<T, C, S> {
type Inner = ((T, C), sockaddr_storage, socklen_t);
type Inner = ((T, C), sockaddr_storage, socklen_t, usize);

fn into_inner(self) -> Self::Inner {
((self.buffer, self.control), self.addr, self.msg.msg_namelen)
(
(self.buffer, self.control),
self.addr,
self.msg.msg_namelen,
self.msg.msg_controllen as _,
)
}
}

Expand Down
12 changes: 3 additions & 9 deletions compio-net/src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ impl Socket {
&self,
buffer: T,
control: C,
) -> BufResult<(usize, SockAddr), (T, C)> {
) -> BufResult<(usize, usize, SockAddr), (T, C)> {
self.recv_msg_vectored([buffer], control)
.await
.map_buffer(|([buffer], control)| (buffer, control))
Expand All @@ -271,20 +271,14 @@ impl Socket {
&self,
buffer: T,
control: C,
) -> BufResult<(usize, SockAddr), (T, C)> {
) -> BufResult<(usize, usize, SockAddr), (T, C)> {
let fd = self.to_shared_fd();
let op = RecvMsg::new(fd, buffer, control);
compio_runtime::submit(op)
.await
.into_inner()
.map_addr()
.map(|(init, obj), (mut buffer, control)| {
// SAFETY: The number of bytes received would not bypass the buffer capacity.
unsafe {
buffer.set_buf_init(init);
}
((init, obj), (buffer, control))
})
.map_advanced()
}

pub async fn send_to<T: IoBuf>(&self, buffer: T, addr: &SockAddr) -> BufResult<usize, T> {
Expand Down
8 changes: 4 additions & 4 deletions compio-net/src/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,11 @@ impl UdpSocket {
&self,
buffer: T,
control: C,
) -> BufResult<(usize, SocketAddr), (T, C)> {
) -> BufResult<(usize, usize, SocketAddr), (T, C)> {
self.inner
.recv_msg(buffer, control)
.await
.map_res(|(n, addr)| (n, addr.as_socket().expect("should be SocketAddr")))
.map_res(|(n, m, addr)| (n, m, addr.as_socket().expect("should be SocketAddr")))
}

/// Receives a single datagram message and ancillary data on the socket. On
Expand All @@ -241,11 +241,11 @@ impl UdpSocket {
&self,
buffer: T,
control: C,
) -> BufResult<(usize, SocketAddr), (T, C)> {
) -> BufResult<(usize, usize, SocketAddr), (T, C)> {
self.inner
.recv_msg_vectored(buffer, control)
.await
.map_res(|(n, addr)| (n, addr.as_socket().expect("should be SocketAddr")))
.map_res(|(n, m, addr)| (n, m, addr.as_socket().expect("should be SocketAddr")))
}

/// Sends data on the socket to the given address. On success, returns the
Expand Down
12 changes: 8 additions & 4 deletions compio-net/tests/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,18 @@ async fn send_msg_with_ipv6_ecn() {

active.send_msg(MSG, control, passive_addr).await.unwrap();

let res = passive.recv_msg(Vec::with_capacity(20), [0u8; 32]).await;
assert_eq!(res.0.unwrap().1, active_addr);
assert_eq!(res.1.0, MSG.as_bytes());
let ((_, _, addr), (buffer, control)) = passive
.recv_msg(Vec::with_capacity(20), Vec::with_capacity(32))
.await
.unwrap();
assert_eq!(addr, active_addr);
assert_eq!(buffer, MSG.as_bytes());
unsafe {
let mut iter = CMsgIter::new(&res.1.1);
let mut iter = CMsgIter::new(&control);
let cmsg = iter.next().unwrap();
assert_eq!(cmsg.level(), IPPROTO_IPV6);
assert_eq!(cmsg.ty(), IPV6_TCLASS);
assert_eq!(cmsg.data::<i32>(), &ECN_BITS);
assert!(iter.next().is_none());
}
}

0 comments on commit d41a42f

Please sign in to comment.