diff --git a/Cargo.lock b/Cargo.lock index 1176afee70..40d9057916 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1913,6 +1913,7 @@ dependencies = [ "rustls-tokio-stream", "rustls-webpki", "serde", + "tokio", "webpki-roots", ] @@ -5457,9 +5458,9 @@ dependencies = [ [[package]] name = "rustls-tokio-stream" -version = "0.2.17" +version = "0.2.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ded7a36e8ac05b8ada77a84c5ceec95361942ee9dedb60a82f93f788a791aae8" +checksum = "c478c030dfd68498e6c59168d9eec4f8bead33152a5f3095ad4bdbdcea09d466" dependencies = [ "futures", "rustls", diff --git a/Cargo.toml b/Cargo.toml index 81953da719..ba5be99bb0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -150,7 +150,7 @@ ring = "^0.17.0" rusqlite = { version = "=0.29.0", features = ["unlock_notify", "bundled"] } rustls = "0.21.11" rustls-pemfile = "1.0.0" -rustls-tokio-stream = "=0.2.17" +rustls-tokio-stream = "=0.2.23" rustls-webpki = "0.101.4" rustyline = "=13.0.0" saffron = "=0.1.0" diff --git a/ext/fetch/lib.rs b/ext/fetch/lib.rs index 3e43370d3b..21ca040277 100644 --- a/ext/fetch/lib.rs +++ b/ext/fetch/lib.rs @@ -46,6 +46,7 @@ use deno_tls::RootCertStoreProvider; use data_url::DataUrl; use deno_tls::TlsKey; use deno_tls::TlsKeys; +use deno_tls::TlsKeysHolder; use http_v02::header::CONTENT_LENGTH; use http_v02::Uri; use reqwest::header::HeaderMap; @@ -80,7 +81,7 @@ pub struct Options { pub request_builder_hook: Option Result>, pub unsafely_ignore_certificate_errors: Option>, - pub client_cert_chain_and_key: Option, + pub client_cert_chain_and_key: TlsKeys, pub file_fetch_handler: Rc, } @@ -101,7 +102,7 @@ impl Default for Options { proxy: None, request_builder_hook: None, unsafely_ignore_certificate_errors: None, - client_cert_chain_and_key: None, + client_cert_chain_and_key: TlsKeys::Null, file_fetch_handler: Rc::new(DefaultFileFetchHandler), } } @@ -205,7 +206,11 @@ pub fn create_client_from_options( unsafely_ignore_certificate_errors: options .unsafely_ignore_certificate_errors .clone(), - client_cert_chain_and_key: options.client_cert_chain_and_key.clone(), + client_cert_chain_and_key: options + .client_cert_chain_and_key + .clone() + .try_into() + .unwrap_or_default(), pool_max_idle_per_host: None, pool_idle_timeout: None, http1: true, @@ -821,7 +826,7 @@ fn default_true() -> bool { pub fn op_fetch_custom_client( state: &mut OpState, #[serde] args: CreateHttpClientArgs, - #[cppgc] tls_keys: &deno_tls::TlsKeys, + #[cppgc] tls_keys: &TlsKeysHolder, ) -> Result where FP: FetchPermissions + 'static, @@ -832,11 +837,6 @@ where permissions.check_net_url(&url, "Deno.createHttpClient()")?; } - let client_cert_chain_and_key = match tls_keys { - TlsKeys::Null => None, - TlsKeys::Static(key) => Some(key.clone()), - }; - let options = state.borrow::(); let ca_certs = args .ca_certs @@ -853,7 +853,7 @@ where unsafely_ignore_certificate_errors: options .unsafely_ignore_certificate_errors .clone(), - client_cert_chain_and_key, + client_cert_chain_and_key: tls_keys.take().try_into().unwrap(), pool_max_idle_per_host: args.pool_max_idle_per_host, pool_idle_timeout: args.pool_idle_timeout.and_then( |timeout| match timeout { @@ -915,7 +915,7 @@ pub fn create_http_client( options.root_cert_store, options.ca_certs, options.unsafely_ignore_certificate_errors, - options.client_cert_chain_and_key, + options.client_cert_chain_and_key.into(), deno_tls::SocketUse::Http, )?; diff --git a/ext/kv/remote.rs b/ext/kv/remote.rs index 88127fc8fa..9d5e099c73 100644 --- a/ext/kv/remote.rs +++ b/ext/kv/remote.rs @@ -16,7 +16,7 @@ use deno_fetch::CreateHttpClientOptions; use deno_tls::rustls::RootCertStore; use deno_tls::Proxy; use deno_tls::RootCertStoreProvider; -use deno_tls::TlsKey; +use deno_tls::TlsKeys; use denokv_remote::MetadataEndpoint; use denokv_remote::Remote; use url::Url; @@ -27,7 +27,7 @@ pub struct HttpOptions { pub root_cert_store_provider: Option>, pub proxy: Option, pub unsafely_ignore_certificate_errors: Option>, - pub client_cert_chain_and_key: Option, + pub client_cert_chain_and_key: TlsKeys, } impl HttpOptions { @@ -135,7 +135,11 @@ impl DatabaseHandler unsafely_ignore_certificate_errors: options .unsafely_ignore_certificate_errors .clone(), - client_cert_chain_and_key: options.client_cert_chain_and_key.clone(), + client_cert_chain_and_key: options + .client_cert_chain_and_key + .clone() + .try_into() + .unwrap(), pool_max_idle_per_host: None, pool_idle_timeout: None, http1: false, diff --git a/ext/net/02_tls.js b/ext/net/02_tls.js index 0b775047f6..e51df7424a 100644 --- a/ext/net/02_tls.js +++ b/ext/net/02_tls.js @@ -6,6 +6,10 @@ import { op_net_accept_tls, op_net_connect_tls, op_net_listen_tls, + op_tls_cert_resolver_create, + op_tls_cert_resolver_poll, + op_tls_cert_resolver_resolve, + op_tls_cert_resolver_resolve_error, op_tls_handshake, op_tls_key_null, op_tls_key_static, @@ -16,6 +20,7 @@ const { Number, ObjectDefineProperty, TypeError, + SymbolFor, } = primordials; import { Conn, Listener } from "ext:deno_net/01_net.js"; @@ -87,9 +92,12 @@ async function connectTls({ keyFile, privateKey, }); + // TODO(mmastrac): We only expose this feature via symbol for now. This should actually be a feature + // in Deno.connectTls, however. + const serverName = arguments[0][serverNameSymbol] ?? null; const { 0: rid, 1: localAddr, 2: remoteAddr } = await op_net_connect_tls( { hostname, port }, - { certFile: deprecatedCertFile, caCerts, alpnProtocols }, + { certFile: deprecatedCertFile, caCerts, alpnProtocols, serverName }, keyPair, ); localAddr.transport = "tcp"; @@ -133,6 +141,10 @@ class TlsListener extends Listener { * interfaces. */ function hasTlsKeyPairOptions(options) { + // TODO(mmastrac): remove this temporary symbol when the API lands + if (options[resolverSymbol] !== undefined) { + return true; + } return (options.cert !== undefined || options.key !== undefined || options.certFile !== undefined || options.keyFile !== undefined || options.privateKey !== undefined || @@ -159,6 +171,11 @@ function loadTlsKeyPair(api, { privateKey = undefined; } + // TODO(mmastrac): remove this temporary symbol when the API lands + if (arguments[1][resolverSymbol] !== undefined) { + return createTlsKeyResolver(arguments[1][resolverSymbol]); + } + // Check for "pem" format if (keyFormat !== undefined && keyFormat !== "pem") { throw new TypeError('If `keyFormat` is specified, it must be "pem"'); @@ -275,6 +292,37 @@ async function startTls( return new TlsConn(rid, remoteAddr, localAddr); } +const resolverSymbol = SymbolFor("unstableSniResolver"); +const serverNameSymbol = SymbolFor("unstableServerName"); + +function createTlsKeyResolver(callback) { + const { 0: resolver, 1: lookup } = op_tls_cert_resolver_create(); + (async () => { + while (true) { + const sni = await op_tls_cert_resolver_poll(lookup); + if (typeof sni !== "string") { + break; + } + try { + const key = await callback(sni); + if (!hasTlsKeyPairOptions(key)) { + op_tls_cert_resolver_resolve_error(lookup, sni, "Invalid key"); + } else { + const resolved = loadTlsKeyPair("Deno.listenTls", key); + op_tls_cert_resolver_resolve(lookup, sni, resolved); + } + } catch (e) { + op_tls_cert_resolver_resolve_error(lookup, sni, e.message); + } + } + })(); + return resolver; +} + +internals.resolverSymbol = resolverSymbol; +internals.serverNameSymbol = serverNameSymbol; +internals.createTlsKeyResolver = createTlsKeyResolver; + export { connectTls, hasTlsKeyPairOptions, diff --git a/ext/net/lib.rs b/ext/net/lib.rs index d137aa315a..fa8074b345 100644 --- a/ext/net/lib.rs +++ b/ext/net/lib.rs @@ -87,6 +87,10 @@ deno_core::extension!(deno_net, ops_tls::op_tls_key_null, ops_tls::op_tls_key_static, ops_tls::op_tls_key_static_from_file

, + ops_tls::op_tls_cert_resolver_create, + ops_tls::op_tls_cert_resolver_poll, + ops_tls::op_tls_cert_resolver_resolve, + ops_tls::op_tls_cert_resolver_resolve_error, ops_tls::op_tls_start

, ops_tls::op_net_connect_tls

, ops_tls::op_net_listen_tls

, diff --git a/ext/net/ops_tls.rs b/ext/net/ops_tls.rs index 487adf3bc7..c529859087 100644 --- a/ext/net/ops_tls.rs +++ b/ext/net/ops_tls.rs @@ -11,6 +11,7 @@ use crate::DefaultTlsOptions; use crate::NetPermissions; use crate::UnsafelyIgnoreCertificateErrors; use deno_core::anyhow::anyhow; +use deno_core::anyhow::bail; use deno_core::error::bad_resource; use deno_core::error::custom_error; use deno_core::error::generic_error; @@ -29,13 +30,18 @@ use deno_core::ResourceId; use deno_tls::create_client_config; use deno_tls::load_certs; use deno_tls::load_private_keys; +use deno_tls::new_resolver; use deno_tls::rustls::Certificate; +use deno_tls::rustls::ClientConnection; use deno_tls::rustls::PrivateKey; use deno_tls::rustls::ServerConfig; use deno_tls::rustls::ServerName; +use deno_tls::ServerConfigProvider; use deno_tls::SocketUse; use deno_tls::TlsKey; +use deno_tls::TlsKeyLookup; use deno_tls::TlsKeys; +use deno_tls::TlsKeysHolder; use rustls_tokio_stream::TlsStreamRead; use rustls_tokio_stream::TlsStreamWrite; use serde::Deserialize; @@ -63,14 +69,26 @@ pub(crate) const TLS_BUFFER_SIZE: Option = pub struct TlsListener { pub(crate) tcp_listener: TcpListener, - pub(crate) tls_config: Arc, + pub(crate) tls_config: Option>, + pub(crate) server_config_provider: Option, } impl TlsListener { pub async fn accept(&self) -> std::io::Result<(TlsStream, SocketAddr)> { let (tcp, addr) = self.tcp_listener.accept().await?; - let tls = - TlsStream::new_server_side(tcp, self.tls_config.clone(), TLS_BUFFER_SIZE); + let tls = if let Some(provider) = &self.server_config_provider { + TlsStream::new_server_side_acceptor( + tcp, + provider.clone(), + TLS_BUFFER_SIZE, + ) + } else { + TlsStream::new_server_side( + tcp, + self.tls_config.clone().unwrap(), + TLS_BUFFER_SIZE, + ) + }; Ok((tls, addr)) } pub fn local_addr(&self) -> std::io::Result { @@ -164,6 +182,7 @@ pub struct ConnectTlsArgs { cert_file: Option, ca_certs: Vec, alpn_protocols: Option>, + server_name: Option, } #[derive(Deserialize)] @@ -179,7 +198,10 @@ pub struct StartTlsArgs { pub fn op_tls_key_null<'s>( scope: &mut v8::HandleScope<'s>, ) -> Result, AnyError> { - Ok(deno_core::cppgc::make_cppgc_object(scope, TlsKeys::Null)) + Ok(deno_core::cppgc::make_cppgc_object( + scope, + TlsKeysHolder::from(TlsKeys::Null), + )) } #[op2] @@ -195,7 +217,7 @@ pub fn op_tls_key_static<'s>( .unwrap(); Ok(deno_core::cppgc::make_cppgc_object( scope, - TlsKeys::Static(TlsKey(cert, key)), + TlsKeysHolder::from(TlsKeys::Static(TlsKey(cert, key))), )) } @@ -224,10 +246,53 @@ where .unwrap(); Ok(deno_core::cppgc::make_cppgc_object( scope, - TlsKeys::Static(TlsKey(cert, key)), + TlsKeysHolder::from(TlsKeys::Static(TlsKey(cert, key))), )) } +#[op2] +pub fn op_tls_cert_resolver_create<'s>( + scope: &mut v8::HandleScope<'s>, +) -> v8::Local<'s, v8::Array> { + let (resolver, lookup) = new_resolver(); + let resolver = deno_core::cppgc::make_cppgc_object( + scope, + TlsKeysHolder::from(TlsKeys::Resolver(resolver)), + ); + let lookup = deno_core::cppgc::make_cppgc_object(scope, lookup); + v8::Array::new_with_elements(scope, &[resolver.into(), lookup.into()]) +} + +#[op2(async)] +#[string] +pub async fn op_tls_cert_resolver_poll( + #[cppgc] lookup: &TlsKeyLookup, +) -> Option { + lookup.poll().await +} + +#[op2(fast)] +pub fn op_tls_cert_resolver_resolve( + #[cppgc] lookup: &TlsKeyLookup, + #[string] sni: String, + #[cppgc] key: &TlsKeysHolder, +) -> Result<(), AnyError> { + let TlsKeys::Static(key) = key.take() else { + bail!("unexpected key type"); + }; + lookup.resolve(sni, Ok(key)); + Ok(()) +} + +#[op2(fast)] +pub fn op_tls_cert_resolver_resolve_error( + #[cppgc] lookup: &TlsKeyLookup, + #[string] sni: String, + #[string] error: String, +) { + lookup.resolve(sni, Err(anyhow!(error))) +} + #[op2] #[serde] pub fn op_tls_start( @@ -287,7 +352,7 @@ where root_cert_store, ca_certs, unsafely_ignore_certificate_errors, - None, + TlsKeys::Null, SocketUse::GeneralSsl, )?; @@ -299,8 +364,7 @@ where let tls_config = Arc::new(tls_config); let tls_stream = TlsStream::new_client_side( tcp_stream, - tls_config, - hostname_dns, + ClientConnection::new(tls_config, hostname_dns)?, TLS_BUFFER_SIZE, ); @@ -320,7 +384,7 @@ pub async fn op_net_connect_tls( state: Rc>, #[serde] addr: IpAddr, #[serde] args: ConnectTlsArgs, - #[cppgc] key_pair: &TlsKeys, + #[cppgc] key_pair: &TlsKeysHolder, ) -> Result<(ResourceId, IpAddr, IpAddr), AnyError> where NP: NetPermissions + 'static, @@ -357,8 +421,12 @@ where .borrow() .borrow::() .root_cert_store()?; - let hostname_dns = ServerName::try_from(&*addr.hostname) - .map_err(|_| invalid_hostname(&addr.hostname))?; + let hostname_dns = if let Some(server_name) = args.server_name { + ServerName::try_from(server_name.as_str()) + } else { + ServerName::try_from(&*addr.hostname) + } + .map_err(|_| invalid_hostname(&addr.hostname))?; let connect_addr = resolve_addr(&addr.hostname, addr.port) .await? .next() @@ -367,15 +435,11 @@ where let local_addr = tcp_stream.local_addr()?; let remote_addr = tcp_stream.peer_addr()?; - let cert_and_key = match key_pair { - TlsKeys::Null => None, - TlsKeys::Static(key) => Some(key.clone()), - }; let mut tls_config = create_client_config( root_cert_store, ca_certs, unsafely_ignore_certificate_errors, - cert_and_key, + key_pair.take(), SocketUse::GeneralSsl, )?; @@ -388,8 +452,7 @@ where let tls_stream = TlsStream::new_client_side( tcp_stream, - tls_config, - hostname_dns, + ClientConnection::new(tls_config, hostname_dns)?, TLS_BUFFER_SIZE, ); @@ -429,7 +492,7 @@ pub fn op_net_listen_tls( state: &mut OpState, #[serde] addr: IpAddr, #[serde] args: ListenTlsArgs, - #[cppgc] keys: &TlsKeys, + #[cppgc] keys: &TlsKeysHolder, ) -> Result<(ResourceId, IpAddr), AnyError> where NP: NetPermissions + 'static, @@ -444,36 +507,44 @@ where .check_net(&(&addr.hostname, Some(addr.port)), "Deno.listenTls()")?; } - let tls_config = ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth(); - - let mut tls_config = match keys { - TlsKeys::Null => Err(anyhow!("Deno.listenTls requires a key")), - TlsKeys::Static(TlsKey(cert, key)) => tls_config - .with_single_cert(cert.clone(), key.clone()) - .map_err(|e| anyhow!(e)), - } - .map_err(|e| { - custom_error("InvalidData", "Error creating TLS certificate").context(e) - })?; - - if let Some(alpn_protocols) = args.alpn_protocols { - tls_config.alpn_protocols = - alpn_protocols.into_iter().map(|s| s.into_bytes()).collect(); - } - let bind_addr = resolve_addr_sync(&addr.hostname, addr.port)? .next() .ok_or_else(|| generic_error("No resolved address found"))?; let tcp_listener = TcpListener::bind_direct(bind_addr, args.reuse_port)?; let local_addr = tcp_listener.local_addr()?; + let alpn = args + .alpn_protocols + .unwrap_or_default() + .into_iter() + .map(|s| s.into_bytes()) + .collect(); + let listener = match keys.take() { + TlsKeys::Null => Err(anyhow!("Deno.listenTls requires a key")), + TlsKeys::Static(TlsKey(cert, key)) => { + let mut tls_config = ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(cert, key) + .map_err(|e| anyhow!(e))?; + tls_config.alpn_protocols = alpn; + Ok(TlsListener { + tcp_listener, + tls_config: Some(tls_config.into()), + server_config_provider: None, + }) + } + TlsKeys::Resolver(resolver) => Ok(TlsListener { + tcp_listener, + tls_config: None, + server_config_provider: Some(resolver.into_server_config_provider(alpn)), + }), + } + .map_err(|e| { + custom_error("InvalidData", "Error creating TLS certificate").context(e) + })?; - let tls_listener_resource = NetworkListenerResource::new(TlsListener { - tcp_listener, - tls_config: tls_config.into(), - }); + let tls_listener_resource = NetworkListenerResource::new(listener); let rid = state.resource_table.add(tls_listener_resource); diff --git a/ext/tls/Cargo.toml b/ext/tls/Cargo.toml index 6f587f1010..b809b4ebe8 100644 --- a/ext/tls/Cargo.toml +++ b/ext/tls/Cargo.toml @@ -22,4 +22,5 @@ rustls-pemfile.workspace = true rustls-tokio-stream.workspace = true rustls-webpki.workspace = true serde.workspace = true +tokio.workspace = true webpki-roots.workspace = true diff --git a/ext/tls/lib.rs b/ext/tls/lib.rs index 7e68971e2e..5122264bf1 100644 --- a/ext/tls/lib.rs +++ b/ext/tls/lib.rs @@ -30,6 +30,9 @@ use std::io::Cursor; use std::sync::Arc; use std::time::SystemTime; +mod tls_key; +pub use tls_key::*; + pub type Certificate = rustls::Certificate; pub type PrivateKey = rustls::PrivateKey; pub type RootCertStore = rustls::RootCertStore; @@ -175,7 +178,7 @@ pub fn create_client_config( root_cert_store: Option, ca_certs: Vec>, unsafely_ignore_certificate_errors: Option>, - maybe_cert_chain_and_key: Option, + maybe_cert_chain_and_key: TlsKeys, socket_use: SocketUse, ) -> Result { if let Some(ic_allowlist) = unsafely_ignore_certificate_errors { @@ -189,14 +192,13 @@ pub fn create_client_config( // However it's not really feasible to deduplicate it as the `client_config` instances // are not type-compatible - one wants "client cert", the other wants "transparency policy // or client cert". - let mut client = - if let Some(TlsKey(cert_chain, private_key)) = maybe_cert_chain_and_key { - client_config - .with_client_auth_cert(cert_chain, private_key) - .expect("invalid client key or certificate") - } else { - client_config.with_no_client_auth() - }; + let mut client = match maybe_cert_chain_and_key { + TlsKeys::Static(TlsKey(cert_chain, private_key)) => client_config + .with_client_auth_cert(cert_chain, private_key) + .expect("invalid client key or certificate"), + TlsKeys::Null => client_config.with_no_client_auth(), + TlsKeys::Resolver(_) => unimplemented!(), + }; add_alpn(&mut client, socket_use); return Ok(client); @@ -226,14 +228,13 @@ pub fn create_client_config( root_cert_store }); - let mut client = - if let Some(TlsKey(cert_chain, private_key)) = maybe_cert_chain_and_key { - client_config - .with_client_auth_cert(cert_chain, private_key) - .expect("invalid client key or certificate") - } else { - client_config.with_no_client_auth() - }; + let mut client = match maybe_cert_chain_and_key { + TlsKeys::Static(TlsKey(cert_chain, private_key)) => client_config + .with_client_auth_cert(cert_chain, private_key) + .expect("invalid client key or certificate"), + TlsKeys::Null => client_config.with_no_client_auth(), + TlsKeys::Resolver(_) => unimplemented!(), + }; add_alpn(&mut client, socket_use); Ok(client) @@ -325,15 +326,3 @@ pub fn load_private_keys(bytes: &[u8]) -> Result, AnyError> { Ok(keys) } - -/// A loaded key. -// FUTURE(mmastrac): add resolver enum value to support dynamic SNI -pub enum TlsKeys { - // TODO(mmastrac): We need Option<&T> for cppgc -- this is a workaround - Null, - Static(TlsKey), -} - -/// A TLS certificate/private key pair. -#[derive(Clone, Debug)] -pub struct TlsKey(pub Vec, pub PrivateKey); diff --git a/ext/tls/tls_key.rs b/ext/tls/tls_key.rs new file mode 100644 index 0000000000..18064a91a0 --- /dev/null +++ b/ext/tls/tls_key.rs @@ -0,0 +1,321 @@ +// Copyright 2018-2024 the Deno authors. All rights reserved. MIT license. + +//! These represent the various types of TLS keys we support for both client and server +//! connections. +//! +//! A TLS key will most often be static, and will loaded from a certificate and key file +//! or string. These are represented by `TlsKey`, which is stored in `TlsKeys::Static`. +//! +//! In more complex cases, you may need a `TlsKeyResolver`/`TlsKeyLookup` pair, which +//! requires polling of the `TlsKeyLookup` lookup queue. The underlying channels that used for +//! key lookup can handle closing one end of the pair, in which case they will just +//! attempt to clean up the associated resources. + +use crate::Certificate; +use crate::PrivateKey; +use deno_core::anyhow::anyhow; +use deno_core::error::AnyError; +use deno_core::futures::future::poll_fn; +use deno_core::futures::future::Either; +use deno_core::futures::FutureExt; +use deno_core::unsync::spawn; +use rustls::ServerConfig; +use rustls_tokio_stream::ServerConfigProvider; +use std::cell::RefCell; +use std::collections::HashMap; +use std::fmt::Debug; +use std::future::ready; +use std::future::Future; +use std::io::ErrorKind; +use std::rc::Rc; +use std::sync::Arc; +use tokio::sync::broadcast; +use tokio::sync::mpsc; +use tokio::sync::oneshot; + +type ErrorType = Rc; + +/// A TLS certificate/private key pair. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct TlsKey(pub Vec, pub PrivateKey); + +#[derive(Clone, Debug, Default)] +pub enum TlsKeys { + // TODO(mmastrac): We need Option<&T> for cppgc -- this is a workaround + #[default] + Null, + Static(TlsKey), + Resolver(TlsKeyResolver), +} + +pub struct TlsKeysHolder(RefCell); + +impl TlsKeysHolder { + pub fn take(&self) -> TlsKeys { + std::mem::take(&mut *self.0.borrow_mut()) + } +} + +impl From for TlsKeysHolder { + fn from(value: TlsKeys) -> Self { + TlsKeysHolder(RefCell::new(value)) + } +} + +impl TryInto> for TlsKeys { + type Error = Self; + fn try_into(self) -> Result, Self::Error> { + match self { + Self::Null => Ok(None), + Self::Static(key) => Ok(Some(key)), + Self::Resolver(_) => Err(self), + } + } +} + +impl From> for TlsKeys { + fn from(value: Option) -> Self { + match value { + None => TlsKeys::Null, + Some(key) => TlsKeys::Static(key), + } + } +} + +enum TlsKeyState { + Resolving(broadcast::Receiver>), + Resolved(Result), +} + +struct TlsKeyResolverInner { + resolution_tx: mpsc::UnboundedSender<( + String, + broadcast::Sender>, + )>, + cache: RefCell>, +} + +#[derive(Clone)] +pub struct TlsKeyResolver { + inner: Rc, +} + +impl TlsKeyResolver { + async fn resolve_internal( + &self, + sni: String, + alpn: Vec>, + ) -> Result, AnyError> { + let key = self.resolve(sni).await?; + + let mut tls_config = ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(key.0, key.1)?; + tls_config.alpn_protocols = alpn; + Ok(tls_config.into()) + } + + pub fn into_server_config_provider( + self, + alpn: Vec>, + ) -> ServerConfigProvider { + let (tx, mut rx) = mpsc::unbounded_channel::<(_, oneshot::Sender<_>)>(); + + // We don't want to make the resolver multi-threaded, but the `ServerConfigProvider` is + // required to be wrapped in an Arc. To fix this, we spawn a task in our current runtime + // to respond to the requests. + spawn(async move { + while let Some((sni, txr)) = rx.recv().await { + _ = txr.send(self.resolve_internal(sni, alpn.clone()).await); + } + }); + + Arc::new(move |hello| { + // Take ownership of the SNI information + let sni = hello.server_name().unwrap_or_default().to_owned(); + let (txr, rxr) = tokio::sync::oneshot::channel::<_>(); + _ = tx.send((sni, txr)); + rxr + .map(|res| match res { + Err(e) => Err(std::io::Error::new(ErrorKind::InvalidData, e)), + Ok(Err(e)) => Err(std::io::Error::new(ErrorKind::InvalidData, e)), + Ok(Ok(res)) => Ok(res), + }) + .boxed() + }) + } +} + +impl Debug for TlsKeyResolver { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TlsKeyResolver").finish() + } +} + +pub fn new_resolver() -> (TlsKeyResolver, TlsKeyLookup) { + let (resolution_tx, resolution_rx) = mpsc::unbounded_channel(); + ( + TlsKeyResolver { + inner: Rc::new(TlsKeyResolverInner { + resolution_tx, + cache: Default::default(), + }), + }, + TlsKeyLookup { + resolution_rx: RefCell::new(resolution_rx), + pending: Default::default(), + }, + ) +} + +impl TlsKeyResolver { + /// Resolve the certificate and key for a given host. This immediately spawns a task in the + /// background and is therefore cancellation-safe. + pub fn resolve( + &self, + sni: String, + ) -> impl Future> { + let mut cache = self.inner.cache.borrow_mut(); + let mut recv = match cache.get(&sni) { + None => { + let (tx, rx) = broadcast::channel(1); + cache.insert(sni.clone(), TlsKeyState::Resolving(rx.resubscribe())); + _ = self.inner.resolution_tx.send((sni.clone(), tx)); + rx + } + Some(TlsKeyState::Resolving(recv)) => recv.resubscribe(), + Some(TlsKeyState::Resolved(res)) => { + return Either::Left(ready(res.clone().map_err(|_| anyhow!("Failed")))); + } + }; + drop(cache); + + // Make this cancellation safe + let inner = self.inner.clone(); + let handle = spawn(async move { + let res = recv.recv().await?; + let mut cache = inner.cache.borrow_mut(); + match cache.get(&sni) { + None | Some(TlsKeyState::Resolving(..)) => { + cache.insert(sni, TlsKeyState::Resolved(res.clone())); + } + Some(TlsKeyState::Resolved(..)) => { + // Someone beat us to it + } + } + res.map_err(|_| anyhow!("Failed")) + }); + Either::Right(async move { handle.await? }) + } +} + +pub struct TlsKeyLookup { + #[allow(clippy::type_complexity)] + resolution_rx: RefCell< + mpsc::UnboundedReceiver<( + String, + broadcast::Sender>, + )>, + >, + pending: + RefCell>>>, +} + +impl TlsKeyLookup { + /// Multiple `poll` calls are safe, but this method is not starvation-safe. Generally + /// only one `poll`er should be active at any time. + pub async fn poll(&self) -> Option { + if let Some((sni, sender)) = + poll_fn(|cx| self.resolution_rx.borrow_mut().poll_recv(cx)).await + { + self.pending.borrow_mut().insert(sni.clone(), sender); + Some(sni) + } else { + None + } + } + + /// Resolve a previously polled item. + pub fn resolve(&self, sni: String, res: Result) { + _ = self + .pending + .borrow_mut() + .remove(&sni) + .unwrap() + .send(res.map_err(Rc::new)); + } +} + +#[cfg(test)] +pub mod tests { + use super::*; + use deno_core::unsync::spawn; + use rustls::Certificate; + use rustls::PrivateKey; + + fn tls_key_for_test(sni: &str) -> TlsKey { + TlsKey( + vec![Certificate(format!("{sni}-cert").into_bytes())], + PrivateKey(format!("{sni}-key").into_bytes()), + ) + } + + #[tokio::test] + async fn test_resolve_once() { + let (resolver, lookup) = new_resolver(); + let task = spawn(async move { + while let Some(sni) = lookup.poll().await { + lookup.resolve(sni.clone(), Ok(tls_key_for_test(&sni))); + } + }); + + let key = resolver.resolve("example.com".to_owned()).await.unwrap(); + assert_eq!(tls_key_for_test("example.com"), key); + drop(resolver); + + task.await.unwrap(); + } + + #[tokio::test] + async fn test_resolve_concurrent() { + let (resolver, lookup) = new_resolver(); + let task = spawn(async move { + while let Some(sni) = lookup.poll().await { + lookup.resolve(sni.clone(), Ok(tls_key_for_test(&sni))); + } + }); + + let f1 = resolver.resolve("example.com".to_owned()); + let f2 = resolver.resolve("example.com".to_owned()); + + let key = f1.await.unwrap(); + assert_eq!(tls_key_for_test("example.com"), key); + let key = f2.await.unwrap(); + assert_eq!(tls_key_for_test("example.com"), key); + drop(resolver); + + task.await.unwrap(); + } + + #[tokio::test] + async fn test_resolve_multiple_concurrent() { + let (resolver, lookup) = new_resolver(); + let task = spawn(async move { + while let Some(sni) = lookup.poll().await { + lookup.resolve(sni.clone(), Ok(tls_key_for_test(&sni))); + } + }); + + let f1 = resolver.resolve("example1.com".to_owned()); + let f2 = resolver.resolve("example2.com".to_owned()); + + let key = f1.await.unwrap(); + assert_eq!(tls_key_for_test("example1.com"), key); + let key = f2.await.unwrap(); + assert_eq!(tls_key_for_test("example2.com"), key); + drop(resolver); + + task.await.unwrap(); + } +} diff --git a/ext/websocket/lib.rs b/ext/websocket/lib.rs index e4df9d3d35..06a75faabd 100644 --- a/ext/websocket/lib.rs +++ b/ext/websocket/lib.rs @@ -23,8 +23,10 @@ use deno_core::ToJsBuffer; use deno_net::raw::NetworkStream; use deno_tls::create_client_config; use deno_tls::rustls::ClientConfig; +use deno_tls::rustls::ClientConnection; use deno_tls::RootCertStoreProvider; use deno_tls::SocketUse; +use deno_tls::TlsKeys; use http::header::CONNECTION; use http::header::UPGRADE; use http::HeaderName; @@ -236,8 +238,7 @@ async fn handshake_http1_wss( ServerName::try_from(domain).map_err(|_| invalid_hostname(domain))?; let mut tls_connector = TlsStream::new_client_side( tcp_socket, - tls_config.into(), - dnsname, + ClientConnection::new(tls_config.into(), dnsname)?, NonZeroUsize::new(65536), ); // If we can bail on an http/1.1 ALPN mismatch here, we can avoid doing extra work @@ -261,8 +262,11 @@ async fn handshake_http2_wss( let dnsname = ServerName::try_from(domain).map_err(|_| invalid_hostname(domain))?; // We need to better expose the underlying errors here - let mut tls_connector = - TlsStream::new_client_side(tcp_socket, tls_config.into(), dnsname, None); + let mut tls_connector = TlsStream::new_client_side( + tcp_socket, + ClientConnection::new(tls_config.into(), dnsname)?, + None, + ); let handshake = tls_connector.handshake().await?; if handshake.alpn.is_none() { bail!("Didn't receive h2 alpn, aborting connection"); @@ -332,7 +336,7 @@ pub fn create_ws_client_config( root_cert_store, vec![], unsafely_ignore_certificate_errors, - None, + TlsKeys::Null, socket_use, ) } diff --git a/runtime/web_worker.rs b/runtime/web_worker.rs index 0124b12a34..8360356940 100644 --- a/runtime/web_worker.rs +++ b/runtime/web_worker.rs @@ -47,6 +47,7 @@ use deno_io::Stdio; use deno_kv::dynamic::MultiBackendDbHandler; use deno_terminal::colors; use deno_tls::RootCertStoreProvider; +use deno_tls::TlsKeys; use deno_web::create_entangled_message_port; use deno_web::serialize_transferables; use deno_web::BlobStore; @@ -477,7 +478,7 @@ impl WebWorker { unsafely_ignore_certificate_errors: options .unsafely_ignore_certificate_errors .clone(), - client_cert_chain_and_key: None, + client_cert_chain_and_key: TlsKeys::Null, proxy: None, }, ), diff --git a/runtime/worker.rs b/runtime/worker.rs index a5fec16e47..1c291c6413 100644 --- a/runtime/worker.rs +++ b/runtime/worker.rs @@ -39,6 +39,7 @@ use deno_http::DefaultHttpPropertyExtractor; use deno_io::Stdio; use deno_kv::dynamic::MultiBackendDbHandler; use deno_tls::RootCertStoreProvider; +use deno_tls::TlsKeys; use deno_web::BlobStore; use log::debug; @@ -450,7 +451,7 @@ impl MainWorker { unsafely_ignore_certificate_errors: options .unsafely_ignore_certificate_errors .clone(), - client_cert_chain_and_key: None, + client_cert_chain_and_key: TlsKeys::Null, proxy: None, }, ), diff --git a/tests/integration/js_unit_tests.rs b/tests/integration/js_unit_tests.rs index 2bf78034e9..cbae4a0b8c 100644 --- a/tests/integration/js_unit_tests.rs +++ b/tests/integration/js_unit_tests.rs @@ -94,6 +94,7 @@ util::unit_test_factory!( text_encoding_test, timers_test, tls_test, + tls_sni_test, truncate_test, tty_color_test, tty_test, @@ -129,7 +130,7 @@ fn js_unit_test(test: String) { .arg("--no-prompt"); // TODO(mmastrac): it would be better to just load a test CA for all tests - let deno = if test == "websocket_test" { + let deno = if test == "websocket_test" || test == "tls_sni_test" { deno.arg("--unsafely-ignore-certificate-errors") } else { deno diff --git a/tests/integration/run_tests.rs b/tests/integration/run_tests.rs index 88ddfb3185..8a24603b32 100644 --- a/tests/integration/run_tests.rs +++ b/tests/integration/run_tests.rs @@ -13,6 +13,7 @@ use deno_core::serde_json::json; use deno_core::url; use deno_fetch::reqwest; use deno_tls::rustls; +use deno_tls::rustls::ClientConnection; use deno_tls::rustls_pemfile; use deno_tls::TlsStream; use pretty_assertions::assert_eq; @@ -5388,8 +5389,11 @@ async fn listen_tls_alpn() { let tcp_stream = tokio::net::TcpStream::connect("localhost:4504") .await .unwrap(); - let mut tls_stream = - TlsStream::new_client_side(tcp_stream, cfg, hostname, None); + let mut tls_stream = TlsStream::new_client_side( + tcp_stream, + ClientConnection::new(cfg, hostname).unwrap(), + None, + ); let handshake = tls_stream.handshake().await.unwrap(); @@ -5437,8 +5441,11 @@ async fn listen_tls_alpn_fail() { let tcp_stream = tokio::net::TcpStream::connect("localhost:4505") .await .unwrap(); - let mut tls_stream = - TlsStream::new_client_side(tcp_stream, cfg, hostname, None); + let mut tls_stream = TlsStream::new_client_side( + tcp_stream, + ClientConnection::new(cfg, hostname).unwrap(), + None, + ); tls_stream.handshake().await.unwrap_err(); diff --git a/tests/unit/tls_sni_test.ts b/tests/unit/tls_sni_test.ts new file mode 100644 index 0000000000..404f8016e3 --- /dev/null +++ b/tests/unit/tls_sni_test.ts @@ -0,0 +1,60 @@ +// Copyright 2018-2024 the Deno authors. All rights reserved. MIT license. +import { assertEquals, assertRejects } from "./test_util.ts"; +// @ts-expect-error TypeScript (as of 3.7) does not support indexing namespaces by symbol +const { resolverSymbol, serverNameSymbol } = Deno[Deno.internal]; + +const cert = Deno.readTextFileSync("tests/testdata/tls/localhost.crt"); +const key = Deno.readTextFileSync("tests/testdata/tls/localhost.key"); +const certEcc = Deno.readTextFileSync("tests/testdata/tls/localhost_ecc.crt"); +const keyEcc = Deno.readTextFileSync("tests/testdata/tls/localhost_ecc.key"); + +Deno.test( + { permissions: { net: true, read: true } }, + async function listenResolver() { + const sniRequests: string[] = []; + const keys: Record = { + "server-1": { cert, key }, + "server-2": { cert: certEcc, key: keyEcc }, + "fail-server-3": { cert: "(invalid)", key: "(bad)" }, + }; + const opts: unknown = { + hostname: "localhost", + port: 0, + [resolverSymbol]: (sni: string) => { + sniRequests.push(sni); + return keys[sni]!; + }, + }; + const listener = Deno.listenTls( + opts, + ); + + for ( + const server of ["server-1", "server-2", "fail-server-3", "fail-server-4"] + ) { + const conn = await Deno.connectTls({ + hostname: "localhost", + [serverNameSymbol]: server, + port: listener.addr.port, + }); + const serverConn = await listener.accept(); + if (server.startsWith("fail-")) { + await assertRejects(async () => await conn.handshake()); + await assertRejects(async () => await serverConn.handshake()); + } else { + await conn.handshake(); + await serverConn.handshake(); + } + conn.close(); + serverConn.close(); + } + + assertEquals(sniRequests, [ + "server-1", + "server-2", + "fail-server-3", + "fail-server-4", + ]); + listener.close(); + }, +);