Skip to content

Commit

Permalink
feat: Add Connector trait
Browse files Browse the repository at this point in the history
  • Loading branch information
zu1k committed Dec 10, 2023
1 parent 4ddcc3a commit 3b5a805
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 2 deletions.
11 changes: 11 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ serde = { version = "1", features = ["derive"] }
env_logger = "0.10"
rustls = { version = "0.22.0" }
rustls-pemfile = { version = "2.0" }
socket2 = "0.5"

[[example]]
name = "cureq"
Expand Down
52 changes: 52 additions & 0 deletions examples/bind_connect.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use socket2::{Domain, Socket, Type};
use std::net::SocketAddr;
use ureq::Connector;

#[derive(Debug)]
pub(crate) struct BindConnector {
bind_addr: SocketAddr,
}

impl BindConnector {
pub fn new_bind(bind_addr: SocketAddr) -> Self {
Self { bind_addr }
}
}

impl Connector for BindConnector {
fn connect(&self, addr: &std::net::SocketAddr) -> std::io::Result<std::net::TcpStream> {
let socket = Socket::new(Domain::for_address(addr.to_owned()), Type::STREAM, None)?;
socket.bind(&self.bind_addr.into())?;
socket.connect(&addr.to_owned().into())?;
Ok(socket.into())
}

fn connect_timeout(
&self,
addr: &std::net::SocketAddr,
timeout: std::time::Duration,
) -> std::io::Result<std::net::TcpStream> {
let socket = Socket::new(Domain::for_address(addr.to_owned()), Type::STREAM, None)?;
socket.bind(&self.bind_addr.into())?;
socket.connect_timeout(&addr.to_owned().into(), timeout)?;
Ok(socket.into())
}
}

pub fn main() {
let agent = ureq::builder()
.connector(BindConnector::new_bind("127.0.0.1:54321".parse().unwrap()))
.build();

let result = agent.get("http://127.0.0.1:8080/").call();

match result {
Err(err) => {
println!("{:?}", err);
std::process::exit(1);
}
Ok(response) => {
assert_eq!(response.status(), 200);
}
}
}
15 changes: 15 additions & 0 deletions src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::sync::Arc;
use std::time::Duration;
use url::Url;

use crate::connect::{ArcConnector, StdTcpConnector};
use crate::middleware::Middleware;
use crate::pool::ConnectionPool;
use crate::proxy::Proxy;
Expand Down Expand Up @@ -45,6 +46,7 @@ pub struct AgentBuilder {
#[cfg(feature = "cookies")]
cookie_store: Option<CookieStore>,
resolver: ArcResolver,
connector: ArcConnector,
middleware: Vec<Box<dyn Middleware>>,
}

Expand Down Expand Up @@ -126,6 +128,7 @@ pub(crate) struct AgentState {
#[cfg(feature = "cookies")]
pub(crate) cookie_tin: CookieTin,
pub(crate) resolver: ArcResolver,
pub(crate) connector: ArcConnector,
pub(crate) middleware: Vec<Box<dyn Middleware>>,
}

Expand Down Expand Up @@ -271,6 +274,7 @@ impl AgentBuilder {
max_idle_connections: DEFAULT_MAX_IDLE_CONNECTIONS,
max_idle_connections_per_host: DEFAULT_MAX_IDLE_CONNECTIONS_PER_HOST,
resolver: StdResolver.into(),
connector: StdTcpConnector.into(),
#[cfg(feature = "cookies")]
cookie_store: None,
middleware: vec![],
Expand Down Expand Up @@ -298,6 +302,7 @@ impl AgentBuilder {
#[cfg(feature = "cookies")]
cookie_tin: CookieTin::new(self.cookie_store.unwrap_or_else(CookieStore::default)),
resolver: self.resolver,
connector: self.connector,
middleware: self.middleware,
}),
}
Expand Down Expand Up @@ -402,6 +407,16 @@ impl AgentBuilder {
self
}

/// Configures a custom connector to be used by this agent. By default,
/// tcp-connect is done by std::net::TcpStream. This allows you
/// to override that connection with your own alternative.
///
/// See `examples/bind_connect.rs` for example.
pub fn connector(mut self, connector: impl crate::Connector + 'static) -> Self {
self.connector = connector.into();
self
}

/// Timeout for the socket connection to be successful.
/// If both this and `.timeout()` are both set, `.timeout_connect()`
/// takes precedence.
Expand Down
47 changes: 47 additions & 0 deletions src/connect.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use std::fmt;
use std::io::Result as IoResult;
use std::net::{SocketAddr, TcpStream};
use std::sync::Arc;
use std::time::Duration;

/// A custom Connector to override the default TcpStream connector.
pub trait Connector: Send + Sync {
fn connect(&self, addr: &SocketAddr) -> IoResult<TcpStream> {
TcpStream::connect(addr)
}

fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> IoResult<TcpStream> {
TcpStream::connect_timeout(addr, timeout)
}
}

#[derive(Debug)]
pub(crate) struct StdTcpConnector;

impl Connector for StdTcpConnector {}

#[derive(Clone)]
pub(crate) struct ArcConnector(Arc<dyn Connector>);

impl<R> From<R> for ArcConnector
where
R: Connector + 'static,
{
fn from(r: R) -> Self {
Self(Arc::new(r))
}
}

impl fmt::Debug for ArcConnector {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "ArcConnector(...)")
}
}

impl std::ops::Deref for ArcConnector {
type Target = dyn Connector;

fn deref(&self) -> &Self::Target {
self.0.as_ref()
}
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@
mod agent;
mod body;
mod chunked;
mod connect;
mod error;
mod header;
mod middleware;
Expand Down Expand Up @@ -429,6 +430,7 @@ mod http_crate;
pub use crate::agent::Agent;
pub use crate::agent::AgentBuilder;
pub use crate::agent::RedirectAuthHeaders;
pub use crate::connect::Connector;
pub use crate::error::{Error, ErrorKind, OrAnyStatus, Transport};
pub use crate::middleware::{Middleware, MiddlewareNext};
pub use crate::proxy::Proxy;
Expand Down
4 changes: 2 additions & 2 deletions src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,9 +406,9 @@ pub(crate) fn connect_host(
proto.unwrap(),
)
} else if let Some(timeout) = timeout {
TcpStream::connect_timeout(&sock_addr, timeout)
unit.connector().connect_timeout(&sock_addr, timeout)
} else {
TcpStream::connect(sock_addr)
unit.connector().connect(&sock_addr)
};

if let Ok(stream) = stream {
Expand Down
5 changes: 5 additions & 0 deletions src/unit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use cookie::Cookie;

use crate::agent::RedirectAuthHeaders;
use crate::body::{self, BodySize, Payload, SizedReader};
use crate::connect::ArcConnector;
use crate::error::{Error, ErrorKind};
use crate::header;
use crate::header::{get_header, Header};
Expand Down Expand Up @@ -115,6 +116,10 @@ impl Unit {
self.agent.state.resolver.clone()
}

pub fn connector(&self) -> ArcConnector {
self.agent.state.connector.clone()
}

#[cfg(test)]
pub fn header(&self, name: &str) -> Option<&str> {
header::get_header(&self.headers, name)
Expand Down

0 comments on commit 3b5a805

Please sign in to comment.