Skip to content

Commit

Permalink
📝 peek implementation for unix (#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
Xudong-Huang committed Feb 25, 2024
1 parent 5e559f5 commit dd8ff74
Show file tree
Hide file tree
Showing 6 changed files with 218 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ socket2 = { version = "0.5", features = ["all"] }
may_queue = { version = "0.1", path = "may_queue" }

[target.'cfg(unix)'.dependencies]
nix = { version = "0.28", features = ["event"] }
nix = { version = "0.28", features = ["event", "socket"] }
libc = "0.2"

[target.'cfg(windows)'.dependencies]
Expand Down
31 changes: 31 additions & 0 deletions src/io/sys/unix/co_io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@ use std::time::Duration;

use self::io_impl::co_io_err::Error;
use self::io_impl::net as net_impl;
use super::from_nix_error;
use crate::io as io_impl;
#[cfg(feature = "io_timeout")]
use crate::sync::atomic_dur::AtomicDuration;
use crate::yield_now::yield_with_io;

use nix::sys::socket::{recv, MsgFlags};

fn set_nonblocking<T: AsRawFd>(fd: &T, nb: bool) -> io::Result<()> {
unsafe {
let fd = fd.as_raw_fd();
Expand Down Expand Up @@ -148,6 +151,34 @@ impl<T: AsRawFd> CoIo<T> {
self.write_timeout.store(dur);
Ok(())
}

/// Receives data on the socket from the remote address to which it is
/// connected, without removing that data from the queue. On success,
/// returns the number of bytes peeked.
pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
self.io.reset();
// this is an earlier return try for nonblocking read
// it's useful for server but not necessary for client
match recv(self.io.fd, buf, MsgFlags::MSG_PEEK) {
Ok(n) => return Ok(n),
Err(e) => {
if e == nix::errno::Errno::EAGAIN {
// do nothing
} else {
return Err(from_nix_error(e));
}
}
}

let mut reader = net_impl::SocketPeek::new(
self,
buf,
#[cfg(feature = "io_timeout")]
self.read_timeout.get(),
);
yield_with_io(&reader, reader.is_coroutine);
reader.done()
}
}

impl<T: AsRawFd + Read> Read for CoIo<T> {
Expand Down
2 changes: 2 additions & 0 deletions src/io/sys/unix/net/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod socket_peek;
mod socket_read;
mod socket_write;
mod socket_write_vectored;
Expand All @@ -10,6 +11,7 @@ mod unix_recv_from;
mod unix_send_to;
mod unix_stream_connect;

pub use self::socket_peek::SocketPeek;
pub use self::socket_read::SocketRead;
pub use self::socket_write::SocketWrite;
pub use self::socket_write_vectored::SocketWriteVectored;
Expand Down
102 changes: 102 additions & 0 deletions src/io/sys/unix/net/socket_peek.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
use std::io;
use std::sync::atomic::Ordering;
#[cfg(feature = "io_timeout")]
use std::time::Duration;

use super::super::{co_io_result, from_nix_error, IoData};
#[cfg(feature = "io_cancel")]
use crate::coroutine_impl::co_cancel_data;
use crate::coroutine_impl::{is_coroutine, CoroutineImpl, EventSource};
use crate::io::AsIoData;
use crate::yield_now::yield_with_io;

use nix::sys::socket::{recv, MsgFlags};

pub struct SocketPeek<'a> {
io_data: &'a IoData,
buf: &'a mut [u8],
#[cfg(feature = "io_timeout")]
timeout: Option<Duration>,
pub(crate) is_coroutine: bool,
}

impl<'a> SocketPeek<'a> {
pub fn new<T: AsIoData>(
s: &'a T,
buf: &'a mut [u8],
#[cfg(feature = "io_timeout")] timeout: Option<Duration>,
) -> Self {
SocketPeek {
io_data: s.as_io_data(),
buf,
#[cfg(feature = "io_timeout")]
timeout,
is_coroutine: is_coroutine(),
}
}

pub fn done(&mut self) -> io::Result<usize> {
loop {
co_io_result(self.is_coroutine)?;

// clear the io_flag
self.io_data.io_flag.store(false, Ordering::Relaxed);

// finish the read operation
match recv(self.io_data.fd, self.buf, MsgFlags::MSG_PEEK) {
Ok(n) => return Ok(n),
Err(e) => {
if e == nix::errno::Errno::EAGAIN {
// do nothing
} else {
return Err(from_nix_error(e));
}
}
}

if self.io_data.io_flag.load(Ordering::Relaxed) {
continue;
}

// the result is still WouldBlock, need to try again
yield_with_io(self, self.is_coroutine);
}
}
}

impl<'a> EventSource for SocketPeek<'a> {
fn subscribe(&mut self, co: CoroutineImpl) {
#[cfg(feature = "io_cancel")]
let cancel = co_cancel_data(&co);
let io_data = self.io_data;

#[cfg(feature = "io_timeout")]
if let Some(dur) = self.timeout {
crate::scheduler::get_scheduler()
.get_selector()
.add_io_timer(self.io_data, dur);
}

// after register the coroutine, it's possible that other thread run it immediately
// and cause the process after it invalid, this is kind of user and kernel competition
// so we need to delay the drop of the EventSource, that's why _g is here
unsafe { io_data.co.unsync_store(co) };
// till here the io may be done in other thread

// there is event, re-run the coroutine
if io_data.io_flag.load(Ordering::Acquire) {
#[allow(clippy::needless_return)]
return io_data.fast_schedule();
}

#[cfg(feature = "io_cancel")]
{
// register the cancel io data
cancel.set_io((*io_data).clone());
// re-check the cancel status
if cancel.is_canceled() {
unsafe { cancel.cancel() };
}
}
}
}
33 changes: 33 additions & 0 deletions src/net/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,39 @@ impl TcpStream {
write_timeout: AtomicDuration::new(None),
}
}

/// Receives data on the socket from the remote address to which it is
/// connected, without removing that data from the queue. On success,
/// returns the number of bytes peeked.
pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
#[cfg(unix)]
{
self._io.reset();
// this is an earlier return try for nonblocking read
// it's useful for server but not necessary for client
match self.sys.peek(buf) {
Ok(n) => return Ok(n),
Err(e) => {
// raw_os_error is faster than kind
let raw_err = e.raw_os_error();
if raw_err == Some(libc::EAGAIN) || raw_err == Some(libc::EWOULDBLOCK) {
// do nothing here
} else {
return Err(e);
}
}
}
}

let mut reader = net_impl::SocketPeek::new(
self,
buf,
#[cfg(feature = "io_timeout")]
self.read_timeout.get(),
);
yield_with_io(&reader, reader.is_coroutine);
reader.done()
}
}

impl Read for TcpStream {
Expand Down
49 changes: 49 additions & 0 deletions src/os/unix/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,13 @@ impl UnixStream {
self.0.write_timeout()
}

/// Receives data on the socket from the remote address to which it is
/// connected, without removing that data from the queue. On success,
/// returns the number of bytes peeked.
pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
self.0.peek(buf)
}

/// Returns the value of the `SO_ERROR` option.
///
/// # Examples
Expand Down Expand Up @@ -856,6 +863,13 @@ impl UnixDatagram {
reader.done()
}

/// Receives data on the socket from the remote address to which it is
/// connected, without removing that data from the queue. On success,
/// returns the number of bytes peeked.
pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
self.0.peek(buf)
}

/// Sends data on the socket to the specified address.
///
/// On success, returns the number of bytes written.
Expand Down Expand Up @@ -1394,4 +1408,39 @@ mod test {
fn abstract_namespace_not_allowed() {
assert!(UnixStream::connect("\0asdf").is_err());
}

#[test]
fn socket_peek() {
let msg1 = b"hello";
let msg2 = b"world!";

let (s1, s2) = or_panic!(UnixDatagram::pair());
or_panic!(s2.send(msg1));

let mut buf = [0; 5];
or_panic!(s1.peek(&mut buf));
assert_eq!(&msg1[..], &buf[..]);
buf.copy_from_slice(&[0; 5]);
or_panic!(s1.peek(&mut buf));
assert_eq!(&msg1[..], &buf[..]);
buf.copy_from_slice(&[0; 5]);
or_panic!(s1.peek(&mut buf));
assert_eq!(&msg1[..], &buf[..]);

or_panic!(s2.send(msg2));
let mut buf = [0; 11];
let n = s1.peek(&mut buf).unwrap();
assert_eq!(n, 5);
assert_eq!(&buf[0..n], &msg1[..]);
let n = s1.recv(&mut buf).unwrap();
assert_eq!(n, 5);
let n = s1.peek(&mut buf).unwrap();
assert_eq!(n, 6);
assert_eq!(&buf[0..n], &msg2[..]);
let n = s1.recv(&mut buf).unwrap();
assert_eq!(n, 6);
// // this would block until there is some data
// let n = s1.peek(&mut buf).unwrap();
// assert_eq!(n, 0);
}
}

0 comments on commit dd8ff74

Please sign in to comment.