diff --git a/ext/fetch/dns.rs b/ext/fetch/dns.rs index e233021400..2c0af89dd4 100644 --- a/ext/fetch/dns.rs +++ b/ext/fetch/dns.rs @@ -3,6 +3,7 @@ use std::future::Future; use std::io; use std::net::SocketAddr; use std::pin::Pin; +use std::sync::Arc; use std::task::Poll; use std::task::{self}; use std::vec; @@ -19,6 +20,16 @@ pub enum Resolver { Gai(GaiResolver), /// hickory-resolver's userspace resolver. Hickory(hickory_resolver::Resolver), + /// A custom resolver that implements `Resolve`. + Custom(Arc), +} + +/// Alias for `Future` type returned by a custom DNS resolver. +pub type Resolving = + Pin> + Send>>; + +pub trait Resolve: Send + Sync + std::fmt::Debug { + fn resolve(&self, name: Name) -> Resolving; } impl Default for Resolver { @@ -107,7 +118,43 @@ impl Service for Resolver { Ok(iter) }) } + Resolver::Custom(resolver) => { + let resolver = resolver.clone(); + tokio::spawn(async move { resolver.resolve(name).await }) + } }; ResolveFut { inner: task } } } + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use super::*; + + // A resolver that resolves any name into the same address. + #[derive(Debug)] + struct DebugResolver(SocketAddr); + + impl Resolve for DebugResolver { + fn resolve(&self, _name: Name) -> Resolving { + let addr = self.0; + Box::pin(async move { Ok(vec![addr].into_iter()) }) + } + } + + #[tokio::test] + async fn custom_dns_resolver() { + let mut resolver = Resolver::Custom(Arc::new(DebugResolver( + "127.0.0.1:8080".parse().unwrap(), + ))); + let mut addr = resolver + .call(Name::from_str("foo.com").unwrap()) + .await + .unwrap(); + + let addr = addr.next().unwrap(); + assert_eq!(addr, "127.0.0.1:8080".parse().unwrap()); + } +}