diff --git a/Cargo.lock b/Cargo.lock index 295fd3e13b..733b18459f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1779,9 +1779,9 @@ dependencies = [ "http", "hyper 0.14.27", "once_cell", + "rustls-tokio-stream", "serde", "tokio", - "tokio-rustls", ] [[package]] @@ -4567,6 +4567,17 @@ dependencies = [ "base64 0.21.4", ] +[[package]] +name = "rustls-tokio-stream" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8101c6e909600a3648a7774cb06837a5f976eb3265736d7135b4c177fa3020b9" +dependencies = [ + "futures", + "rustls", + "tokio", +] + [[package]] name = "rustls-webpki" version = "0.101.7" diff --git a/Cargo.toml b/Cargo.toml index 1f136401e5..61a8dc68f1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -128,6 +128,7 @@ ring = "^0.17.0" rusqlite = { version = "=0.29.0", features = ["unlock_notify", "bundled"] } rustls = "0.21.8" rustls-pemfile = "1.0.0" +rustls-tokio-stream = "0.2.4" rustls-webpki = "0.101.4" rustls-native-certs = "0.6.2" webpki-roots = "0.25.2" diff --git a/cli/tests/testdata/run/websocket_test.ts b/cli/tests/testdata/run/websocket_test.ts index d80f03c92a..43db15ccee 100644 --- a/cli/tests/testdata/run/websocket_test.ts +++ b/cli/tests/testdata/run/websocket_test.ts @@ -163,7 +163,7 @@ Deno.test("websocket error", async () => { // Error message got changed because we don't use warp in test_util assertEquals( err.message, - "InvalidData: received corrupt message of type InvalidContentType", + "InvalidData: invalid data", ); promise1.resolve(); }; diff --git a/cli/tests/unit/websocket_test.ts b/cli/tests/unit/websocket_test.ts index b761cd1183..8ae729d42b 100644 --- a/cli/tests/unit/websocket_test.ts +++ b/cli/tests/unit/websocket_test.ts @@ -277,3 +277,51 @@ Deno.test( } }, ); + +Deno.test(async function websocketTlsSocketWorks() { + const cert = await Deno.readTextFile("cli/tests/testdata/tls/localhost.crt"); + const key = await Deno.readTextFile("cli/tests/testdata/tls/localhost.key"); + + const messages: string[] = [], + errors: { server?: Event; client?: Event }[] = []; + const promise = new Promise((okay, nope) => { + const ac = new AbortController(); + const server = Deno.serve({ + handler: (req) => { + const { response, socket } = Deno.upgradeWebSocket(req); + socket.onopen = () => socket.send("ping"); + socket.onmessage = (e) => { + messages.push(e.data); + socket.close(); + }; + socket.onerror = (e) => errors.push({ server: e }); + socket.onclose = () => ac.abort(); + return response; + }, + signal: ac.signal, + hostname: "localhost", + port: servePort, + cert, + key, + }); + setTimeout(() => { + const ws = new WebSocket(`wss://localhost:${servePort}`); + ws.onmessage = (e) => { + messages.push(e.data); + ws.send("pong"); + }; + ws.onerror = (e) => { + errors.push({ client: e }); + nope(); + }; + ws.onclose = () => okay(server.finished); + }, 1000); + }); + + const finished = await promise; + + assertEquals(errors, []); + assertEquals(messages, ["ping", "pong"]); + + await finished; +}); diff --git a/ext/websocket/Cargo.toml b/ext/websocket/Cargo.toml index 7dd7a9afee..da29203c49 100644 --- a/ext/websocket/Cargo.toml +++ b/ext/websocket/Cargo.toml @@ -22,6 +22,6 @@ fastwebsockets = { workspace = true, features = ["upgrade", "unstable-split"] } http.workspace = true hyper = { workspace = true, features = ["backports"] } once_cell.workspace = true +rustls-tokio-stream.workspace = true serde.workspace = true tokio.workspace = true -tokio-rustls.workspace = true diff --git a/ext/websocket/lib.rs b/ext/websocket/lib.rs index 0f3456eef2..ac40b8304c 100644 --- a/ext/websocket/lib.rs +++ b/ext/websocket/lib.rs @@ -29,6 +29,9 @@ use http::Request; use http::Uri; use hyper::Body; use once_cell::sync::Lazy; +use rustls_tokio_stream::rustls::RootCertStore; +use rustls_tokio_stream::rustls::ServerName; +use rustls_tokio_stream::TlsStream; use serde::Serialize; use std::borrow::Cow; use std::cell::Cell; @@ -36,6 +39,7 @@ use std::cell::RefCell; use std::convert::TryFrom; use std::fmt; use std::future::Future; +use std::num::NonZeroUsize; use std::path::PathBuf; use std::rc::Rc; use std::sync::Arc; @@ -44,9 +48,6 @@ use tokio::io::AsyncWrite; use tokio::io::ReadHalf; use tokio::io::WriteHalf; use tokio::net::TcpStream; -use tokio_rustls::rustls::RootCertStore; -use tokio_rustls::rustls::ServerName; -use tokio_rustls::TlsConnector; use fastwebsockets::CloseCode; use fastwebsockets::FragmentCollectorRead; @@ -284,11 +285,16 @@ where unsafely_ignore_certificate_errors, None, )?; - let tls_connector = TlsConnector::from(Arc::new(tls_config)); let dnsname = ServerName::try_from(domain.as_str()) .map_err(|_| invalid_hostname(domain))?; - let tls_socket = tls_connector.connect(dnsname, tcp_socket).await?; - handshake(cancel_resource, request, tls_socket).await? + let mut tls_connector = TlsStream::new_client_side( + tcp_socket, + tls_config.into(), + dnsname, + NonZeroUsize::new(65536), + ); + let _hs = tls_connector.handshake().await?; + handshake(cancel_resource, request, tls_connector).await? } _ => unreachable!(), };