diff --git a/compio-signal/Cargo.toml b/compio-signal/Cargo.toml index d3e67817..1fd3071a 100644 --- a/compio-signal/Cargo.toml +++ b/compio-signal/Cargo.toml @@ -18,23 +18,33 @@ rustdoc-args = ["--cfg", "docsrs"] # Workspace dependencies compio-runtime = { workspace = true, features = ["event"] } -once_cell = { workspace = true } -slab = { workspace = true } - # Windows specific dependencies [target.'cfg(windows)'.dependencies] compio-driver = { workspace = true } +once_cell = { workspace = true } +slab = { workspace = true } windows-sys = { workspace = true, features = [ "Win32_Foundation", "Win32_System_Console", ] } +# Linux specific dependencies +[target.'cfg(target_os = "linux")'.dependencies] +compio-buf = { workspace = true } +compio-driver = { workspace = true } + # Unix specific dependencies [target.'cfg(unix)'.dependencies] -os_pipe = { workspace = true } libc = { workspace = true } +[target.'cfg(all(unix, not(target_os = "linux")))'.dependencies] +once_cell = { workspace = true } +os_pipe = { workspace = true } +slab = { workspace = true } + [features] +default = [] +io-uring = ["compio-driver/io-uring"] # Nightly features lazy_cell = [] once_cell_try = [] diff --git a/compio-signal/src/lib.rs b/compio-signal/src/lib.rs index 6dff3033..1256fefd 100644 --- a/compio-signal/src/lib.rs +++ b/compio-signal/src/lib.rs @@ -23,6 +23,7 @@ pub mod windows; #[cfg(unix)] +#[cfg_attr(target_os = "linux", path = "linux.rs")] pub mod unix; /// Completes when a "ctrl-c" notification is sent to the process. diff --git a/compio-signal/src/linux.rs b/compio-signal/src/linux.rs new file mode 100644 index 00000000..95c31a3a --- /dev/null +++ b/compio-signal/src/linux.rs @@ -0,0 +1,128 @@ +//! Linux-specific types for signal handling. + +use std::{ + cell::RefCell, collections::HashMap, io, mem::MaybeUninit, os::fd::FromRawFd, ptr::null_mut, + thread_local, +}; + +use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, SetBufInit}; +use compio_driver::{op::Recv, syscall, OwnedFd, SharedFd}; + +thread_local! { + static REG_MAP: RefCell> = RefCell::new(HashMap::new()); +} + +fn sigset(sig: i32) -> io::Result { + let mut set: MaybeUninit = MaybeUninit::uninit(); + syscall!(libc::sigemptyset(set.as_mut_ptr()))?; + syscall!(libc::sigaddset(set.as_mut_ptr(), sig))?; + // SAFETY: sigemptyset initializes the set. + Ok(unsafe { set.assume_init() }) +} + +fn register_signal(sig: i32) -> io::Result { + REG_MAP.with_borrow_mut(|map| { + let count = map.entry(sig).or_default(); + let set = sigset(sig)?; + if *count == 0 { + syscall!(libc::pthread_sigmask(libc::SIG_BLOCK, &set, null_mut()))?; + } + *count += 1; + Ok(set) + }) +} + +fn unregister_signal(sig: i32) -> io::Result { + REG_MAP.with_borrow_mut(|map| { + let count = map.entry(sig).or_default(); + if *count > 0 { + *count -= 1; + } + let set = sigset(sig)?; + if *count == 0 { + syscall!(libc::pthread_sigmask(libc::SIG_UNBLOCK, &set, null_mut()))?; + } + Ok(set) + }) +} + +/// Represents a listener to unix signal event. +#[derive(Debug)] +struct SignalFd { + fd: SharedFd, + sig: i32, +} + +impl SignalFd { + fn new(sig: i32) -> io::Result { + let set = register_signal(sig)?; + let mut flag = libc::SFD_CLOEXEC; + if cfg!(not(feature = "io-uring")) { + flag |= libc::SFD_NONBLOCK; + } + let fd = syscall!(libc::signalfd(-1, &set, flag))?; + let fd = unsafe { OwnedFd::from_raw_fd(fd) }; + Ok(Self { + fd: SharedFd::new(fd), + sig, + }) + } + + async fn wait(self) -> io::Result<()> { + const INFO_SIZE: usize = std::mem::size_of::(); + + struct SignalInfo(MaybeUninit); + + unsafe impl IoBuf for SignalInfo { + fn as_buf_ptr(&self) -> *const u8 { + self.0.as_ptr().cast() + } + + fn buf_len(&self) -> usize { + 0 + } + + fn buf_capacity(&self) -> usize { + INFO_SIZE + } + } + + unsafe impl IoBufMut for SignalInfo { + fn as_buf_mut_ptr(&mut self) -> *mut u8 { + self.0.as_mut_ptr().cast() + } + } + + impl SetBufInit for SignalInfo { + unsafe fn set_buf_init(&mut self, len: usize) { + debug_assert!(len <= INFO_SIZE) + } + } + + let info = SignalInfo(MaybeUninit::::uninit()); + let op = Recv::new(self.fd.clone(), info); + let BufResult(res, op) = compio_runtime::submit(op).await; + let len = res?; + debug_assert_eq!(len, INFO_SIZE); + let info = op.into_inner(); + let info = unsafe { info.0.assume_init() }; + debug_assert_eq!(info.ssi_signo, self.sig as u32); + Ok(()) + } +} + +impl Drop for SignalFd { + fn drop(&mut self) { + unregister_signal(self.sig).ok(); + } +} + +/// Creates a new listener which will receive notifications when the current +/// process receives the specified signal. +/// +/// It sets the signal mask of the current thread. +pub async fn signal(sig: i32) -> io::Result<()> { + let fd = SignalFd::new(sig)?; + fd.wait().await?; + Ok(()) +} diff --git a/compio/Cargo.toml b/compio/Cargo.toml index 921e1692..9451e526 100644 --- a/compio/Cargo.toml +++ b/compio/Cargo.toml @@ -81,6 +81,7 @@ io-uring = [ "compio-driver/io-uring", "compio-fs?/io-uring", "compio-net?/io-uring", + "compio-signal?/io-uring", ] polling = ["compio-driver/polling"] io = ["dep:compio-io"]