Skip to content

Commit

Permalink
📝 port udp overlapped api from miow
Browse files Browse the repository at this point in the history
  • Loading branch information
Xudong-Huang committed Mar 4, 2024
1 parent e93fc80 commit d3298c4
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 13 deletions.
100 changes: 92 additions & 8 deletions src/io/sys/windows/miow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use windows_sys::Win32::Networking::WinSock::*;
use windows_sys::Win32::System::Threading::INFINITE;
use windows_sys::Win32::System::IO::*;

#[allow(clippy::upper_case_acronyms)]
type BOOL = i32;
const TRUE: BOOL = 1;
const FALSE: BOOL = 0;
Expand Down Expand Up @@ -287,7 +288,7 @@ impl WsaExtension {
if prev != 0 && !cfg!(debug_assertions) {
return Ok(prev);
}
let mut ret = 0 as usize;
let mut ret = 0;
let mut bytes = 0;

// https://github.com/microsoft/win32metadata/issues/671
Expand All @@ -302,7 +303,7 @@ impl WsaExtension {
&mut ret as *mut _ as *mut _,
std::mem::size_of_val(&ret) as u32,
&mut bytes,
0 as *mut _,
std::ptr::null_mut(),
None,
)
};
Expand Down Expand Up @@ -501,7 +502,7 @@ pub unsafe fn socket_write(
buf: &[u8],
overlapped: *mut OVERLAPPED,
) -> io::Result<Option<usize>> {
let mut buf = slice2buf(buf);
let buf = slice2buf(buf);
let mut bytes_written = 0;
// Note here that we capture the number of bytes written. The
// documentation on MSDN, however, states:
Expand All @@ -521,7 +522,7 @@ pub unsafe fn socket_write(
// [1]: https://github.com/carllerche/mio/pull/520#issuecomment-273983823
let r = WSASend(
socket as SOCKET,
&mut buf,
&buf,
1,
&mut bytes_written,
0,
Expand Down Expand Up @@ -578,7 +579,7 @@ pub fn connect_complete(socket: RawSocket) -> io::Result<()> {
socket as SOCKET,
SOL_SOCKET as _,
SO_UPDATE_CONNECT_CONTEXT,
0 as *mut _,
std::ptr::null_mut(),
0,
)
};
Expand Down Expand Up @@ -630,9 +631,9 @@ impl AcceptAddrsBuf {
/// succeeded to parse out the data that was written in.
pub fn parse(&self, socket: &TcpListener) -> io::Result<AcceptAddrs> {
let mut ret = AcceptAddrs {
local: 0 as *mut _,
local: std::ptr::null_mut(),
local_len: 0,
remote: 0 as *mut _,
remote: std::ptr::null_mut(),
remote_len: 0,
_data: self,
};
Expand All @@ -657,7 +658,8 @@ impl AcceptAddrsBuf {

#[allow(deref_nullptr)]
fn args(&self) -> (*mut std::ffi::c_void, u32, u32, u32) {
let remote_offset = unsafe { &(*(0 as *const AcceptAddrsBuf)).remote as *const _ as usize };
let remote_offset =
unsafe { &(*(std::ptr::null::<AcceptAddrsBuf>())).remote as *const _ as usize };
(
self as *const _ as *mut _,
0,
Expand Down Expand Up @@ -747,3 +749,85 @@ pub fn accept_complete(me: RawSocket, socket: &TcpStream) -> io::Result<()> {
Err(io::Error::last_os_error())
}
}

/// A type to represent a buffer in which a socket address will be stored.
///
/// This type is used with the `recv_from_overlapped` function on the
/// `UdpSocketExt` trait to provide space for the overlapped I/O operation to
/// fill in the address upon completion.
#[derive(Clone, Copy)]
pub struct SocketAddrBuf {
buf: SOCKADDR_STORAGE,
len: i32,
}

impl SocketAddrBuf {
/// Creates a new blank socket address buffer.
///
/// This should be used before a call to `recv_from_overlapped` overlapped
/// to create an instance to pass down.
pub fn new() -> SocketAddrBuf {
SocketAddrBuf {
buf: unsafe { std::mem::zeroed() },
len: std::mem::size_of::<SOCKADDR_STORAGE>() as i32,
}
}

/// Parses this buffer to return a standard socket address.
///
/// This function should be called after the buffer has been filled in with
/// a call to `recv_from_overlapped` being completed. It will interpret the
/// address filled in and return the standard socket address type.
///
/// If an error is encountered then `None` is returned.
#[allow(clippy::wrong_self_convention)]
pub fn to_socket_addr(&self) -> Option<SocketAddr> {
unsafe { ptrs_to_socket_addr(&self.buf as *const _ as *const _, self.len) }
}
}

pub unsafe fn recv_from_overlapped(
socket: RawSocket,
buf: &mut [u8],
addr: *mut SocketAddrBuf,
overlapped: *mut OVERLAPPED,
) -> io::Result<Option<usize>> {
let buf = slice2buf(buf);
let mut flags = 0;
let mut received_bytes: u32 = 0;
let r = WSARecvFrom(
socket as SOCKET,
&buf,
1,
&mut received_bytes,
&mut flags,
&mut (*addr).buf as *mut _ as *mut _,
&mut (*addr).len,
overlapped,
None,
);
cvt(r, received_bytes)
}

pub unsafe fn send_to_overlapped(
socket: RawSocket,
buf: &[u8],
addr: &SocketAddr,
overlapped: *mut OVERLAPPED,
) -> io::Result<Option<usize>> {
let (addr_buf, addr_len) = socket_addr_to_ptrs(addr);
let buf = slice2buf(buf);
let mut sent_bytes = 0;
let r = WSASendTo(
socket as SOCKET,
&buf,
1,
&mut sent_bytes,
0,
addr_buf.as_ptr() as *const _,
addr_len,
overlapped,
None,
);
cvt(r, sent_bytes)
}
5 changes: 3 additions & 2 deletions src/io/sys/windows/net/udp_recv_from.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::os::windows::io::AsRawSocket;
#[cfg(feature = "io_timeout")]
use std::time::Duration;

use super::super::miow::{recv_from_overlapped, SocketAddrBuf};
use super::super::{co_io_result, EventData};
#[cfg(feature = "io_cancel")]
use crate::coroutine_impl::co_cancel_data;
Expand All @@ -13,7 +14,6 @@ use crate::io::cancel::CancelIoData;
use crate::net::UdpSocket;
use crate::scheduler::get_scheduler;
use crate::sync::delay_drop::DelayDrop;
use miow::net::{SocketAddrBuf, UdpSocketExt};
use windows_sys::Win32::Foundation::*;

pub struct UdpRecvFrom<'a> {
Expand Down Expand Up @@ -64,7 +64,8 @@ impl<'a> EventSource for UdpRecvFrom<'a> {
self.io_data.co = Some(co);
// call the overlapped read API
co_try!(s, self.io_data.co.take().expect("can't get co"), unsafe {
self.socket.recv_from_overlapped(
recv_from_overlapped(
self.socket.as_raw_socket(),
self.buf,
&mut self.addr,
self.io_data.get_overlapped(),
Expand Down
10 changes: 7 additions & 3 deletions src/io/sys/windows/net/udp_send_to.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ use std::os::windows::io::AsRawSocket;
#[cfg(feature = "io_timeout")]
use std::time::Duration;

use super::super::miow::send_to_overlapped;
use super::super::{co_io_result, EventData};
use crate::coroutine_impl::{is_coroutine, CoroutineImpl, EventSource};
use crate::net::UdpSocket;
use crate::scheduler::get_scheduler;
use miow::net::UdpSocketExt;
use windows_sys::Win32::Foundation::*;

pub struct UdpSendTo<'a> {
Expand Down Expand Up @@ -58,8 +58,12 @@ impl<'a> EventSource for UdpSendTo<'a> {
self.io_data.co = Some(co);
// call the overlapped read API
co_try!(s, self.io_data.co.take().expect("can't get co"), unsafe {
self.socket
.send_to_overlapped(self.buf, &self.addr, self.io_data.get_overlapped())
send_to_overlapped(
self.socket.as_raw_socket(),
self.buf,
&self.addr,
self.io_data.get_overlapped(),
)
});
}
}

0 comments on commit d3298c4

Please sign in to comment.