diff --git a/Cargo.toml b/Cargo.toml index 0a00f90f..ad7520b3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,7 @@ compio-dispatcher = { path = "./compio-dispatcher", version = "0.3.0" } compio-log = { path = "./compio-log", version = "0.1.0" } compio-tls = { path = "./compio-tls", version = "0.2.0", default-features = false } compio-process = { path = "./compio-process", version = "0.1.0" } +compio-quic = { path = "./compio-quic", version = "0.1.0" } flume = "0.11.0" cfg-if = "1.0.0" diff --git a/compio-quic/Cargo.toml b/compio-quic/Cargo.toml index da99a0a9..795cf084 100644 --- a/compio-quic/Cargo.toml +++ b/compio-quic/Cargo.toml @@ -43,16 +43,27 @@ windows-sys = { workspace = true, features = ["Win32_Networking_WinSock"] } libc = { workspace = true } [dev-dependencies] +compio-dispatcher = { workspace = true } compio-driver = { workspace = true } compio-macros = { workspace = true } +compio-runtime = { workspace = true, features = ["criterion"] } + rand = "0.8.5" rcgen = "0.13.1" socket2 = { workspace = true, features = ["all"] } tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } +criterion = { workspace = true, features = ["async_tokio"] } +quinn = "0.11.3" +tokio = { workspace = true, features = ["rt", "macros"] } + [features] -default = ["webpki-roots"] -futures-io = ["futures-util/io"] +default = [] +io-compat = ["futures-util/io"] platform-verifier = ["dep:rustls-platform-verifier"] native-certs = ["dep:rustls-native-certs"] webpki-roots = ["dep:webpki-roots"] + +[[bench]] +name = "quic" +harness = false diff --git a/compio-quic/benches/quic.rs b/compio-quic/benches/quic.rs new file mode 100644 index 00000000..e318a069 --- /dev/null +++ b/compio-quic/benches/quic.rs @@ -0,0 +1,196 @@ +use std::{ + net::{IpAddr, Ipv4Addr, SocketAddr}, + sync::Arc, + time::Instant, +}; + +use bytes::Bytes; +use criterion::{criterion_group, criterion_main, Bencher, Criterion, Throughput}; +use futures_util::{stream::FuturesUnordered, StreamExt}; +use rand::{thread_rng, RngCore}; + +criterion_group!(quic, echo); +criterion_main!(quic); + +fn gen_cert() -> ( + rustls::pki_types::CertificateDer<'static>, + rustls::pki_types::PrivateKeyDer<'static>, +) { + let rcgen::CertifiedKey { cert, key_pair } = + rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); + let cert = cert.der().clone(); + let key_der = key_pair.serialize_der().try_into().unwrap(); + (cert, key_der) +} + +macro_rules! echo_impl { + ($send:ident, $recv:ident) => { + loop { + // These are 32 buffers, for reading approximately 32kB at once + let mut bufs: [Bytes; 32] = std::array::from_fn(|_| Bytes::new()); + + match $recv.read_chunks(&mut bufs).await.unwrap() { + Some(n) => { + $send.write_all_chunks(&mut bufs[..n]).await.unwrap(); + } + None => break, + } + } + + let _ = $send.finish(); + }; +} + +fn echo_compio_quic(b: &mut Bencher, content: &[u8], streams: usize) { + use compio_quic::{ClientBuilder, ServerBuilder}; + + let runtime = compio_runtime::Runtime::new().unwrap(); + b.to_async(runtime).iter_custom(|iter| async move { + let (cert, key_der) = gen_cert(); + let server = ServerBuilder::new_with_single_cert(vec![cert.clone()], key_der) + .unwrap() + .bind("127.0.0.1:0") + .await + .unwrap(); + let client = ClientBuilder::new_with_empty_roots() + .with_custom_certificate(cert) + .unwrap() + .with_no_crls() + .bind("127.0.0.1:0") + .await + .unwrap(); + let addr = server.local_addr().unwrap(); + + let (client_conn, server_conn) = futures_util::join!( + async move { + client + .connect(addr, "localhost", None) + .unwrap() + .await + .unwrap() + }, + async move { server.wait_incoming().await.unwrap().await.unwrap() } + ); + + let start = Instant::now(); + let handle = compio_runtime::spawn(async move { + while let Ok((mut send, mut recv)) = server_conn.accept_bi().await { + compio_runtime::spawn(async move { + echo_impl!(send, recv); + }) + .detach(); + } + }); + for _i in 0..iter { + let mut futures = (0..streams) + .map(|_| async { + let (mut send, mut recv) = client_conn.open_bi_wait().await.unwrap(); + futures_util::join!( + async { + send.write_all(content).await.unwrap(); + send.finish().unwrap(); + }, + async { + let mut buf = vec![]; + recv.read_to_end(&mut buf).await.unwrap(); + } + ); + }) + .collect::>(); + while futures.next().await.is_some() {} + } + drop(handle); + start.elapsed() + }) +} + +fn echo_quinn(b: &mut Bencher, content: &[u8], streams: usize) { + use quinn::{ClientConfig, Endpoint, ServerConfig}; + + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + b.to_async(&runtime).iter_custom(|iter| async move { + let (cert, key_der) = gen_cert(); + let server_config = ServerConfig::with_single_cert(vec![cert.clone()], key_der).unwrap(); + let mut roots = rustls::RootCertStore::empty(); + roots.add(cert).unwrap(); + let client_config = ClientConfig::with_root_certificates(Arc::new(roots)).unwrap(); + let server = Endpoint::server( + server_config, + SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), + ) + .unwrap(); + let mut client = + Endpoint::client(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)).unwrap(); + client.set_default_client_config(client_config); + let addr = server.local_addr().unwrap(); + + let (client_conn, server_conn) = futures_util::join!( + async move { client.connect(addr, "localhost").unwrap().await.unwrap() }, + async move { server.accept().await.unwrap().await.unwrap() } + ); + + let start = Instant::now(); + tokio::spawn(async move { + while let Ok((mut send, mut recv)) = server_conn.accept_bi().await { + tokio::spawn(async move { + echo_impl!(send, recv); + }); + } + }); + for _i in 0..iter { + let mut futures = (0..streams) + .map(|_| async { + let (mut send, mut recv) = client_conn.open_bi().await.unwrap(); + tokio::join!( + async { + send.write_all(content).await.unwrap(); + send.finish().unwrap(); + }, + async { + recv.read_to_end(usize::MAX).await.unwrap(); + } + ); + }) + .collect::>(); + while futures.next().await.is_some() {} + } + start.elapsed() + }); +} + +fn echo(c: &mut Criterion) { + let mut rng = thread_rng(); + + let mut large_data = [0u8; 1024 * 1024]; + rng.fill_bytes(&mut large_data); + + let mut small_data = [0u8; 10]; + rng.fill_bytes(&mut small_data); + + let mut group = c.benchmark_group("echo-large-data-1-stream"); + group.throughput(Throughput::Bytes((large_data.len() * 2) as u64)); + + group.bench_function("compio-quic", |b| echo_compio_quic(b, &large_data, 1)); + group.bench_function("quinn", |b| echo_quinn(b, &large_data, 1)); + + group.finish(); + + let mut group = c.benchmark_group("echo-large-data-10-streams"); + group.throughput(Throughput::Bytes((large_data.len() * 10 * 2) as u64)); + + group.bench_function("compio-quic", |b| echo_compio_quic(b, &large_data, 10)); + group.bench_function("quinn", |b| echo_quinn(b, &large_data, 10)); + + group.finish(); + + let mut group = c.benchmark_group("echo-small-data-100-streams"); + group.throughput(Throughput::Bytes((small_data.len() * 10 * 2) as u64)); + + group.bench_function("compio-quic", |b| echo_compio_quic(b, &small_data, 100)); + group.bench_function("quinn", |b| echo_quinn(b, &small_data, 100)); + + group.finish(); +} diff --git a/compio-quic/examples/dispatcher.rs b/compio-quic/examples/dispatcher.rs new file mode 100644 index 00000000..851debcf --- /dev/null +++ b/compio-quic/examples/dispatcher.rs @@ -0,0 +1,75 @@ +use std::num::NonZeroUsize; + +use compio_dispatcher::Dispatcher; +use compio_quic::{ClientBuilder, Endpoint, ServerBuilder}; +use compio_runtime::spawn; +use futures_util::{stream::FuturesUnordered, StreamExt}; + +#[compio_macros::main] +async fn main() { + const THREAD_NUM: usize = 5; + const CLIENT_NUM: usize = 10; + + let rcgen::CertifiedKey { cert, key_pair } = + rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); + let cert = cert.der().clone(); + let key_der = key_pair.serialize_der().try_into().unwrap(); + + let server_config = ServerBuilder::new_with_single_cert(vec![cert.clone()], key_der) + .unwrap() + .build(); + let client_config = ClientBuilder::new_with_empty_roots() + .with_custom_certificate(cert) + .unwrap() + .with_no_crls() + .build(); + let mut endpoint = Endpoint::server("127.0.0.1:0", server_config) + .await + .unwrap(); + endpoint.default_client_config = Some(client_config); + + spawn({ + let endpoint = endpoint.clone(); + async move { + let mut futures = FuturesUnordered::from_iter((0..CLIENT_NUM).map(|i| { + let endpoint = &endpoint; + async move { + let conn = endpoint + .connect(endpoint.local_addr().unwrap(), "localhost", None) + .unwrap() + .await + .unwrap(); + let mut send = conn.open_uni().unwrap(); + send.write_all(format!("Hello world {}!", i).as_bytes()) + .await + .unwrap(); + send.finish().unwrap(); + send.stopped().await.unwrap(); + } + })); + while let Some(()) = futures.next().await {} + } + }) + .detach(); + + let dispatcher = Dispatcher::builder() + .worker_threads(NonZeroUsize::new(THREAD_NUM).unwrap()) + .build() + .unwrap(); + let mut handles = FuturesUnordered::new(); + for _i in 0..CLIENT_NUM { + let incoming = endpoint.wait_incoming().await.unwrap(); + let handle = dispatcher + .dispatch(move || async move { + let conn = incoming.await.unwrap(); + let mut recv = conn.accept_uni().await.unwrap(); + let mut buf = vec![]; + recv.read_to_end(&mut buf).await.unwrap(); + println!("{}", std::str::from_utf8(&buf).unwrap()); + }) + .unwrap(); + handles.push(handle); + } + while handles.next().await.is_some() {} + dispatcher.join().await.unwrap(); +} diff --git a/compio-quic/examples/server.rs b/compio-quic/examples/server.rs index 9e52e8a9..3a380f88 100644 --- a/compio-quic/examples/server.rs +++ b/compio-quic/examples/server.rs @@ -7,11 +7,12 @@ async fn main() { .with_env_filter(EnvFilter::from_default_env()) .init(); - let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); - let cert_chain = vec![cert.cert.into()]; - let key_der = cert.key_pair.serialize_der().try_into().unwrap(); + let rcgen::CertifiedKey { cert, key_pair } = + rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); + let cert = cert.der().clone(); + let key_der = key_pair.serialize_der().try_into().unwrap(); - let endpoint = ServerBuilder::new_with_single_cert(cert_chain, key_der) + let endpoint = ServerBuilder::new_with_single_cert(vec![cert], key_der) .unwrap() .with_key_log() .bind("[::1]:4433") diff --git a/compio-quic/src/lib.rs b/compio-quic/src/lib.rs index 9144feb1..d73a3c65 100644 --- a/compio-quic/src/lib.rs +++ b/compio-quic/src/lib.rs @@ -25,7 +25,7 @@ pub use builder::{ClientBuilder, ServerBuilder}; pub use connection::{Connecting, Connection}; pub use endpoint::Endpoint; pub use incoming::{Incoming, IncomingFuture}; -pub use recv_stream::{ReadError, RecvStream}; +pub use recv_stream::{ReadError, ReadExactError, RecvStream}; pub use send_stream::{SendStream, WriteError}; pub(crate) use crate::{ diff --git a/compio-quic/src/recv_stream.rs b/compio-quic/src/recv_stream.rs index 5a12f99a..723c5075 100644 --- a/compio-quic/src/recv_stream.rs +++ b/compio-quic/src/recv_stream.rs @@ -8,7 +8,7 @@ use std::{ use bytes::{BufMut, Bytes}; use compio_buf::{BufResult, IoBufMut}; use compio_io::AsyncRead; -use futures_util::future::poll_fn; +use futures_util::{future::poll_fn, ready}; use quinn_proto::{Chunk, Chunks, ClosedStream, ConnectionError, ReadableError, StreamId, VarInt}; use thiserror::Error; @@ -261,6 +261,23 @@ impl RecvStream { poll_fn(|cx| self.poll_read(cx, &mut buf)).await } + /// Read an exact number of bytes contiguously from the stream. + /// + /// See [`read()`] for details. This operation is *not* cancel-safe. + /// + /// [`read()`]: RecvStream::read + pub async fn read_exact(&mut self, mut buf: impl BufMut) -> Result<(), ReadExactError> { + poll_fn(|cx| { + while buf.has_remaining_mut() { + if ready!(self.poll_read(cx, &mut buf))?.is_none() { + return Poll::Ready(Err(ReadExactError::FinishedEarly(buf.remaining_mut()))); + } + } + Poll::Ready(Ok(())) + }) + .await + } + /// Read the next segment of data. /// /// Yields `None` if the stream was finished. Otherwise, yields a segment of @@ -470,6 +487,17 @@ impl From for io::Error { } } +/// Errors that arise from reading from a stream. +#[derive(Debug, Error, Clone, PartialEq, Eq)] +pub enum ReadExactError { + /// The stream finished before all bytes were read + #[error("stream finished early (expected {0} bytes more)")] + FinishedEarly(usize), + /// A read error occurred + #[error(transparent)] + ReadError(#[from] ReadError), +} + impl AsyncRead for RecvStream { async fn read(&mut self, mut buf: B) -> BufResult { let res = self @@ -485,7 +513,7 @@ impl AsyncRead for RecvStream { } } -#[cfg(feature = "futures-io")] +#[cfg(feature = "io-compat")] impl futures_util::AsyncRead for RecvStream { fn poll_read( self: std::pin::Pin<&mut Self>, diff --git a/compio-quic/src/send_stream.rs b/compio-quic/src/send_stream.rs index bf8fc41a..7801e726 100644 --- a/compio-quic/src/send_stream.rs +++ b/compio-quic/src/send_stream.rs @@ -341,7 +341,7 @@ impl AsyncWrite for SendStream { } } -#[cfg(feature = "futures-io")] +#[cfg(feature = "io-compat")] impl futures_util::AsyncWrite for SendStream { fn poll_write( self: std::pin::Pin<&mut Self>, diff --git a/compio-quic/tests/common/mod.rs b/compio-quic/tests/common/mod.rs index 05fbe3f0..08745b3d 100644 --- a/compio-quic/tests/common/mod.rs +++ b/compio-quic/tests/common/mod.rs @@ -12,15 +12,16 @@ pub fn subscribe() -> DefaultGuard { } pub fn config_pair(transport: Option) -> (ServerConfig, ClientConfig) { - let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); - let cert_chain = vec![cert.cert.der().clone()]; - let key_der = cert.key_pair.serialize_der().try_into().unwrap(); + let rcgen::CertifiedKey { cert, key_pair } = + rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); + let cert = cert.der().clone(); + let key_der = key_pair.serialize_der().try_into().unwrap(); - let mut server_config = ServerBuilder::new_with_single_cert(cert_chain, key_der) + let mut server_config = ServerBuilder::new_with_single_cert(vec![cert.clone()], key_der) .unwrap() .build(); let mut client_config = ClientBuilder::new_with_empty_roots() - .with_custom_certificate(cert.cert.into()) + .with_custom_certificate(cert) .unwrap() .with_no_crls() .build(); diff --git a/compio-quic/tests/echo.rs b/compio-quic/tests/echo.rs index 1ab51364..69d942ad 100644 --- a/compio-quic/tests/echo.rs +++ b/compio-quic/tests/echo.rs @@ -1,4 +1,7 @@ -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::{ + array, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, +}; use bytes::Bytes; use compio_quic::{Endpoint, RecvStream, SendStream, TransportConfig}; @@ -20,17 +23,7 @@ struct EchoArgs { async fn echo((mut send, mut recv): (SendStream, RecvStream)) { loop { // These are 32 buffers, for reading approximately 32kB at once - #[rustfmt::skip] - let mut bufs = [ - Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), - Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), - Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), - Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), - Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), - Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), - Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), - Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), - ]; + let mut bufs: [Bytes; 32] = array::from_fn(|_| Bytes::new()); match recv.read_chunks(&mut bufs).await.unwrap() { Some(n) => { diff --git a/compio/Cargo.toml b/compio/Cargo.toml index 692d7ac1..8cbb3715 100644 --- a/compio/Cargo.toml +++ b/compio/Cargo.toml @@ -42,6 +42,7 @@ compio-dispatcher = { workspace = true, optional = true } compio-log = { workspace = true } compio-tls = { workspace = true, optional = true } compio-process = { workspace = true, optional = true } +compio-quic = { workspace = true, optional = true } # Shared dev dependencies for all platforms [dev-dependencies] @@ -83,7 +84,7 @@ io-uring = [ ] polling = ["compio-driver/polling"] io = ["dep:compio-io"] -io-compat = ["io", "compio-io/compat"] +io-compat = ["io", "compio-io/compat", "compio-quic/io-compat"] runtime = ["dep:compio-runtime", "dep:compio-fs", "dep:compio-net", "io"] macros = ["dep:compio-macros", "runtime"] event = ["compio-runtime/event", "runtime"] @@ -94,6 +95,7 @@ tls = ["dep:compio-tls"] native-tls = ["tls", "compio-tls/native-tls"] rustls = ["tls", "compio-tls/rustls"] process = ["dep:compio-process"] +quic = ["dep:compio-quic"] all = [ "time", "macros", @@ -102,6 +104,7 @@ all = [ "native-tls", "rustls", "process", + "quic", ] arrayvec = ["compio-buf/arrayvec"] diff --git a/compio/src/lib.rs b/compio/src/lib.rs index 244d8b37..8b6c5c09 100644 --- a/compio/src/lib.rs +++ b/compio/src/lib.rs @@ -41,6 +41,9 @@ pub use compio_macros::*; #[cfg(feature = "process")] #[doc(inline)] pub use compio_process as process; +#[cfg(feature = "quic")] +#[doc(inline)] +pub use compio_quic as quic; #[cfg(feature = "signal")] #[doc(inline)] pub use compio_signal as signal;