Skip to content

Commit

Permalink
📝 port accept_overlapped from miow
Browse files Browse the repository at this point in the history
  • Loading branch information
Xudong-Huang committed Mar 4, 2024
1 parent 9c39e5f commit e93fc80
Show file tree
Hide file tree
Showing 2 changed files with 230 additions and 5 deletions.
223 changes: 222 additions & 1 deletion src/io/sys/windows/miow.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
//! ported from miow crate which is not maintained anymore

use std::net::SocketAddr;
use std::net::{
Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, TcpListener, TcpStream,
};
use std::os::windows::io::{AsRawHandle, AsRawSocket};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
Expand Down Expand Up @@ -413,6 +415,66 @@ fn socket_addr_to_ptrs(addr: &SocketAddr) -> (SocketAddrCRepr, i32) {
}
}

#[doc(hidden)]
trait NetInt {
fn from_be(i: Self) -> Self;
#[allow(dead_code)]
fn to_be(&self) -> Self;
}
macro_rules! doit {
($($t:ident)*) => ($(impl NetInt for $t {
fn from_be(i: Self) -> Self { <$t>::from_be(i) }
fn to_be(&self) -> Self { <$t>::to_be(*self) }
})*)
}
doit! { i8 i16 i32 i64 isize u8 u16 u32 u64 usize }

// fn hton<I: NetInt>(i: I) -> I { i.to_be() }
fn ntoh<I: NetInt>(i: I) -> I {
I::from_be(i)
}

unsafe fn ptrs_to_socket_addr(ptr: *const SOCKADDR, len: i32) -> Option<SocketAddr> {
if (len as usize) < std::mem::size_of::<i32>() {
return None;
}
match (*ptr).sa_family as _ {
AF_INET if len as usize >= std::mem::size_of::<SOCKADDR_IN>() => {
let b = &*(ptr as *const SOCKADDR_IN);
let ip = ntoh(b.sin_addr.S_un.S_addr);
let ip = Ipv4Addr::new(
(ip >> 24) as u8,
(ip >> 16) as u8,
(ip >> 8) as u8,
ip as u8,
);
Some(SocketAddr::V4(SocketAddrV4::new(ip, ntoh(b.sin_port))))
}
AF_INET6 if len as usize >= std::mem::size_of::<SOCKADDR_IN6>() => {
let b = &*(ptr as *const SOCKADDR_IN6);
let arr = &b.sin6_addr.u.Byte;
let ip = Ipv6Addr::new(
((arr[0] as u16) << 8) | (arr[1] as u16),
((arr[2] as u16) << 8) | (arr[3] as u16),
((arr[4] as u16) << 8) | (arr[5] as u16),
((arr[6] as u16) << 8) | (arr[7] as u16),
((arr[8] as u16) << 8) | (arr[9] as u16),
((arr[10] as u16) << 8) | (arr[11] as u16),
((arr[12] as u16) << 8) | (arr[13] as u16),
((arr[14] as u16) << 8) | (arr[15] as u16),
);
let addr = SocketAddrV6::new(
ip,
ntoh(b.sin6_port),
ntoh(b.sin6_flowinfo),
ntoh(b.Anonymous.sin6_scope_id),
);
Some(SocketAddr::V6(addr))
}
_ => None,
}
}

pub unsafe fn socket_read(
socket: RawSocket,
buf: &mut [u8],
Expand Down Expand Up @@ -526,3 +588,162 @@ pub fn connect_complete(socket: RawSocket) -> io::Result<()> {
Err(io::Error::last_os_error())
}
}

static GETACCEPTEXSOCKADDRS: WsaExtension = WsaExtension {
guid: GUID {
data1: 0xb5367df2,
data2: 0xcbac,
data3: 0x11cf,
data4: [0x95, 0xca, 0x00, 0x80, 0x5f, 0x48, 0xa1, 0x92],
},
val: AtomicUsize::new(0),
};

/// A type to represent a buffer in which an accepted socket's address will be
/// stored.
///
/// This type is used with the `accept_overlapped` method on the
/// `TcpListenerExt` trait to provide space for the overlapped I/O operation to
/// fill in the socket addresses upon completion.
#[repr(C)]
pub struct AcceptAddrsBuf {
// For AcceptEx we've got the restriction that the addresses passed in that
// buffer need to be at least 16 bytes more than the maximum address length
// for the protocol in question, so add some extra here and there
local: SOCKADDR_STORAGE,
_pad1: [u8; 16],
remote: SOCKADDR_STORAGE,
_pad2: [u8; 16],
}

impl AcceptAddrsBuf {
/// Creates a new blank buffer ready to be passed to a call to
/// `accept_overlapped`.
pub fn new() -> AcceptAddrsBuf {
unsafe { std::mem::zeroed() }
}

/// Parses the data contained in this address buffer, returning the parsed
/// result if successful.
///
/// This function can be called after a call to `accept_overlapped` has
/// succeeded to parse out the data that was written in.
pub fn parse(&self, socket: &TcpListener) -> io::Result<AcceptAddrs> {
let mut ret = AcceptAddrs {
local: 0 as *mut _,
local_len: 0,
remote: 0 as *mut _,
remote_len: 0,
_data: self,
};
let ptr = GETACCEPTEXSOCKADDRS.get(socket.as_raw_socket() as SOCKET)?;
assert!(ptr != 0);
unsafe {
let get_sockaddrs = std::mem::transmute::<_, LPFN_GETACCEPTEXSOCKADDRS>(ptr).unwrap();
let (a, b, c, d) = self.args();
get_sockaddrs(
a,
b,
c,
d,
&mut ret.local,
&mut ret.local_len,
&mut ret.remote,
&mut ret.remote_len,
);
Ok(ret)
}
}

#[allow(deref_nullptr)]
fn args(&self) -> (*mut std::ffi::c_void, u32, u32, u32) {
let remote_offset = unsafe { &(*(0 as *const AcceptAddrsBuf)).remote as *const _ as usize };
(
self as *const _ as *mut _,
0,
remote_offset as u32,
(std::mem::size_of_val(self) - remote_offset) as u32,
)
}
}

/// The parsed return value of `AcceptAddrsBuf`.
pub struct AcceptAddrs<'a> {
local: *mut SOCKADDR,
local_len: i32,
remote: *mut SOCKADDR,
remote_len: i32,
_data: &'a AcceptAddrsBuf,
}

impl<'a> AcceptAddrs<'a> {
/// Returns the local socket address contained in this buffer.
#[allow(dead_code)]
pub fn local(&self) -> Option<SocketAddr> {
unsafe { ptrs_to_socket_addr(self.local, self.local_len) }
}

/// Returns the remote socket address contained in this buffer.
pub fn remote(&self) -> Option<SocketAddr> {
unsafe { ptrs_to_socket_addr(self.remote, self.remote_len) }
}
}

pub unsafe fn accept_overlapped(
me: RawSocket,
socket: &TcpStream,
addrs: &mut AcceptAddrsBuf,
overlapped: *mut OVERLAPPED,
) -> io::Result<bool> {
static ACCEPTEX: WsaExtension = WsaExtension {
guid: GUID {
data1: 0xb5367df1,
data2: 0xcbac,
data3: 0x11cf,
data4: [0x95, 0xca, 0x00, 0x80, 0x5f, 0x48, 0xa1, 0x92],
},
val: AtomicUsize::new(0),
};

let ptr = ACCEPTEX.get(me as SOCKET)?;
assert!(ptr != 0);
let accept_ex = std::mem::transmute::<_, LPFN_ACCEPTEX>(ptr).unwrap();

let mut bytes = 0;
let (a, b, c, d) = (*addrs).args();
let r = accept_ex(
me as SOCKET,
socket.as_raw_socket() as SOCKET,
a,
b,
c,
d,
&mut bytes,
overlapped,
);
let succeeded = if r == TRUE {
true
} else {
last_err()?;
false
};
Ok(succeeded)
}

pub fn accept_complete(me: RawSocket, socket: &TcpStream) -> io::Result<()> {
const SO_UPDATE_ACCEPT_CONTEXT: i32 = 0x700B;
let result = unsafe {
setsockopt(
socket.as_raw_socket() as SOCKET,
SOL_SOCKET as _,
SO_UPDATE_ACCEPT_CONTEXT,
&me as *const _ as *mut _,
std::mem::size_of_val(&me) as i32,
)
};
if result == 0 {
Ok(())
} else {
Err(io::Error::last_os_error())
}
}
12 changes: 8 additions & 4 deletions src/io/sys/windows/net/tcp_listener_accept.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::io;
use std::net::SocketAddr;
use std::os::windows::io::AsRawSocket;

use super::super::miow::{accept_complete, accept_overlapped, AcceptAddrsBuf};
use super::super::{add_socket, co_io_result, EventData};
#[cfg(feature = "io_cancel")]
use crate::coroutine_impl::co_cancel_data;
Expand All @@ -12,7 +13,6 @@ use crate::io::OptionCell;
use crate::net::{TcpListener, TcpStream};
use crate::scheduler::get_scheduler;
use crate::sync::delay_drop::DelayDrop;
use miow::net::{AcceptAddrsBuf, TcpListenerExt};
use windows_sys::Win32::Foundation::*;

pub struct TcpListenerAccept<'a> {
Expand Down Expand Up @@ -49,7 +49,7 @@ impl<'a> TcpListenerAccept<'a> {
co_io_result(&self.io_data, self.is_coroutine)?;
let socket = &self.socket;
let ss = self.ret.take();
let s = socket.accept_complete(&ss).and_then(|_| {
let s = accept_complete(socket.as_raw_socket(), &ss).and_then(|_| {
ss.set_nonblocking(true)?;
add_socket(&ss).map(|io| TcpStream::from_stream(ss, io))
})?;
Expand All @@ -76,8 +76,12 @@ impl<'a> EventSource for TcpListenerAccept<'a> {

// call the overlapped read API
co_try!(s, self.io_data.co.take().expect("can't get co"), unsafe {
self.socket
.accept_overlapped(&self.ret, &mut self.addr, self.io_data.get_overlapped())
accept_overlapped(
self.socket.as_raw_socket(),
&self.ret,
&mut self.addr,
self.io_data.get_overlapped(),
)
});

#[cfg(feature = "io_cancel")]
Expand Down

0 comments on commit e93fc80

Please sign in to comment.