From d41a42f55aa047cccf0be10c907d6c9c475dad28 Mon Sep 17 00:00:00 2001 From: Asakura Mizu Date: Thu, 1 Aug 2024 01:26:27 +0800 Subject: [PATCH 1/2] fix(driver, net): get control message length from `RecvMsg` --- compio-driver/src/iocp/op.rs | 13 ++++++++++-- compio-driver/src/op.rs | 38 +++++++++++++++++++++++++++++------- compio-driver/src/unix/op.rs | 11 ++++++++--- compio-net/src/socket.rs | 12 +++--------- compio-net/src/udp.rs | 8 ++++---- compio-net/tests/udp.rs | 12 ++++++++---- 6 files changed, 65 insertions(+), 29 deletions(-) diff --git a/compio-driver/src/iocp/op.rs b/compio-driver/src/iocp/op.rs index 705971b8..cf2b2740 100644 --- a/compio-driver/src/iocp/op.rs +++ b/compio-driver/src/iocp/op.rs @@ -786,6 +786,7 @@ pub struct RecvMsg { fd: SharedFd, buffer: T, control: C, + control_len: u32, _p: PhantomPinned, } @@ -806,16 +807,22 @@ impl RecvMsg { fd, buffer, control, + control_len: 0, _p: PhantomPinned, } } } impl IntoInner for RecvMsg { - type Inner = ((T, C), SOCKADDR_STORAGE, socklen_t); + type Inner = ((T, C), SOCKADDR_STORAGE, socklen_t, usize); fn into_inner(self) -> Self::Inner { - ((self.buffer, self.control), self.addr, self.addr_len) + ( + (self.buffer, self.control), + self.addr, + self.addr_len, + self.control_len as _, + ) } } @@ -837,6 +844,7 @@ impl OpCode for RecvMsg { Control: std::mem::transmute::(this.control.as_io_slice_mut()), dwFlags: 0, }; + this.control_len = 0; let mut received = 0; let res = recvmsg_fn( @@ -846,6 +854,7 @@ impl OpCode for RecvMsg { optr, None, ); + this.control_len = msg.Control.len; winsock_result(res, received) } diff --git a/compio-driver/src/op.rs b/compio-driver/src/op.rs index 89b60d10..743a6c5f 100644 --- a/compio-driver/src/op.rs +++ b/compio-driver/src/op.rs @@ -49,23 +49,47 @@ impl BufResultExt for BufResult<(usize, O), T> { } } -/// Helper trait for [`RecvFrom`] and [`RecvFromVectored`]. +impl BufResultExt for BufResult<(usize, usize, O), (T, C)> { + fn map_advanced(self) -> Self { + self.map( + |(init_buffer, init_control, obj), (mut buffer, mut control)| { + unsafe { + buffer.set_buf_init(init_buffer); + control.set_buf_init(init_control); + } + ((init_buffer, init_control, obj), (buffer, control)) + }, + ) + } +} + +/// Helper trait for [`RecvFrom`], [`RecvFromVectored`] and [`RecvMsg`]. pub trait RecvResultExt { /// The mapped result. - type RecvFromResult; + type RecvResult; /// Create [`SockAddr`] if the result is [`Ok`]. - fn map_addr(self) -> Self::RecvFromResult; + fn map_addr(self) -> Self::RecvResult; } impl RecvResultExt for BufResult { - type RecvFromResult = BufResult<(usize, SockAddr), T>; + type RecvResult = BufResult<(usize, SockAddr), T>; + + fn map_addr(self) -> Self::RecvResult { + self.map_buffer(|(buffer, addr_buffer, addr_size)| (buffer, addr_buffer, addr_size, 0)) + .map_addr() + .map_res(|(res, _, addr)| (res, addr)) + } +} + +impl RecvResultExt for BufResult { + type RecvResult = BufResult<(usize, usize, SockAddr), T>; - fn map_addr(self) -> Self::RecvFromResult { + fn map_addr(self) -> Self::RecvResult { self.map2( - |res, (buffer, addr_buffer, addr_size)| { + |res, (buffer, addr_buffer, addr_size, len)| { let addr = unsafe { SockAddr::new(addr_buffer, addr_size) }; - ((res, addr), buffer) + ((res, len, addr), buffer) }, |(buffer, ..)| buffer, ) diff --git a/compio-driver/src/unix/op.rs b/compio-driver/src/unix/op.rs index aff5899a..2d14bc1b 100644 --- a/compio-driver/src/unix/op.rs +++ b/compio-driver/src/unix/op.rs @@ -411,15 +411,20 @@ impl RecvMsg { self.msg.msg_iov = self.slices.as_mut_ptr() as _; self.msg.msg_iovlen = self.slices.len() as _; self.msg.msg_control = self.control.as_buf_mut_ptr() as _; - self.msg.msg_controllen = self.control.buf_len() as _; + self.msg.msg_controllen = self.control.buf_capacity() as _; } } impl IntoInner for RecvMsg { - type Inner = ((T, C), sockaddr_storage, socklen_t); + type Inner = ((T, C), sockaddr_storage, socklen_t, usize); fn into_inner(self) -> Self::Inner { - ((self.buffer, self.control), self.addr, self.msg.msg_namelen) + ( + (self.buffer, self.control), + self.addr, + self.msg.msg_namelen, + self.msg.msg_controllen as _, + ) } } diff --git a/compio-net/src/socket.rs b/compio-net/src/socket.rs index 0f84dd08..d7bc5a8d 100644 --- a/compio-net/src/socket.rs +++ b/compio-net/src/socket.rs @@ -261,7 +261,7 @@ impl Socket { &self, buffer: T, control: C, - ) -> BufResult<(usize, SockAddr), (T, C)> { + ) -> BufResult<(usize, usize, SockAddr), (T, C)> { self.recv_msg_vectored([buffer], control) .await .map_buffer(|([buffer], control)| (buffer, control)) @@ -271,20 +271,14 @@ impl Socket { &self, buffer: T, control: C, - ) -> BufResult<(usize, SockAddr), (T, C)> { + ) -> BufResult<(usize, usize, SockAddr), (T, C)> { let fd = self.to_shared_fd(); let op = RecvMsg::new(fd, buffer, control); compio_runtime::submit(op) .await .into_inner() .map_addr() - .map(|(init, obj), (mut buffer, control)| { - // SAFETY: The number of bytes received would not bypass the buffer capacity. - unsafe { - buffer.set_buf_init(init); - } - ((init, obj), (buffer, control)) - }) + .map_advanced() } pub async fn send_to(&self, buffer: T, addr: &SockAddr) -> BufResult { diff --git a/compio-net/src/udp.rs b/compio-net/src/udp.rs index 1063eed9..13e59d73 100644 --- a/compio-net/src/udp.rs +++ b/compio-net/src/udp.rs @@ -228,11 +228,11 @@ impl UdpSocket { &self, buffer: T, control: C, - ) -> BufResult<(usize, SocketAddr), (T, C)> { + ) -> BufResult<(usize, usize, SocketAddr), (T, C)> { self.inner .recv_msg(buffer, control) .await - .map_res(|(n, addr)| (n, addr.as_socket().expect("should be SocketAddr"))) + .map_res(|(n, m, addr)| (n, m, addr.as_socket().expect("should be SocketAddr"))) } /// Receives a single datagram message and ancillary data on the socket. On @@ -241,11 +241,11 @@ impl UdpSocket { &self, buffer: T, control: C, - ) -> BufResult<(usize, SocketAddr), (T, C)> { + ) -> BufResult<(usize, usize, SocketAddr), (T, C)> { self.inner .recv_msg_vectored(buffer, control) .await - .map_res(|(n, addr)| (n, addr.as_socket().expect("should be SocketAddr"))) + .map_res(|(n, m, addr)| (n, m, addr.as_socket().expect("should be SocketAddr"))) } /// Sends data on the socket to the given address. On success, returns the diff --git a/compio-net/tests/udp.rs b/compio-net/tests/udp.rs index fe7dc263..d813290e 100644 --- a/compio-net/tests/udp.rs +++ b/compio-net/tests/udp.rs @@ -103,14 +103,18 @@ async fn send_msg_with_ipv6_ecn() { active.send_msg(MSG, control, passive_addr).await.unwrap(); - let res = passive.recv_msg(Vec::with_capacity(20), [0u8; 32]).await; - assert_eq!(res.0.unwrap().1, active_addr); - assert_eq!(res.1.0, MSG.as_bytes()); + let ((_, _, addr), (buffer, control)) = passive + .recv_msg(Vec::with_capacity(20), Vec::with_capacity(32)) + .await + .unwrap(); + assert_eq!(addr, active_addr); + assert_eq!(buffer, MSG.as_bytes()); unsafe { - let mut iter = CMsgIter::new(&res.1.1); + let mut iter = CMsgIter::new(&control); let cmsg = iter.next().unwrap(); assert_eq!(cmsg.level(), IPPROTO_IPV6); assert_eq!(cmsg.ty(), IPV6_TCLASS); assert_eq!(cmsg.data::(), &ECN_BITS); + assert!(iter.next().is_none()); } } From 970f328302e6eca859d9ab2d80cc8322c08dbbac Mon Sep 17 00:00:00 2001 From: Asakura Mizu Date: Sat, 3 Aug 2024 21:58:25 +0800 Subject: [PATCH 2/2] fix(driver): `IoVectoredBuf` usage --- compio-driver/src/iocp/op.rs | 4 ++-- compio-driver/src/unix/op.rs | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/compio-driver/src/iocp/op.rs b/compio-driver/src/iocp/op.rs index cf2b2740..8369b234 100644 --- a/compio-driver/src/iocp/op.rs +++ b/compio-driver/src/iocp/op.rs @@ -835,7 +835,7 @@ impl OpCode for RecvMsg { })?; let this = self.get_unchecked_mut(); - let mut slices = this.buffer.as_io_slices_mut(); + let mut slices = this.buffer.io_slices_mut(); let mut msg = WSAMSG { name: &mut this.addr as *mut _ as _, namelen: this.addr_len, @@ -906,7 +906,7 @@ impl OpCode for SendMsg { unsafe fn operate(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> Poll> { let this = self.get_unchecked_mut(); - let slices = this.buffer.as_io_slices(); + let slices = this.buffer.io_slices(); let msg = WSAMSG { name: this.addr.as_ptr() as _, namelen: this.addr.len(), diff --git a/compio-driver/src/unix/op.rs b/compio-driver/src/unix/op.rs index 2d14bc1b..e22dee4d 100644 --- a/compio-driver/src/unix/op.rs +++ b/compio-driver/src/unix/op.rs @@ -404,7 +404,7 @@ impl RecvMsg { } pub(crate) unsafe fn set_msg(&mut self) { - self.slices = self.buffer.as_io_slices_mut(); + self.slices = self.buffer.io_slices_mut(); self.msg.msg_name = std::ptr::addr_of_mut!(self.addr) as _; self.msg.msg_namelen = std::mem::size_of_val(&self.addr) as _; @@ -463,7 +463,7 @@ impl SendMsg { } pub(crate) unsafe fn set_msg(&mut self) { - self.slices = self.buffer.as_io_slices(); + self.slices = self.buffer.io_slices(); self.msg.msg_name = self.addr.as_ptr() as _; self.msg.msg_namelen = self.addr.len();