From fafb2584efec33152fbe353d94151fa36004586a Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Sun, 23 Apr 2023 14:07:37 -0600 Subject: [PATCH] refactor(ext/websocket): Remove dep on tungstenite by reworking code (#18812) --- ext/websocket/lib.rs | 57 ++++++++++++++++++++++++----------------- ext/websocket/stream.rs | 15 +++++------ 2 files changed, 40 insertions(+), 32 deletions(-) diff --git a/ext/websocket/lib.rs b/ext/websocket/lib.rs index 71aa66ff38..943b5d47c7 100644 --- a/ext/websocket/lib.rs +++ b/ext/websocket/lib.rs @@ -38,11 +38,12 @@ use std::future::Future; use std::path::PathBuf; use std::rc::Rc; use std::sync::Arc; +use tokio::io::AsyncRead; +use tokio::io::AsyncWrite; use tokio::net::TcpStream; use tokio_rustls::rustls::RootCertStore; use tokio_rustls::rustls::ServerName; use tokio_rustls::TlsConnector; -use tokio_tungstenite::MaybeTlsStream; use fastwebsockets::CloseCode; use fastwebsockets::FragmentCollector; @@ -129,6 +130,33 @@ pub struct CreateResponse { extensions: String, } +async fn handshake( + cancel_resource: Option>, + request: Request, + socket: S, +) -> Result<(WebSocket, http::Response), AnyError> { + let client = + fastwebsockets::handshake::client(&LocalExecutor, request, socket); + + let (upgraded, response) = if let Some(cancel_resource) = cancel_resource { + client.or_cancel(cancel_resource).await? + } else { + client.await + } + .map_err(|err| { + DomExceptionNetworkError::new(&format!( + "failed to connect to WebSocket: {err}" + )) + })?; + + let upgraded = upgraded.into_inner(); + let stream = + WebSocketStream::new(stream::WsStreamKind::Upgraded(upgraded), None); + let stream = WebSocket::after_handshake(stream, Role::Client); + + Ok((stream, response)) +} + #[op] pub async fn op_ws_create( state: Rc>, @@ -155,7 +183,7 @@ where .borrow_mut() .resource_table .get::(cancel_rid)?; - Some(r) + Some(r.0.clone()) } else { None }; @@ -223,8 +251,8 @@ where let addr = format!("{domain}:{port}"); let tcp_socket = TcpStream::connect(addr).await?; - let socket: MaybeTlsStream = match uri.scheme_str() { - Some("ws") => MaybeTlsStream::Plain(tcp_socket), + let (stream, response) = match uri.scheme_str() { + Some("ws") => handshake(cancel_resource, request, tcp_socket).await?, Some("wss") => { let tls_config = create_client_config( root_cert_store, @@ -236,30 +264,11 @@ where let dnsname = ServerName::try_from(domain.as_str()) .map_err(|_| invalid_hostname(domain))?; let tls_socket = tls_connector.connect(dnsname, tcp_socket).await?; - MaybeTlsStream::Rustls(tls_socket) + handshake(cancel_resource, request, tls_socket).await? } _ => unreachable!(), }; - let client = - fastwebsockets::handshake::client(&LocalExecutor, request, socket); - - let (upgraded, response) = if let Some(cancel_resource) = cancel_resource { - client.or_cancel(cancel_resource.0.to_owned()).await? - } else { - client.await - } - .map_err(|err| { - DomExceptionNetworkError::new(&format!( - "failed to connect to WebSocket: {err}" - )) - })?; - - let inner = MaybeTlsStream::Plain(upgraded.into_inner()); - let stream = - WebSocketStream::new(stream::WsStreamKind::Tungstenite(inner), None); - let stream = WebSocket::after_handshake(stream, Role::Client); - if let Some(cancel_rid) = cancel_handle { state.borrow_mut().resource_table.close(cancel_rid).ok(); } diff --git a/ext/websocket/stream.rs b/ext/websocket/stream.rs index 69c06b7eb7..6f93406f62 100644 --- a/ext/websocket/stream.rs +++ b/ext/websocket/stream.rs @@ -8,11 +8,10 @@ use std::task::Poll; use tokio::io::AsyncRead; use tokio::io::AsyncWrite; use tokio::io::ReadBuf; -use tokio_tungstenite::MaybeTlsStream; // TODO(bartlomieju): remove this pub(crate) enum WsStreamKind { - Tungstenite(MaybeTlsStream), + Upgraded(Upgraded), Network(NetworkStream), } @@ -54,7 +53,7 @@ impl AsyncRead for WebSocketStream { } match &mut self.stream { WsStreamKind::Network(stream) => Pin::new(stream).poll_read(cx, buf), - WsStreamKind::Tungstenite(stream) => Pin::new(stream).poll_read(cx, buf), + WsStreamKind::Upgraded(stream) => Pin::new(stream).poll_read(cx, buf), } } } @@ -67,7 +66,7 @@ impl AsyncWrite for WebSocketStream { ) -> std::task::Poll> { match &mut self.stream { WsStreamKind::Network(stream) => Pin::new(stream).poll_write(cx, buf), - WsStreamKind::Tungstenite(stream) => Pin::new(stream).poll_write(cx, buf), + WsStreamKind::Upgraded(stream) => Pin::new(stream).poll_write(cx, buf), } } @@ -77,7 +76,7 @@ impl AsyncWrite for WebSocketStream { ) -> std::task::Poll> { match &mut self.stream { WsStreamKind::Network(stream) => Pin::new(stream).poll_flush(cx), - WsStreamKind::Tungstenite(stream) => Pin::new(stream).poll_flush(cx), + WsStreamKind::Upgraded(stream) => Pin::new(stream).poll_flush(cx), } } @@ -87,14 +86,14 @@ impl AsyncWrite for WebSocketStream { ) -> std::task::Poll> { match &mut self.stream { WsStreamKind::Network(stream) => Pin::new(stream).poll_shutdown(cx), - WsStreamKind::Tungstenite(stream) => Pin::new(stream).poll_shutdown(cx), + WsStreamKind::Upgraded(stream) => Pin::new(stream).poll_shutdown(cx), } } fn is_write_vectored(&self) -> bool { match &self.stream { WsStreamKind::Network(stream) => stream.is_write_vectored(), - WsStreamKind::Tungstenite(stream) => stream.is_write_vectored(), + WsStreamKind::Upgraded(stream) => stream.is_write_vectored(), } } @@ -107,7 +106,7 @@ impl AsyncWrite for WebSocketStream { WsStreamKind::Network(stream) => { Pin::new(stream).poll_write_vectored(cx, bufs) } - WsStreamKind::Tungstenite(stream) => { + WsStreamKind::Upgraded(stream) => { Pin::new(stream).poll_write_vectored(cx, bufs) } }