Skip to content

Commit

Permalink
Merge pull request #283 from AsakuraMizu/master
Browse files Browse the repository at this point in the history
fix(driver/net): ancillary data
  • Loading branch information
George-Miao authored Aug 3, 2024
2 parents 84ed77b + 970f328 commit d24ce51
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 33 deletions.
17 changes: 13 additions & 4 deletions compio-driver/src/iocp/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,7 @@ pub struct RecvMsg<T: IoVectoredBufMut, C: IoBufMut, S> {
fd: SharedFd<S>,
buffer: T,
control: C,
control_len: u32,
_p: PhantomPinned,
}

Expand All @@ -806,16 +807,22 @@ impl<T: IoVectoredBufMut, C: IoBufMut, S> RecvMsg<T, C, S> {
fd,
buffer,
control,
control_len: 0,
_p: PhantomPinned,
}
}
}

impl<T: IoVectoredBufMut, C: IoBufMut, S> IntoInner for RecvMsg<T, C, S> {
type Inner = ((T, C), SOCKADDR_STORAGE, socklen_t);
type Inner = ((T, C), SOCKADDR_STORAGE, socklen_t, usize);

fn into_inner(self) -> Self::Inner {
((self.buffer, self.control), self.addr, self.addr_len)
(
(self.buffer, self.control),
self.addr,
self.addr_len,
self.control_len as _,
)
}
}

Expand All @@ -828,7 +835,7 @@ impl<T: IoVectoredBufMut, C: IoBufMut, S: AsRawFd> OpCode for RecvMsg<T, C, S> {
})?;

let this = self.get_unchecked_mut();
let mut slices = this.buffer.as_io_slices_mut();
let mut slices = this.buffer.io_slices_mut();
let mut msg = WSAMSG {
name: &mut this.addr as *mut _ as _,
namelen: this.addr_len,
Expand All @@ -837,6 +844,7 @@ impl<T: IoVectoredBufMut, C: IoBufMut, S: AsRawFd> OpCode for RecvMsg<T, C, S> {
Control: std::mem::transmute::<IoSliceMut, WSABUF>(this.control.as_io_slice_mut()),
dwFlags: 0,
};
this.control_len = 0;

let mut received = 0;
let res = recvmsg_fn(
Expand All @@ -846,6 +854,7 @@ impl<T: IoVectoredBufMut, C: IoBufMut, S: AsRawFd> OpCode for RecvMsg<T, C, S> {
optr,
None,
);
this.control_len = msg.Control.len;
winsock_result(res, received)
}

Expand Down Expand Up @@ -897,7 +906,7 @@ impl<T: IoVectoredBuf, C: IoBuf, S: AsRawFd> OpCode for SendMsg<T, C, S> {
unsafe fn operate(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
let this = self.get_unchecked_mut();

let slices = this.buffer.as_io_slices();
let slices = this.buffer.io_slices();
let msg = WSAMSG {
name: this.addr.as_ptr() as _,
namelen: this.addr.len(),
Expand Down
38 changes: 31 additions & 7 deletions compio-driver/src/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,47 @@ impl<T: SetBufInit, O> BufResultExt for BufResult<(usize, O), T> {
}
}

/// Helper trait for [`RecvFrom`] and [`RecvFromVectored`].
impl<T: SetBufInit, C: SetBufInit, O> BufResultExt for BufResult<(usize, usize, O), (T, C)> {
fn map_advanced(self) -> Self {
self.map(
|(init_buffer, init_control, obj), (mut buffer, mut control)| {
unsafe {
buffer.set_buf_init(init_buffer);
control.set_buf_init(init_control);
}
((init_buffer, init_control, obj), (buffer, control))
},
)
}
}

/// Helper trait for [`RecvFrom`], [`RecvFromVectored`] and [`RecvMsg`].
pub trait RecvResultExt {
/// The mapped result.
type RecvFromResult;
type RecvResult;

/// Create [`SockAddr`] if the result is [`Ok`].
fn map_addr(self) -> Self::RecvFromResult;
fn map_addr(self) -> Self::RecvResult;
}

impl<T> RecvResultExt for BufResult<usize, (T, sockaddr_storage, socklen_t)> {
type RecvFromResult = BufResult<(usize, SockAddr), T>;
type RecvResult = BufResult<(usize, SockAddr), T>;

fn map_addr(self) -> Self::RecvResult {
self.map_buffer(|(buffer, addr_buffer, addr_size)| (buffer, addr_buffer, addr_size, 0))
.map_addr()
.map_res(|(res, _, addr)| (res, addr))
}
}

impl<T> RecvResultExt for BufResult<usize, (T, sockaddr_storage, socklen_t, usize)> {
type RecvResult = BufResult<(usize, usize, SockAddr), T>;

fn map_addr(self) -> Self::RecvFromResult {
fn map_addr(self) -> Self::RecvResult {
self.map2(
|res, (buffer, addr_buffer, addr_size)| {
|res, (buffer, addr_buffer, addr_size, len)| {
let addr = unsafe { SockAddr::new(addr_buffer, addr_size) };
((res, addr), buffer)
((res, len, addr), buffer)
},
|(buffer, ..)| buffer,
)
Expand Down
15 changes: 10 additions & 5 deletions compio-driver/src/unix/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -404,22 +404,27 @@ impl<T: IoVectoredBufMut, C: IoBufMut, S> RecvMsg<T, C, S> {
}

pub(crate) unsafe fn set_msg(&mut self) {
self.slices = self.buffer.as_io_slices_mut();
self.slices = self.buffer.io_slices_mut();

self.msg.msg_name = std::ptr::addr_of_mut!(self.addr) as _;
self.msg.msg_namelen = std::mem::size_of_val(&self.addr) as _;
self.msg.msg_iov = self.slices.as_mut_ptr() as _;
self.msg.msg_iovlen = self.slices.len() as _;
self.msg.msg_control = self.control.as_buf_mut_ptr() as _;
self.msg.msg_controllen = self.control.buf_len() as _;
self.msg.msg_controllen = self.control.buf_capacity() as _;
}
}

impl<T: IoVectoredBufMut, C: IoBufMut, S> IntoInner for RecvMsg<T, C, S> {
type Inner = ((T, C), sockaddr_storage, socklen_t);
type Inner = ((T, C), sockaddr_storage, socklen_t, usize);

fn into_inner(self) -> Self::Inner {
((self.buffer, self.control), self.addr, self.msg.msg_namelen)
(
(self.buffer, self.control),
self.addr,
self.msg.msg_namelen,
self.msg.msg_controllen as _,
)
}
}

Expand Down Expand Up @@ -458,7 +463,7 @@ impl<T: IoVectoredBuf, C: IoBuf, S> SendMsg<T, C, S> {
}

pub(crate) unsafe fn set_msg(&mut self) {
self.slices = self.buffer.as_io_slices();
self.slices = self.buffer.io_slices();

self.msg.msg_name = self.addr.as_ptr() as _;
self.msg.msg_namelen = self.addr.len();
Expand Down
12 changes: 3 additions & 9 deletions compio-net/src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ impl Socket {
&self,
buffer: T,
control: C,
) -> BufResult<(usize, SockAddr), (T, C)> {
) -> BufResult<(usize, usize, SockAddr), (T, C)> {
self.recv_msg_vectored([buffer], control)
.await
.map_buffer(|([buffer], control)| (buffer, control))
Expand All @@ -271,20 +271,14 @@ impl Socket {
&self,
buffer: T,
control: C,
) -> BufResult<(usize, SockAddr), (T, C)> {
) -> BufResult<(usize, usize, SockAddr), (T, C)> {
let fd = self.to_shared_fd();
let op = RecvMsg::new(fd, buffer, control);
compio_runtime::submit(op)
.await
.into_inner()
.map_addr()
.map(|(init, obj), (mut buffer, control)| {
// SAFETY: The number of bytes received would not bypass the buffer capacity.
unsafe {
buffer.set_buf_init(init);
}
((init, obj), (buffer, control))
})
.map_advanced()
}

pub async fn send_to<T: IoBuf>(&self, buffer: T, addr: &SockAddr) -> BufResult<usize, T> {
Expand Down
8 changes: 4 additions & 4 deletions compio-net/src/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,11 @@ impl UdpSocket {
&self,
buffer: T,
control: C,
) -> BufResult<(usize, SocketAddr), (T, C)> {
) -> BufResult<(usize, usize, SocketAddr), (T, C)> {
self.inner
.recv_msg(buffer, control)
.await
.map_res(|(n, addr)| (n, addr.as_socket().expect("should be SocketAddr")))
.map_res(|(n, m, addr)| (n, m, addr.as_socket().expect("should be SocketAddr")))
}

/// Receives a single datagram message and ancillary data on the socket. On
Expand All @@ -241,11 +241,11 @@ impl UdpSocket {
&self,
buffer: T,
control: C,
) -> BufResult<(usize, SocketAddr), (T, C)> {
) -> BufResult<(usize, usize, SocketAddr), (T, C)> {
self.inner
.recv_msg_vectored(buffer, control)
.await
.map_res(|(n, addr)| (n, addr.as_socket().expect("should be SocketAddr")))
.map_res(|(n, m, addr)| (n, m, addr.as_socket().expect("should be SocketAddr")))
}

/// Sends data on the socket to the given address. On success, returns the
Expand Down
12 changes: 8 additions & 4 deletions compio-net/tests/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,18 @@ async fn send_msg_with_ipv6_ecn() {

active.send_msg(MSG, control, passive_addr).await.unwrap();

let res = passive.recv_msg(Vec::with_capacity(20), [0u8; 32]).await;
assert_eq!(res.0.unwrap().1, active_addr);
assert_eq!(res.1.0, MSG.as_bytes());
let ((_, _, addr), (buffer, control)) = passive
.recv_msg(Vec::with_capacity(20), Vec::with_capacity(32))
.await
.unwrap();
assert_eq!(addr, active_addr);
assert_eq!(buffer, MSG.as_bytes());
unsafe {
let mut iter = CMsgIter::new(&res.1.1);
let mut iter = CMsgIter::new(&control);
let cmsg = iter.next().unwrap();
assert_eq!(cmsg.level(), IPPROTO_IPV6);
assert_eq!(cmsg.ty(), IPV6_TCLASS);
assert_eq!(cmsg.data::<i32>(), &ECN_BITS);
assert!(iter.next().is_none());
}
}

0 comments on commit d24ce51

Please sign in to comment.