From 3b5a8053d9b9547deca6817fcac99f4cc1b24323 Mon Sep 17 00:00:00 2001 From: zu1k Date: Sun, 10 Dec 2023 17:58:08 +0800 Subject: [PATCH] feat: Add Connector trait --- Cargo.lock | 11 +++++++++ Cargo.toml | 1 + examples/bind_connect.rs | 52 ++++++++++++++++++++++++++++++++++++++++ src/agent.rs | 15 ++++++++++++ src/connect.rs | 47 ++++++++++++++++++++++++++++++++++++ src/lib.rs | 2 ++ src/stream.rs | 4 ++-- src/unit.rs | 5 ++++ 8 files changed, 135 insertions(+), 2 deletions(-) create mode 100644 examples/bind_connect.rs create mode 100644 src/connect.rs diff --git a/Cargo.lock b/Cargo.lock index a708ff82..f1419fd3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -662,6 +662,16 @@ dependencies = [ "serde", ] +[[package]] +name = "socket2" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" +dependencies = [ + "libc", + "windows-sys", +] + [[package]] name = "socks" version = "0.3.4" @@ -811,6 +821,7 @@ dependencies = [ "rustls-webpki", "serde", "serde_json", + "socket2", "socks", "url", "webpki-roots", diff --git a/Cargo.toml b/Cargo.toml index e2e1d116..d2609dff 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -59,6 +59,7 @@ serde = { version = "1", features = ["derive"] } env_logger = "0.10" rustls = { version = "0.22.0" } rustls-pemfile = { version = "2.0" } +socket2 = "0.5" [[example]] name = "cureq" diff --git a/examples/bind_connect.rs b/examples/bind_connect.rs new file mode 100644 index 00000000..57e327c4 --- /dev/null +++ b/examples/bind_connect.rs @@ -0,0 +1,52 @@ +use socket2::{Domain, Socket, Type}; +use std::net::SocketAddr; +use ureq::Connector; + +#[derive(Debug)] +pub(crate) struct BindConnector { + bind_addr: SocketAddr, +} + +impl BindConnector { + pub fn new_bind(bind_addr: SocketAddr) -> Self { + Self { bind_addr } + } +} + +impl Connector for BindConnector { + fn connect(&self, addr: &std::net::SocketAddr) -> std::io::Result { + let socket = Socket::new(Domain::for_address(addr.to_owned()), Type::STREAM, None)?; + socket.bind(&self.bind_addr.into())?; + socket.connect(&addr.to_owned().into())?; + Ok(socket.into()) + } + + fn connect_timeout( + &self, + addr: &std::net::SocketAddr, + timeout: std::time::Duration, + ) -> std::io::Result { + let socket = Socket::new(Domain::for_address(addr.to_owned()), Type::STREAM, None)?; + socket.bind(&self.bind_addr.into())?; + socket.connect_timeout(&addr.to_owned().into(), timeout)?; + Ok(socket.into()) + } +} + +pub fn main() { + let agent = ureq::builder() + .connector(BindConnector::new_bind("127.0.0.1:54321".parse().unwrap())) + .build(); + + let result = agent.get("http://127.0.0.1:8080/").call(); + + match result { + Err(err) => { + println!("{:?}", err); + std::process::exit(1); + } + Ok(response) => { + assert_eq!(response.status(), 200); + } + } +} diff --git a/src/agent.rs b/src/agent.rs index 986d3f92..fbf4da4a 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use std::time::Duration; use url::Url; +use crate::connect::{ArcConnector, StdTcpConnector}; use crate::middleware::Middleware; use crate::pool::ConnectionPool; use crate::proxy::Proxy; @@ -45,6 +46,7 @@ pub struct AgentBuilder { #[cfg(feature = "cookies")] cookie_store: Option, resolver: ArcResolver, + connector: ArcConnector, middleware: Vec>, } @@ -126,6 +128,7 @@ pub(crate) struct AgentState { #[cfg(feature = "cookies")] pub(crate) cookie_tin: CookieTin, pub(crate) resolver: ArcResolver, + pub(crate) connector: ArcConnector, pub(crate) middleware: Vec>, } @@ -271,6 +274,7 @@ impl AgentBuilder { max_idle_connections: DEFAULT_MAX_IDLE_CONNECTIONS, max_idle_connections_per_host: DEFAULT_MAX_IDLE_CONNECTIONS_PER_HOST, resolver: StdResolver.into(), + connector: StdTcpConnector.into(), #[cfg(feature = "cookies")] cookie_store: None, middleware: vec![], @@ -298,6 +302,7 @@ impl AgentBuilder { #[cfg(feature = "cookies")] cookie_tin: CookieTin::new(self.cookie_store.unwrap_or_else(CookieStore::default)), resolver: self.resolver, + connector: self.connector, middleware: self.middleware, }), } @@ -402,6 +407,16 @@ impl AgentBuilder { self } + /// Configures a custom connector to be used by this agent. By default, + /// tcp-connect is done by std::net::TcpStream. This allows you + /// to override that connection with your own alternative. + /// + /// See `examples/bind_connect.rs` for example. + pub fn connector(mut self, connector: impl crate::Connector + 'static) -> Self { + self.connector = connector.into(); + self + } + /// Timeout for the socket connection to be successful. /// If both this and `.timeout()` are both set, `.timeout_connect()` /// takes precedence. diff --git a/src/connect.rs b/src/connect.rs new file mode 100644 index 00000000..475964c6 --- /dev/null +++ b/src/connect.rs @@ -0,0 +1,47 @@ +use std::fmt; +use std::io::Result as IoResult; +use std::net::{SocketAddr, TcpStream}; +use std::sync::Arc; +use std::time::Duration; + +/// A custom Connector to override the default TcpStream connector. +pub trait Connector: Send + Sync { + fn connect(&self, addr: &SocketAddr) -> IoResult { + TcpStream::connect(addr) + } + + fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> IoResult { + TcpStream::connect_timeout(addr, timeout) + } +} + +#[derive(Debug)] +pub(crate) struct StdTcpConnector; + +impl Connector for StdTcpConnector {} + +#[derive(Clone)] +pub(crate) struct ArcConnector(Arc); + +impl From for ArcConnector +where + R: Connector + 'static, +{ + fn from(r: R) -> Self { + Self(Arc::new(r)) + } +} + +impl fmt::Debug for ArcConnector { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "ArcConnector(...)") + } +} + +impl std::ops::Deref for ArcConnector { + type Target = dyn Connector; + + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } +} diff --git a/src/lib.rs b/src/lib.rs index 70d42e49..5a9b67f4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -357,6 +357,7 @@ mod agent; mod body; mod chunked; +mod connect; mod error; mod header; mod middleware; @@ -429,6 +430,7 @@ mod http_crate; pub use crate::agent::Agent; pub use crate::agent::AgentBuilder; pub use crate::agent::RedirectAuthHeaders; +pub use crate::connect::Connector; pub use crate::error::{Error, ErrorKind, OrAnyStatus, Transport}; pub use crate::middleware::{Middleware, MiddlewareNext}; pub use crate::proxy::Proxy; diff --git a/src/stream.rs b/src/stream.rs index 8e7a36ee..ed0d77ad 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -406,9 +406,9 @@ pub(crate) fn connect_host( proto.unwrap(), ) } else if let Some(timeout) = timeout { - TcpStream::connect_timeout(&sock_addr, timeout) + unit.connector().connect_timeout(&sock_addr, timeout) } else { - TcpStream::connect(sock_addr) + unit.connector().connect(&sock_addr) }; if let Ok(stream) = stream { diff --git a/src/unit.rs b/src/unit.rs index 13c3c9a6..ba70d1a3 100644 --- a/src/unit.rs +++ b/src/unit.rs @@ -12,6 +12,7 @@ use cookie::Cookie; use crate::agent::RedirectAuthHeaders; use crate::body::{self, BodySize, Payload, SizedReader}; +use crate::connect::ArcConnector; use crate::error::{Error, ErrorKind}; use crate::header; use crate::header::{get_header, Header}; @@ -115,6 +116,10 @@ impl Unit { self.agent.state.resolver.clone() } + pub fn connector(&self) -> ArcConnector { + self.agent.state.connector.clone() + } + #[cfg(test)] pub fn header(&self, name: &str) -> Option<&str> { header::get_header(&self.headers, name)