diff --git a/src/io/sys/windows/miow.rs b/src/io/sys/windows/miow.rs index 8cac75b6..329f06c1 100644 --- a/src/io/sys/windows/miow.rs +++ b/src/io/sys/windows/miow.rs @@ -1,6 +1,8 @@ //! ported from miow crate which is not maintained anymore -use std::net::SocketAddr; +use std::net::{ + Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, TcpListener, TcpStream, +}; use std::os::windows::io::{AsRawHandle, AsRawSocket}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Duration; @@ -413,6 +415,66 @@ fn socket_addr_to_ptrs(addr: &SocketAddr) -> (SocketAddrCRepr, i32) { } } +#[doc(hidden)] +trait NetInt { + fn from_be(i: Self) -> Self; + #[allow(dead_code)] + fn to_be(&self) -> Self; +} +macro_rules! doit { + ($($t:ident)*) => ($(impl NetInt for $t { + fn from_be(i: Self) -> Self { <$t>::from_be(i) } + fn to_be(&self) -> Self { <$t>::to_be(*self) } + })*) +} +doit! { i8 i16 i32 i64 isize u8 u16 u32 u64 usize } + +// fn hton(i: I) -> I { i.to_be() } +fn ntoh(i: I) -> I { + I::from_be(i) +} + +unsafe fn ptrs_to_socket_addr(ptr: *const SOCKADDR, len: i32) -> Option { + if (len as usize) < std::mem::size_of::() { + return None; + } + match (*ptr).sa_family as _ { + AF_INET if len as usize >= std::mem::size_of::() => { + let b = &*(ptr as *const SOCKADDR_IN); + let ip = ntoh(b.sin_addr.S_un.S_addr); + let ip = Ipv4Addr::new( + (ip >> 24) as u8, + (ip >> 16) as u8, + (ip >> 8) as u8, + ip as u8, + ); + Some(SocketAddr::V4(SocketAddrV4::new(ip, ntoh(b.sin_port)))) + } + AF_INET6 if len as usize >= std::mem::size_of::() => { + let b = &*(ptr as *const SOCKADDR_IN6); + let arr = &b.sin6_addr.u.Byte; + let ip = Ipv6Addr::new( + ((arr[0] as u16) << 8) | (arr[1] as u16), + ((arr[2] as u16) << 8) | (arr[3] as u16), + ((arr[4] as u16) << 8) | (arr[5] as u16), + ((arr[6] as u16) << 8) | (arr[7] as u16), + ((arr[8] as u16) << 8) | (arr[9] as u16), + ((arr[10] as u16) << 8) | (arr[11] as u16), + ((arr[12] as u16) << 8) | (arr[13] as u16), + ((arr[14] as u16) << 8) | (arr[15] as u16), + ); + let addr = SocketAddrV6::new( + ip, + ntoh(b.sin6_port), + ntoh(b.sin6_flowinfo), + ntoh(b.Anonymous.sin6_scope_id), + ); + Some(SocketAddr::V6(addr)) + } + _ => None, + } +} + pub unsafe fn socket_read( socket: RawSocket, buf: &mut [u8], @@ -526,3 +588,162 @@ pub fn connect_complete(socket: RawSocket) -> io::Result<()> { Err(io::Error::last_os_error()) } } + +static GETACCEPTEXSOCKADDRS: WsaExtension = WsaExtension { + guid: GUID { + data1: 0xb5367df2, + data2: 0xcbac, + data3: 0x11cf, + data4: [0x95, 0xca, 0x00, 0x80, 0x5f, 0x48, 0xa1, 0x92], + }, + val: AtomicUsize::new(0), +}; + +/// A type to represent a buffer in which an accepted socket's address will be +/// stored. +/// +/// This type is used with the `accept_overlapped` method on the +/// `TcpListenerExt` trait to provide space for the overlapped I/O operation to +/// fill in the socket addresses upon completion. +#[repr(C)] +pub struct AcceptAddrsBuf { + // For AcceptEx we've got the restriction that the addresses passed in that + // buffer need to be at least 16 bytes more than the maximum address length + // for the protocol in question, so add some extra here and there + local: SOCKADDR_STORAGE, + _pad1: [u8; 16], + remote: SOCKADDR_STORAGE, + _pad2: [u8; 16], +} + +impl AcceptAddrsBuf { + /// Creates a new blank buffer ready to be passed to a call to + /// `accept_overlapped`. + pub fn new() -> AcceptAddrsBuf { + unsafe { std::mem::zeroed() } + } + + /// Parses the data contained in this address buffer, returning the parsed + /// result if successful. + /// + /// This function can be called after a call to `accept_overlapped` has + /// succeeded to parse out the data that was written in. + pub fn parse(&self, socket: &TcpListener) -> io::Result { + let mut ret = AcceptAddrs { + local: 0 as *mut _, + local_len: 0, + remote: 0 as *mut _, + remote_len: 0, + _data: self, + }; + let ptr = GETACCEPTEXSOCKADDRS.get(socket.as_raw_socket() as SOCKET)?; + assert!(ptr != 0); + unsafe { + let get_sockaddrs = std::mem::transmute::<_, LPFN_GETACCEPTEXSOCKADDRS>(ptr).unwrap(); + let (a, b, c, d) = self.args(); + get_sockaddrs( + a, + b, + c, + d, + &mut ret.local, + &mut ret.local_len, + &mut ret.remote, + &mut ret.remote_len, + ); + Ok(ret) + } + } + + #[allow(deref_nullptr)] + fn args(&self) -> (*mut std::ffi::c_void, u32, u32, u32) { + let remote_offset = unsafe { &(*(0 as *const AcceptAddrsBuf)).remote as *const _ as usize }; + ( + self as *const _ as *mut _, + 0, + remote_offset as u32, + (std::mem::size_of_val(self) - remote_offset) as u32, + ) + } +} + +/// The parsed return value of `AcceptAddrsBuf`. +pub struct AcceptAddrs<'a> { + local: *mut SOCKADDR, + local_len: i32, + remote: *mut SOCKADDR, + remote_len: i32, + _data: &'a AcceptAddrsBuf, +} + +impl<'a> AcceptAddrs<'a> { + /// Returns the local socket address contained in this buffer. + #[allow(dead_code)] + pub fn local(&self) -> Option { + unsafe { ptrs_to_socket_addr(self.local, self.local_len) } + } + + /// Returns the remote socket address contained in this buffer. + pub fn remote(&self) -> Option { + unsafe { ptrs_to_socket_addr(self.remote, self.remote_len) } + } +} + +pub unsafe fn accept_overlapped( + me: RawSocket, + socket: &TcpStream, + addrs: &mut AcceptAddrsBuf, + overlapped: *mut OVERLAPPED, +) -> io::Result { + static ACCEPTEX: WsaExtension = WsaExtension { + guid: GUID { + data1: 0xb5367df1, + data2: 0xcbac, + data3: 0x11cf, + data4: [0x95, 0xca, 0x00, 0x80, 0x5f, 0x48, 0xa1, 0x92], + }, + val: AtomicUsize::new(0), + }; + + let ptr = ACCEPTEX.get(me as SOCKET)?; + assert!(ptr != 0); + let accept_ex = std::mem::transmute::<_, LPFN_ACCEPTEX>(ptr).unwrap(); + + let mut bytes = 0; + let (a, b, c, d) = (*addrs).args(); + let r = accept_ex( + me as SOCKET, + socket.as_raw_socket() as SOCKET, + a, + b, + c, + d, + &mut bytes, + overlapped, + ); + let succeeded = if r == TRUE { + true + } else { + last_err()?; + false + }; + Ok(succeeded) +} + +pub fn accept_complete(me: RawSocket, socket: &TcpStream) -> io::Result<()> { + const SO_UPDATE_ACCEPT_CONTEXT: i32 = 0x700B; + let result = unsafe { + setsockopt( + socket.as_raw_socket() as SOCKET, + SOL_SOCKET as _, + SO_UPDATE_ACCEPT_CONTEXT, + &me as *const _ as *mut _, + std::mem::size_of_val(&me) as i32, + ) + }; + if result == 0 { + Ok(()) + } else { + Err(io::Error::last_os_error()) + } +} diff --git a/src/io/sys/windows/net/tcp_listener_accept.rs b/src/io/sys/windows/net/tcp_listener_accept.rs index a8623abf..e79cf9ba 100644 --- a/src/io/sys/windows/net/tcp_listener_accept.rs +++ b/src/io/sys/windows/net/tcp_listener_accept.rs @@ -2,6 +2,7 @@ use std::io; use std::net::SocketAddr; use std::os::windows::io::AsRawSocket; +use super::super::miow::{accept_complete, accept_overlapped, AcceptAddrsBuf}; use super::super::{add_socket, co_io_result, EventData}; #[cfg(feature = "io_cancel")] use crate::coroutine_impl::co_cancel_data; @@ -12,7 +13,6 @@ use crate::io::OptionCell; use crate::net::{TcpListener, TcpStream}; use crate::scheduler::get_scheduler; use crate::sync::delay_drop::DelayDrop; -use miow::net::{AcceptAddrsBuf, TcpListenerExt}; use windows_sys::Win32::Foundation::*; pub struct TcpListenerAccept<'a> { @@ -49,7 +49,7 @@ impl<'a> TcpListenerAccept<'a> { co_io_result(&self.io_data, self.is_coroutine)?; let socket = &self.socket; let ss = self.ret.take(); - let s = socket.accept_complete(&ss).and_then(|_| { + let s = accept_complete(socket.as_raw_socket(), &ss).and_then(|_| { ss.set_nonblocking(true)?; add_socket(&ss).map(|io| TcpStream::from_stream(ss, io)) })?; @@ -76,8 +76,12 @@ impl<'a> EventSource for TcpListenerAccept<'a> { // call the overlapped read API co_try!(s, self.io_data.co.take().expect("can't get co"), unsafe { - self.socket - .accept_overlapped(&self.ret, &mut self.addr, self.io_data.get_overlapped()) + accept_overlapped( + self.socket.as_raw_socket(), + &self.ret, + &mut self.addr, + self.io_data.get_overlapped(), + ) }); #[cfg(feature = "io_cancel")]