Skip to content

Commit

Permalink
feat: add flags related methods
Browse files Browse the repository at this point in the history
  • Loading branch information
Sherlock-Holo committed Jul 1, 2024
1 parent d75b76e commit 9f894ab
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 39 deletions.
53 changes: 42 additions & 11 deletions compio-driver/src/iour/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@ use std::{io, os::fd::FromRawFd, pin::Pin, sync::Arc, task::Poll, time::Duration

use compio_log::{instrument, trace, warn};
use crossbeam_queue::SegQueue;
use io_uring::{
opcode::{AsyncCancel, PollAdd},
types::{Fd, SubmitArgs, Timespec},
IoUring,
};
pub(crate) use libc::{sockaddr_storage, socklen_t};

use crate::{syscall, AsyncifyPool, Entry, Key, OutEntries, ProactorBuilder};

cfg_if::cfg_if! {
if #[cfg(feature = "io-uring-cqe32")] {
use io_uring::cqueue::Entry32 as CEntry;
Expand All @@ -19,15 +28,6 @@ cfg_if::cfg_if! {
use io_uring::squeue::Entry as SEntry;
}
}
use io_uring::{
opcode::{AsyncCancel, PollAdd},
types::{Fd, SubmitArgs, Timespec},
IoUring,
};
pub(crate) use libc::{sockaddr_storage, socklen_t};

use crate::{syscall, AsyncifyPool, Entry, Key, OutEntries, ProactorBuilder};

pub(crate) mod op;

/// The created entry of [`OpCode`].
Expand Down Expand Up @@ -238,6 +238,37 @@ impl Driver {
}
}

pub fn push_flags<T: crate::sys::OpCode + 'static>(
&mut self,
op: &mut Key<T>,
) -> Poll<(io::Result<usize>, u32)> {
instrument!(compio_log::Level::TRACE, "push_flags", ?op);
let user_data = op.user_data();
let op_pin = op.as_op_pin();
trace!("push RawOp");
match op_pin.create_entry() {
OpEntry::Submission(entry) => {
#[allow(clippy::useless_conversion)]
if let Err(err) = self.push_raw(entry.user_data(user_data as _).into()) {
return Poll::Ready((Err(err), 0));
}
Poll::Pending
}
#[cfg(feature = "io-uring-sqe128")]
OpEntry::Submission128(entry) => {
if let Err(err) = self.push_raw(entry.user_data(user_data as _)) {
return Poll::Ready((Err(err), 0));
}
Poll::Pending
}
OpEntry::Blocking => match self.push_blocking(user_data) {
Err(err) => Poll::Ready((Err(err), 0)),
Ok(true) => Poll::Pending,
Ok(false) => Poll::Ready((Err(io::Error::from_raw_os_error(libc::EBUSY)), 0)),
},
}
}

fn push_blocking(&mut self, user_data: usize) -> io::Result<bool> {
let handle = self.handle()?;
let completed = self.pool_completed.clone();
Expand All @@ -247,7 +278,7 @@ impl Driver {
let mut op = unsafe { Key::<dyn crate::sys::OpCode>::new_unchecked(user_data) };
let op_pin = op.as_op_pin();
let res = op_pin.call_blocking();
completed.push(Entry::new(user_data, res));
completed.push(Entry::new(user_data, res, todo!("how to get flags?")));
handle.notify().ok();
})
.is_ok();
Expand Down Expand Up @@ -294,7 +325,7 @@ fn create_entry(entry: CEntry) -> Entry {
} else {
Ok(result as _)
};
Entry::new(entry.user_data() as _, result)
Entry::new(entry.user_data() as _, result, entry.flags())
}

fn timespec(duration: std::time::Duration) -> Timespec {
Expand Down
19 changes: 19 additions & 0 deletions compio-driver/src/key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub(crate) struct RawOp<T: ?Sized> {
// The metadata in `*mut RawOp<dyn OpCode>`
metadata: usize,
result: PushEntry<Option<Waker>, io::Result<usize>>,
flags: u32,
op: T,
}

Expand Down Expand Up @@ -84,6 +85,7 @@ impl<T: OpCode + 'static> Key<T> {
cancelled: false,
metadata: opcode_metadata::<T>(),
result: PushEntry::Pending(None),
flags: 0,
op,
});
unsafe { Self::new_unchecked(Box::into_raw(raw_op) as _) }
Expand Down Expand Up @@ -154,6 +156,10 @@ impl<T: ?Sized> Key<T> {
this.cancelled
}

pub(crate) fn set_flags(&mut self, flags: u32) {
self.as_opaque_mut().flags = flags;
}

/// Whether the op is completed.
pub(crate) fn has_result(&self) -> bool {
self.as_opaque().result.is_ready()
Expand Down Expand Up @@ -189,6 +195,19 @@ impl<T> Key<T> {
let op = unsafe { Box::from_raw(self.user_data as *mut RawOp<T>) };
BufResult(op.result.take_ready().unwrap_unchecked(), op.op)
}

/// Get the inner result and flags if it is completed.
///
/// # Safety
///
/// Call it only when the op is completed, otherwise it is UB.
pub(crate) unsafe fn into_inner_flags(self) -> (BufResult<usize, T>, u32) {
let op = unsafe { Box::from_raw(self.user_data as *mut RawOp<T>) };
(
BufResult(op.result.take_ready().unwrap_unchecked(), op.op),
op.flags,
)
}
}

impl<T: OpCode + ?Sized> Key<T> {
Expand Down
82 changes: 60 additions & 22 deletions compio-driver/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,38 +5,35 @@
#![cfg_attr(feature = "once_cell_try", feature(once_cell_try))]
#![warn(missing_docs)]

#[cfg(all(
target_os = "linux",
not(feature = "io-uring"),
not(feature = "polling")
))]
compile_error!("You must choose at least one of these features: [\"io-uring\", \"polling\"]");

use std::{
io,
task::{Poll, Waker},
time::Duration,
};

pub use asyncify::*;
use compio_buf::BufResult;
use compio_log::instrument;

mod key;
pub use fd::*;
pub use key::Key;

pub mod op;
#[cfg(unix)]
#[cfg_attr(docsrs, doc(cfg(all())))]
mod unix;
pub use sys::*;
#[cfg(unix)]
use unix::Overlapped;

mod asyncify;
pub use asyncify::*;
#[cfg(all(
target_os = "linux",
not(feature = "io-uring"),
not(feature = "polling")
))]
compile_error!("You must choose at least one of these features: [\"io-uring\", \"polling\"]");

mod asyncify;
mod fd;
pub use fd::*;

mod key;
pub mod op;
#[cfg(unix)]
#[cfg_attr(docsrs, doc(cfg(all())))]
mod unix;
cfg_if::cfg_if! {
if #[cfg(windows)] {
#[path = "iocp/mod.rs"]
Expand All @@ -53,8 +50,6 @@ cfg_if::cfg_if! {
}
}

pub use sys::*;

#[cfg(windows)]
#[macro_export]
#[doc(hidden)]
Expand Down Expand Up @@ -272,6 +267,24 @@ impl Proactor {
}
}

/// Push an operation into the driver, and return the unique key, called
/// user-defined data, associated with it.
pub fn push_flags<T: OpCode + 'static>(
&mut self,
op: T,
) -> PushEntry<Key<T>, (BufResult<usize, T>, u32)> {
let mut op = self.driver.create_op(op);
match self.driver.push_flags(&mut op) {
Poll::Pending => PushEntry::Pending(op),
Poll::Ready((res, flags)) => {
op.set_result(res);
op.set_flags(flags);
// SAFETY: just completed.
PushEntry::Ready(unsafe { op.into_inner_flags() })
}
}
}

/// Poll the driver and get completed entries.
/// You need to call [`Proactor::pop`] to get the pushed operations.
pub fn poll(
Expand Down Expand Up @@ -300,6 +313,21 @@ impl Proactor {
}
}

/// Get the pushed operations from the completion entries.
///
/// # Panics
/// This function will panic if the requested operation has not been
/// completed.
pub fn pop_flags<T>(&mut self, op: Key<T>) -> PushEntry<Key<T>, (BufResult<usize, T>, u32)> {
instrument!(compio_log::Level::DEBUG, "pop_flags", ?op);
if op.has_result() {
// SAFETY: completed.
PushEntry::Ready(unsafe { op.into_inner_flags() })
} else {
PushEntry::Pending(op)
}
}

/// Update the waker of the specified op.
pub fn update_waker<T>(&mut self, op: &mut Key<T>, waker: Waker) {
op.set_waker(waker);
Expand All @@ -322,18 +350,27 @@ impl AsRawFd for Proactor {
pub(crate) struct Entry {
user_data: usize,
result: io::Result<usize>,
flags: u32,
}

impl Entry {
pub(crate) fn new(user_data: usize, result: io::Result<usize>) -> Self {
Self { user_data, result }
pub(crate) fn new(user_data: usize, result: io::Result<usize>, flags: u32) -> Self {
Self {
user_data,
result,
flags,
}
}

/// The user-defined data returned by [`Proactor::push`].
pub fn user_data(&self) -> usize {
self.user_data
}

pub fn flags(&self) -> u32 {
self.flags
}

/// The result of the operation.
pub fn into_result(self) -> io::Result<usize> {
self.result
Expand All @@ -357,6 +394,7 @@ impl<E: Extend<usize>> Extend<Entry> for OutEntries<'_, E> {
self.entries.extend(iter.into_iter().filter_map(|e| {
let user_data = e.user_data();
let mut op = unsafe { Key::<()>::new_unchecked(user_data) };
op.set_flags(e.flags());
if op.set_result(e.into_result()) {
// SAFETY: completed and cancelled.
let _ = unsafe { op.into_box() };
Expand Down
50 changes: 44 additions & 6 deletions compio-runtime/src/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,21 @@ use compio_driver::{
use compio_log::{debug, instrument};
use crossbeam_queue::SegQueue;
use futures_util::{future::Either, FutureExt};
use send_wrapper::SendWrapper;
use smallvec::SmallVec;

#[cfg(feature = "time")]
use crate::runtime::time::{TimerFuture, TimerRuntime};
use crate::{
runtime::op::{OpFlagsFuture, OpFuture},
BufResult,
};

pub(crate) mod op;
#[cfg(feature = "time")]
pub(crate) mod time;

mod send_wrapper;
use send_wrapper::SendWrapper;

#[cfg(feature = "time")]
use crate::runtime::time::{TimerFuture, TimerRuntime};
use crate::{runtime::op::OpFuture, BufResult};

scoped_tls::scoped_thread_local!(static CURRENT_RUNTIME: Runtime);

/// Type alias for `Task<Result<T, Box<dyn Any + Send>>>`, which resolves to an
Expand Down Expand Up @@ -231,6 +233,13 @@ impl Runtime {
self.driver.borrow_mut().push(op)
}

fn submit_flags_raw<T: OpCode + 'static>(
&self,
op: T,
) -> PushEntry<Key<T>, (BufResult<usize, T>, u32)> {
self.driver.borrow_mut().push_flags(op)
}

/// Submit an operation to the runtime.
///
/// You only need this when authoring your own [`OpCode`].
Expand All @@ -241,6 +250,22 @@ impl Runtime {
}
}

/// Submit an operation to the runtime.
///
/// The difference between [`Runtime::submit`] is this method will return
/// the flags
///
/// You only need this when authoring your own [`OpCode`].
pub fn submit_flags<T: OpCode + 'static>(
&self,
op: T,
) -> impl Future<Output = (BufResult<usize, T>, u32)> {
match self.submit_flags_raw(op) {
PushEntry::Pending(user_data) => Either::Left(OpFlagsFuture::new(user_data)),
PushEntry::Ready(res) => Either::Right(ready(res)),
}
}

#[cfg(feature = "time")]
pub(crate) fn create_timer(&self, delay: std::time::Duration) -> impl Future<Output = ()> {
let mut timer_runtime = self.timer_runtime.borrow_mut();
Expand Down Expand Up @@ -273,6 +298,19 @@ impl Runtime {
})
}

pub(crate) fn poll_task_flags<T: OpCode>(
&self,
cx: &mut Context,
op: Key<T>,
) -> PushEntry<Key<T>, (BufResult<usize, T>, u32)> {
instrument!(compio_log::Level::DEBUG, "poll_task", ?op);
let mut driver = self.driver.borrow_mut();
driver.pop_flags(op).map_pending(|mut k| {
driver.update_waker(&mut k, cx.waker().clone());
k
})
}

#[cfg(feature = "time")]
pub(crate) fn poll_timer(&self, cx: &mut Context, key: usize) -> Poll<()> {
instrument!(compio_log::Level::DEBUG, "poll_timer", ?cx, ?key);
Expand Down
Loading

0 comments on commit 9f894ab

Please sign in to comment.