diff --git a/compio-driver/src/iocp/op.rs b/compio-driver/src/iocp/op.rs index 705971b8..cf2b2740 100644 --- a/compio-driver/src/iocp/op.rs +++ b/compio-driver/src/iocp/op.rs @@ -786,6 +786,7 @@ pub struct RecvMsg { fd: SharedFd, buffer: T, control: C, + control_len: u32, _p: PhantomPinned, } @@ -806,16 +807,22 @@ impl RecvMsg { fd, buffer, control, + control_len: 0, _p: PhantomPinned, } } } impl IntoInner for RecvMsg { - type Inner = ((T, C), SOCKADDR_STORAGE, socklen_t); + type Inner = ((T, C), SOCKADDR_STORAGE, socklen_t, usize); fn into_inner(self) -> Self::Inner { - ((self.buffer, self.control), self.addr, self.addr_len) + ( + (self.buffer, self.control), + self.addr, + self.addr_len, + self.control_len as _, + ) } } @@ -837,6 +844,7 @@ impl OpCode for RecvMsg { Control: std::mem::transmute::(this.control.as_io_slice_mut()), dwFlags: 0, }; + this.control_len = 0; let mut received = 0; let res = recvmsg_fn( @@ -846,6 +854,7 @@ impl OpCode for RecvMsg { optr, None, ); + this.control_len = msg.Control.len; winsock_result(res, received) } diff --git a/compio-driver/src/op.rs b/compio-driver/src/op.rs index 89b60d10..743a6c5f 100644 --- a/compio-driver/src/op.rs +++ b/compio-driver/src/op.rs @@ -49,23 +49,47 @@ impl BufResultExt for BufResult<(usize, O), T> { } } -/// Helper trait for [`RecvFrom`] and [`RecvFromVectored`]. +impl BufResultExt for BufResult<(usize, usize, O), (T, C)> { + fn map_advanced(self) -> Self { + self.map( + |(init_buffer, init_control, obj), (mut buffer, mut control)| { + unsafe { + buffer.set_buf_init(init_buffer); + control.set_buf_init(init_control); + } + ((init_buffer, init_control, obj), (buffer, control)) + }, + ) + } +} + +/// Helper trait for [`RecvFrom`], [`RecvFromVectored`] and [`RecvMsg`]. pub trait RecvResultExt { /// The mapped result. - type RecvFromResult; + type RecvResult; /// Create [`SockAddr`] if the result is [`Ok`]. - fn map_addr(self) -> Self::RecvFromResult; + fn map_addr(self) -> Self::RecvResult; } impl RecvResultExt for BufResult { - type RecvFromResult = BufResult<(usize, SockAddr), T>; + type RecvResult = BufResult<(usize, SockAddr), T>; + + fn map_addr(self) -> Self::RecvResult { + self.map_buffer(|(buffer, addr_buffer, addr_size)| (buffer, addr_buffer, addr_size, 0)) + .map_addr() + .map_res(|(res, _, addr)| (res, addr)) + } +} + +impl RecvResultExt for BufResult { + type RecvResult = BufResult<(usize, usize, SockAddr), T>; - fn map_addr(self) -> Self::RecvFromResult { + fn map_addr(self) -> Self::RecvResult { self.map2( - |res, (buffer, addr_buffer, addr_size)| { + |res, (buffer, addr_buffer, addr_size, len)| { let addr = unsafe { SockAddr::new(addr_buffer, addr_size) }; - ((res, addr), buffer) + ((res, len, addr), buffer) }, |(buffer, ..)| buffer, ) diff --git a/compio-driver/src/unix/op.rs b/compio-driver/src/unix/op.rs index aff5899a..2d14bc1b 100644 --- a/compio-driver/src/unix/op.rs +++ b/compio-driver/src/unix/op.rs @@ -411,15 +411,20 @@ impl RecvMsg { self.msg.msg_iov = self.slices.as_mut_ptr() as _; self.msg.msg_iovlen = self.slices.len() as _; self.msg.msg_control = self.control.as_buf_mut_ptr() as _; - self.msg.msg_controllen = self.control.buf_len() as _; + self.msg.msg_controllen = self.control.buf_capacity() as _; } } impl IntoInner for RecvMsg { - type Inner = ((T, C), sockaddr_storage, socklen_t); + type Inner = ((T, C), sockaddr_storage, socklen_t, usize); fn into_inner(self) -> Self::Inner { - ((self.buffer, self.control), self.addr, self.msg.msg_namelen) + ( + (self.buffer, self.control), + self.addr, + self.msg.msg_namelen, + self.msg.msg_controllen as _, + ) } } diff --git a/compio-net/src/socket.rs b/compio-net/src/socket.rs index 0f84dd08..d7bc5a8d 100644 --- a/compio-net/src/socket.rs +++ b/compio-net/src/socket.rs @@ -261,7 +261,7 @@ impl Socket { &self, buffer: T, control: C, - ) -> BufResult<(usize, SockAddr), (T, C)> { + ) -> BufResult<(usize, usize, SockAddr), (T, C)> { self.recv_msg_vectored([buffer], control) .await .map_buffer(|([buffer], control)| (buffer, control)) @@ -271,20 +271,14 @@ impl Socket { &self, buffer: T, control: C, - ) -> BufResult<(usize, SockAddr), (T, C)> { + ) -> BufResult<(usize, usize, SockAddr), (T, C)> { let fd = self.to_shared_fd(); let op = RecvMsg::new(fd, buffer, control); compio_runtime::submit(op) .await .into_inner() .map_addr() - .map(|(init, obj), (mut buffer, control)| { - // SAFETY: The number of bytes received would not bypass the buffer capacity. - unsafe { - buffer.set_buf_init(init); - } - ((init, obj), (buffer, control)) - }) + .map_advanced() } pub async fn send_to(&self, buffer: T, addr: &SockAddr) -> BufResult { diff --git a/compio-net/src/udp.rs b/compio-net/src/udp.rs index 1063eed9..13e59d73 100644 --- a/compio-net/src/udp.rs +++ b/compio-net/src/udp.rs @@ -228,11 +228,11 @@ impl UdpSocket { &self, buffer: T, control: C, - ) -> BufResult<(usize, SocketAddr), (T, C)> { + ) -> BufResult<(usize, usize, SocketAddr), (T, C)> { self.inner .recv_msg(buffer, control) .await - .map_res(|(n, addr)| (n, addr.as_socket().expect("should be SocketAddr"))) + .map_res(|(n, m, addr)| (n, m, addr.as_socket().expect("should be SocketAddr"))) } /// Receives a single datagram message and ancillary data on the socket. On @@ -241,11 +241,11 @@ impl UdpSocket { &self, buffer: T, control: C, - ) -> BufResult<(usize, SocketAddr), (T, C)> { + ) -> BufResult<(usize, usize, SocketAddr), (T, C)> { self.inner .recv_msg_vectored(buffer, control) .await - .map_res(|(n, addr)| (n, addr.as_socket().expect("should be SocketAddr"))) + .map_res(|(n, m, addr)| (n, m, addr.as_socket().expect("should be SocketAddr"))) } /// Sends data on the socket to the given address. On success, returns the diff --git a/compio-net/tests/udp.rs b/compio-net/tests/udp.rs index fe7dc263..d813290e 100644 --- a/compio-net/tests/udp.rs +++ b/compio-net/tests/udp.rs @@ -103,14 +103,18 @@ async fn send_msg_with_ipv6_ecn() { active.send_msg(MSG, control, passive_addr).await.unwrap(); - let res = passive.recv_msg(Vec::with_capacity(20), [0u8; 32]).await; - assert_eq!(res.0.unwrap().1, active_addr); - assert_eq!(res.1.0, MSG.as_bytes()); + let ((_, _, addr), (buffer, control)) = passive + .recv_msg(Vec::with_capacity(20), Vec::with_capacity(32)) + .await + .unwrap(); + assert_eq!(addr, active_addr); + assert_eq!(buffer, MSG.as_bytes()); unsafe { - let mut iter = CMsgIter::new(&res.1.1); + let mut iter = CMsgIter::new(&control); let cmsg = iter.next().unwrap(); assert_eq!(cmsg.level(), IPPROTO_IPV6); assert_eq!(cmsg.ty(), IPV6_TCLASS); assert_eq!(cmsg.data::(), &ECN_BITS); + assert!(iter.next().is_none()); } }