Skip to content

Commit

Permalink
Merge pull request #275 from AsakuraMizu/master
Browse files Browse the repository at this point in the history
feat: add ancillary data support
  • Loading branch information
George-Miao authored Jul 16, 2024
2 parents 1aeeb52 + f9b286a commit 56024ab
Show file tree
Hide file tree
Showing 13 changed files with 934 additions and 11 deletions.
153 changes: 148 additions & 5 deletions compio-driver/src/iocp/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ use std::{
};

use aligned_array::{Aligned, A8};
use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
use compio_buf::{
BufResult, IntoInner, IoBuf, IoBufMut, IoSlice, IoSliceMut, IoVectoredBuf, IoVectoredBufMut,
};
#[cfg(not(feature = "once_cell_try"))]
use once_cell::sync::OnceCell as OnceLock;
use socket2::SockAddr;
Expand All @@ -25,10 +27,11 @@ use windows_sys::{
},
Networking::WinSock::{
closesocket, setsockopt, shutdown, socklen_t, WSAIoctl, WSARecv, WSARecvFrom, WSASend,
WSASendTo, LPFN_ACCEPTEX, LPFN_CONNECTEX, LPFN_GETACCEPTEXSOCKADDRS, SD_BOTH,
SD_RECEIVE, SD_SEND, SIO_GET_EXTENSION_FUNCTION_POINTER, SOCKADDR, SOCKADDR_STORAGE,
SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT, SO_UPDATE_CONNECT_CONTEXT, WSAID_ACCEPTEX,
WSAID_CONNECTEX, WSAID_GETACCEPTEXSOCKADDRS,
WSASendMsg, WSASendTo, CMSGHDR, LPFN_ACCEPTEX, LPFN_CONNECTEX,
LPFN_GETACCEPTEXSOCKADDRS, LPFN_WSARECVMSG, SD_BOTH, SD_RECEIVE, SD_SEND,
SIO_GET_EXTENSION_FUNCTION_POINTER, SOCKADDR, SOCKADDR_STORAGE, SOL_SOCKET,
SO_UPDATE_ACCEPT_CONTEXT, SO_UPDATE_CONNECT_CONTEXT, WSABUF, WSAID_ACCEPTEX,
WSAID_CONNECTEX, WSAID_GETACCEPTEXSOCKADDRS, WSAID_WSARECVMSG, WSAMSG,
},
Storage::FileSystem::{FlushFileBuffers, ReadFile, WriteFile},
System::{
Expand Down Expand Up @@ -774,6 +777,146 @@ impl<T: IoVectoredBuf, S: AsRawFd> OpCode for SendToVectored<T, S> {
}
}

static WSA_RECVMSG: OnceLock<LPFN_WSARECVMSG> = OnceLock::new();

/// Receive data and source address with ancillary data into vectored buffer.
pub struct RecvMsg<T: IoVectoredBufMut, C: IoBufMut, S> {
addr: SOCKADDR_STORAGE,
addr_len: socklen_t,
fd: SharedFd<S>,
buffer: T,
control: C,
_p: PhantomPinned,
}

impl<T: IoVectoredBufMut, C: IoBufMut, S> RecvMsg<T, C, S> {
/// Create [`RecvMsg`].
///
/// # Panics
///
/// This function will panic if the control message buffer is misaligned.
pub fn new(fd: SharedFd<S>, buffer: T, control: C) -> Self {
assert!(
control.as_buf_ptr().cast::<CMSGHDR>().is_aligned(),
"misaligned control message buffer"
);
Self {
addr: unsafe { std::mem::zeroed() },
addr_len: std::mem::size_of::<SOCKADDR_STORAGE>() as _,
fd,
buffer,
control,
_p: PhantomPinned,
}
}
}

impl<T: IoVectoredBufMut, C: IoBufMut, S> IntoInner for RecvMsg<T, C, S> {
type Inner = ((T, C), SOCKADDR_STORAGE, socklen_t);

fn into_inner(self) -> Self::Inner {
((self.buffer, self.control), self.addr, self.addr_len)
}
}

impl<T: IoVectoredBufMut, C: IoBufMut, S: AsRawFd> OpCode for RecvMsg<T, C, S> {
unsafe fn operate(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
let recvmsg_fn = WSA_RECVMSG
.get_or_try_init(|| get_wsa_fn(self.fd.as_raw_fd(), WSAID_WSARECVMSG))?
.ok_or_else(|| {
io::Error::new(io::ErrorKind::Unsupported, "cannot retrieve WSARecvMsg")
})?;

let this = self.get_unchecked_mut();
let mut slices = this.buffer.as_io_slices_mut();
let mut msg = WSAMSG {
name: &mut this.addr as *mut _ as _,
namelen: this.addr_len,
lpBuffers: slices.as_mut_ptr() as _,
dwBufferCount: slices.len() as _,
Control: std::mem::transmute::<IoSliceMut, WSABUF>(this.control.as_io_slice_mut()),
dwFlags: 0,
};

let mut received = 0;
let res = recvmsg_fn(
this.fd.as_raw_fd() as _,
&mut msg,
&mut received,
optr,
None,
);
winsock_result(res, received)
}

unsafe fn cancel(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> io::Result<()> {
cancel(self.fd.as_raw_fd(), optr)
}
}

/// Send data to specified address accompanied by ancillary data from vectored
/// buffer.
pub struct SendMsg<T: IoVectoredBuf, C: IoBuf, S> {
fd: SharedFd<S>,
buffer: T,
control: C,
addr: SockAddr,
_p: PhantomPinned,
}

impl<T: IoVectoredBuf, C: IoBuf, S> SendMsg<T, C, S> {
/// Create [`SendMsg`].
///
/// # Panics
///
/// This function will panic if the control message buffer is misaligned.
pub fn new(fd: SharedFd<S>, buffer: T, control: C, addr: SockAddr) -> Self {
assert!(
control.as_buf_ptr().cast::<CMSGHDR>().is_aligned(),
"misaligned control message buffer"
);
Self {
fd,
buffer,
control,
addr,
_p: PhantomPinned,
}
}
}

impl<T: IoVectoredBuf, C: IoBuf, S> IntoInner for SendMsg<T, C, S> {
type Inner = (T, C);

fn into_inner(self) -> Self::Inner {
(self.buffer, self.control)
}
}

impl<T: IoVectoredBuf, C: IoBuf, S: AsRawFd> OpCode for SendMsg<T, C, S> {
unsafe fn operate(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
let this = self.get_unchecked_mut();

let slices = this.buffer.as_io_slices();
let msg = WSAMSG {
name: this.addr.as_ptr() as _,
namelen: this.addr.len(),
lpBuffers: slices.as_ptr() as _,
dwBufferCount: slices.len() as _,
Control: std::mem::transmute::<IoSlice, WSABUF>(this.control.as_io_slice()),
dwFlags: 0,
};

let mut sent = 0;
let res = WSASendMsg(this.fd.as_raw_fd() as _, &msg, 0, &mut sent, optr, None);
winsock_result(res, sent)
}

unsafe fn cancel(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> io::Result<()> {
cancel(self.fd.as_raw_fd(), optr)
}
}

/// Connect a named pipe server.
pub struct ConnectNamedPipe<S> {
pub(crate) fd: SharedFd<S>,
Expand Down
20 changes: 20 additions & 0 deletions compio-driver/src/iour/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,26 @@ impl<T: IoVectoredBuf, S> IntoInner for SendToVectored<T, S> {
}
}

impl<T: IoVectoredBufMut, C: IoBufMut, S: AsRawFd> OpCode for RecvMsg<T, C, S> {
fn create_entry(self: Pin<&mut Self>) -> OpEntry {
let this = unsafe { self.get_unchecked_mut() };
unsafe { this.set_msg() };
opcode::RecvMsg::new(Fd(this.fd.as_raw_fd()), &mut this.msg)
.build()
.into()
}
}

impl<T: IoVectoredBuf, C: IoBuf, S: AsRawFd> OpCode for SendMsg<T, C, S> {
fn create_entry(self: Pin<&mut Self>) -> OpEntry {
let this = unsafe { self.get_unchecked_mut() };
unsafe { this.set_msg() };
opcode::SendMsg::new(Fd(this.fd.as_raw_fd()), &this.msg)
.build()
.into()
}
}

impl<S: AsRawFd> OpCode for PollOnce<S> {
fn create_entry(self: Pin<&mut Self>) -> OpEntry {
let flags = match self.interest {
Expand Down
4 changes: 2 additions & 2 deletions compio-driver/src/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ use socket2::SockAddr;
#[cfg(windows)]
pub use crate::sys::op::ConnectNamedPipe;
pub use crate::sys::op::{
Accept, Recv, RecvFrom, RecvFromVectored, RecvVectored, Send, SendTo, SendToVectored,
SendVectored,
Accept, Recv, RecvFrom, RecvFromVectored, RecvMsg, RecvVectored, Send, SendMsg, SendTo,
SendToVectored, SendVectored,
};
#[cfg(unix)]
pub use crate::sys::op::{
Expand Down
41 changes: 41 additions & 0 deletions compio-driver/src/poll/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,47 @@ impl<T: IoVectoredBuf, S> IntoInner for SendToVectored<T, S> {
}
}

impl<T: IoVectoredBufMut, C: IoBufMut, S: AsRawFd> RecvMsg<T, C, S> {
unsafe fn call(&mut self) -> libc::ssize_t {
libc::recvmsg(self.fd.as_raw_fd(), &mut self.msg, 0)
}
}

impl<T: IoVectoredBufMut, C: IoBufMut, S: AsRawFd> OpCode for RecvMsg<T, C, S> {
fn pre_submit(self: Pin<&mut Self>) -> io::Result<Decision> {
let this = unsafe { self.get_unchecked_mut() };
unsafe { this.set_msg() };
syscall!(this.call(), wait_readable(this.fd.as_raw_fd()))
}

fn on_event(self: Pin<&mut Self>, event: &Event) -> Poll<io::Result<usize>> {
debug_assert!(event.readable);

let this = unsafe { self.get_unchecked_mut() };
syscall!(break this.call())
}
}

impl<T: IoVectoredBuf, C: IoBuf, S: AsRawFd> SendMsg<T, C, S> {
unsafe fn call(&self) -> libc::ssize_t {
libc::sendmsg(self.fd.as_raw_fd(), &self.msg, 0)
}
}

impl<T: IoVectoredBuf, C: IoBuf, S: AsRawFd> OpCode for SendMsg<T, C, S> {
fn pre_submit(self: Pin<&mut Self>) -> io::Result<Decision> {
let this = unsafe { self.get_unchecked_mut() };
unsafe { this.set_msg() };
syscall!(this.call(), wait_writable(this.fd.as_raw_fd()))
}

fn on_event(self: Pin<&mut Self>, event: &Event) -> Poll<io::Result<usize>> {
debug_assert!(event.writable);

syscall!(break self.call())
}
}

impl<S: AsRawFd> OpCode for PollOnce<S> {
fn pre_submit(self: Pin<&mut Self>) -> io::Result<Decision> {
Ok(Decision::wait_for(self.fd.as_raw_fd(), self.interest))
Expand Down
107 changes: 107 additions & 0 deletions compio-driver/src/unix/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,113 @@ impl<T: IoVectoredBuf, S> IntoInner for SendVectored<T, S> {
}
}

/// Receive data and source address with ancillary data into vectored buffer.
pub struct RecvMsg<T: IoVectoredBufMut, C: IoBufMut, S> {
pub(crate) msg: libc::msghdr,
pub(crate) addr: sockaddr_storage,
pub(crate) fd: SharedFd<S>,
pub(crate) buffer: T,
pub(crate) control: C,
pub(crate) slices: Vec<IoSliceMut>,
_p: PhantomPinned,
}

impl<T: IoVectoredBufMut, C: IoBufMut, S> RecvMsg<T, C, S> {
/// Create [`RecvMsg`].
///
/// # Panics
///
/// This function will panic if the control message buffer is misaligned.
pub fn new(fd: SharedFd<S>, buffer: T, control: C) -> Self {
assert!(
control.as_buf_ptr().cast::<libc::cmsghdr>().is_aligned(),
"misaligned control message buffer"
);
Self {
addr: unsafe { std::mem::zeroed() },
msg: unsafe { std::mem::zeroed() },
fd,
buffer,
control,
slices: vec![],
_p: PhantomPinned,
}
}

pub(crate) unsafe fn set_msg(&mut self) {
self.slices = self.buffer.as_io_slices_mut();

self.msg.msg_name = std::ptr::addr_of_mut!(self.addr) as _;
self.msg.msg_namelen = std::mem::size_of_val(&self.addr) as _;
self.msg.msg_iov = self.slices.as_mut_ptr() as _;
self.msg.msg_iovlen = self.slices.len() as _;
self.msg.msg_control = self.control.as_buf_mut_ptr() as _;
self.msg.msg_controllen = self.control.buf_len() as _;
}
}

impl<T: IoVectoredBufMut, C: IoBufMut, S> IntoInner for RecvMsg<T, C, S> {
type Inner = ((T, C), sockaddr_storage, socklen_t);

fn into_inner(self) -> Self::Inner {
((self.buffer, self.control), self.addr, self.msg.msg_namelen)
}
}

/// Send data to specified address accompanied by ancillary data from vectored
/// buffer.
pub struct SendMsg<T: IoVectoredBuf, C: IoBuf, S> {
pub(crate) msg: libc::msghdr,
pub(crate) fd: SharedFd<S>,
pub(crate) buffer: T,
pub(crate) control: C,
pub(crate) addr: SockAddr,
pub(crate) slices: Vec<IoSlice>,
_p: PhantomPinned,
}

impl<T: IoVectoredBuf, C: IoBuf, S> SendMsg<T, C, S> {
/// Create [`SendMsg`].
///
/// # Panics
///
/// This function will panic if the control message buffer is misaligned.
pub fn new(fd: SharedFd<S>, buffer: T, control: C, addr: SockAddr) -> Self {
assert!(
control.as_buf_ptr().cast::<libc::cmsghdr>().is_aligned(),
"misaligned control message buffer"
);
Self {
msg: unsafe { std::mem::zeroed() },
fd,
buffer,
control,
addr,
slices: vec![],
_p: PhantomPinned,
}
}

pub(crate) unsafe fn set_msg(&mut self) {
self.slices = self.buffer.as_io_slices();

self.msg.msg_name = self.addr.as_ptr() as _;
self.msg.msg_namelen = self.addr.len();
self.msg.msg_iov = self.slices.as_ptr() as _;
self.msg.msg_iovlen = self.slices.len() as _;
self.msg.msg_control = self.control.as_buf_ptr() as _;
self.msg.msg_controllen = self.control.buf_len() as _;
}
}

impl<T: IoVectoredBuf, C: IoBuf, S> IntoInner for SendMsg<T, C, S> {
type Inner = (T, C);

fn into_inner(self) -> Self::Inner {
(self.buffer, self.control)
}
}

/// The interest to poll a file descriptor.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Interest {
Expand Down
Loading

0 comments on commit 56024ab

Please sign in to comment.