diff --git a/Cargo.lock b/Cargo.lock index 3ff195e977..10d7f335c3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1028,8 +1028,11 @@ dependencies = [ "deno_websocket", "flate2", "fly-accept-encoding", + "httparse", "hyper", + "memmem", "mime", + "once_cell", "percent-encoding", "phf", "pin-project", @@ -2670,6 +2673,12 @@ dependencies = [ "libc", ] +[[package]] +name = "memmem" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a64a92489e2744ce060c349162be1c5f33c6969234104dbd99ddb5feb08b8c15" + [[package]] name = "memoffset" version = "0.6.5" diff --git a/Cargo.toml b/Cargo.toml index e05b2193ae..88fc8d2ce9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -96,11 +96,13 @@ fs3 = "0.5.0" futures = "0.3.21" hex = "0.4" http = "0.2.9" +httparse = "1.8.0" hyper = "0.14.18" indexmap = { version = "1.9.2", features = ["serde"] } libc = "0.2.126" log = "=0.4.17" lsp-types = "=0.93.2" # used by tower-lsp and "proposed" feature is unstable in patch releases +memmem = "0.1.1" notify = "=5.0.0" num-bigint = { version = "0.4", features = ["rand"] } once_cell = "1.17.1" diff --git a/ext/http/01_http.js b/ext/http/01_http.js index 5b78bb2f2f..7224df3c5c 100644 --- a/ext/http/01_http.js +++ b/ext/http/01_http.js @@ -17,6 +17,7 @@ import { _flash, fromInnerRequest, newInnerRequest, + toInnerRequest, } from "ext:deno_fetch/23_request.js"; import { AbortController } from "ext:deno_web/03_abort_signal.js"; import { @@ -61,6 +62,7 @@ const { } = primordials; const connErrorSymbol = Symbol("connError"); +const streamRid = Symbol("streamRid"); const _deferred = Symbol("upgradeHttpDeferred"); class HttpConn { @@ -135,6 +137,7 @@ class HttpConn { body !== null ? new InnerBody(body) : null, false, ); + innerRequest[streamRid] = streamRid; const abortController = new AbortController(); const request = fromInnerRequest( innerRequest, @@ -471,6 +474,12 @@ function upgradeHttp(req) { return req[_deferred].promise; } +async function upgradeHttpRaw(req, tcpConn) { + const inner = toInnerRequest(req); + const res = await core.opAsync("op_http_upgrade_early", inner[streamRid]); + return new TcpConn(res, tcpConn.remoteAddr, tcpConn.localAddr); +} + const spaceCharCode = StringPrototypeCharCodeAt(" ", 0); const tabCharCode = StringPrototypeCharCodeAt("\t", 0); const commaCharCode = StringPrototypeCharCodeAt(",", 0); @@ -545,4 +554,4 @@ function buildCaseInsensitiveCommaValueFinder(checkText) { internals.buildCaseInsensitiveCommaValueFinder = buildCaseInsensitiveCommaValueFinder; -export { _ws, HttpConn, upgradeHttp, upgradeWebSocket }; +export { _ws, HttpConn, upgradeHttp, upgradeHttpRaw, upgradeWebSocket }; diff --git a/ext/http/Cargo.toml b/ext/http/Cargo.toml index 92dde5ee42..605ffa1264 100644 --- a/ext/http/Cargo.toml +++ b/ext/http/Cargo.toml @@ -27,8 +27,11 @@ deno_core.workspace = true deno_websocket.workspace = true flate2.workspace = true fly-accept-encoding = "0.2.0" +httparse.workspace = true hyper = { workspace = true, features = ["server", "stream", "http1", "http2", "runtime"] } +memmem.workspace = true mime = "0.3.16" +once_cell.workspace = true percent-encoding.workspace = true phf = { version = "0.10", features = ["macros"] } pin-project.workspace = true diff --git a/ext/http/lib.rs b/ext/http/lib.rs index 20436a82d3..289e7bf0f9 100644 --- a/ext/http/lib.rs +++ b/ext/http/lib.rs @@ -32,6 +32,7 @@ use deno_core::RcRef; use deno_core::Resource; use deno_core::ResourceId; use deno_core::StringOrBuffer; +use deno_core::WriteOutcome; use deno_core::ZeroCopyBuf; use deno_websocket::ws_create_server_stream; use flate2::write::GzEncoder; @@ -65,15 +66,18 @@ use std::sync::Arc; use std::task::Context; use std::task::Poll; use tokio::io::AsyncRead; +use tokio::io::AsyncReadExt; use tokio::io::AsyncWrite; use tokio::io::AsyncWriteExt; use tokio::task::spawn_local; +use websocket_upgrade::WebSocketUpgrade; use crate::reader_stream::ExternallyAbortableReaderStream; use crate::reader_stream::ShutdownHandle; pub mod compressible; mod reader_stream; +mod websocket_upgrade; deno_core::extension!( deno_http, @@ -86,6 +90,7 @@ deno_core::extension!( op_http_write_resource, op_http_shutdown, op_http_websocket_accept_header, + op_http_upgrade_early, op_http_upgrade_websocket, ], esm = ["01_http.js"], @@ -938,6 +943,192 @@ fn op_http_websocket_accept_header(key: String) -> Result { Ok(base64::encode(digest)) } +struct EarlyUpgradeSocket(AsyncRefCell, CancelHandle); + +enum EarlyUpgradeSocketInner { + PreResponse( + Rc, + WebSocketUpgrade, + // Readers need to block in this state, so they can wait here for the broadcast. + tokio::sync::broadcast::Sender< + Rc>>, + >, + ), + PostResponse( + Rc>>, + Rc>>, + ), +} + +impl EarlyUpgradeSocket { + /// Gets a reader without holding the lock. + async fn get_reader( + self: Rc, + ) -> Result< + Rc>>, + AnyError, + > { + let mut borrow = RcRef::map(self.clone(), |x| &x.0).borrow_mut().await; + let cancel = RcRef::map(self, |x| &x.1); + let inner = &mut *borrow; + match inner { + EarlyUpgradeSocketInner::PreResponse(_, _, tx) => { + let mut rx = tx.subscribe(); + // Ensure we're not borrowing self here + drop(borrow); + Ok( + rx.recv() + .map_err(AnyError::from) + .try_or_cancel(&cancel) + .await?, + ) + } + EarlyUpgradeSocketInner::PostResponse(rx, _) => Ok(rx.clone()), + } + } + + async fn read(self: Rc, data: &mut [u8]) -> Result { + let reader = self.clone().get_reader().await?; + let cancel = RcRef::map(self, |x| &x.1); + Ok( + reader + .borrow_mut() + .await + .read(data) + .try_or_cancel(&cancel) + .await?, + ) + } + + /// Write all the data provided, only holding the lock while we see if the connection needs to be + /// upgraded. + async fn write_all(self: Rc, buf: &[u8]) -> Result<(), AnyError> { + let mut borrow = RcRef::map(self.clone(), |x| &x.0).borrow_mut().await; + let cancel = RcRef::map(self, |x| &x.1); + let inner = &mut *borrow; + match inner { + EarlyUpgradeSocketInner::PreResponse(stream, upgrade, rx_tx) => { + if let Some((resp, extra)) = upgrade.write(buf)? { + let new_wr = HttpResponseWriter::Closed; + let mut old_wr = + RcRef::map(stream.clone(), |r| &r.wr).borrow_mut().await; + let response_tx = match replace(&mut *old_wr, new_wr) { + HttpResponseWriter::Headers(response_tx) => response_tx, + _ => return Err(http_error("response headers already sent")), + }; + + if response_tx.send(resp).is_err() { + stream.conn.closed().await?; + return Err(http_error("connection closed while sending response")); + }; + + let mut old_rd = + RcRef::map(stream.clone(), |r| &r.rd).borrow_mut().await; + let new_rd = HttpRequestReader::Closed; + let upgraded = match replace(&mut *old_rd, new_rd) { + HttpRequestReader::Headers(request) => { + hyper::upgrade::on(request) + .map_err(AnyError::from) + .try_or_cancel(&cancel) + .await? + } + _ => { + return Err(http_error("response already started")); + } + }; + + let (rx, tx) = tokio::io::split(upgraded); + let rx = Rc::new(AsyncRefCell::new(rx)); + let tx = Rc::new(AsyncRefCell::new(tx)); + + // Take the tx and rx lock before we allow anything else to happen because we want to control + // the order of reads and writes. + let mut tx_lock = tx.clone().borrow_mut().await; + let rx_lock = rx.clone().borrow_mut().await; + + // Allow all the pending readers to go now. We still have the lock on inner, so no more + // pending readers can show up. We intentionally ignore errors here, as there may be + // nobody waiting on a read. + _ = rx_tx.send(rx.clone()); + + // We swap out inner here, so once the lock is gone, readers will acquire rx directly. + // We also fully release our lock. + *inner = EarlyUpgradeSocketInner::PostResponse(rx, tx); + drop(borrow); + + // We've updated inner and unlocked it, reads are free to go in-order. + drop(rx_lock); + + // If we had extra data after the response, write that to the upgraded connection + if !extra.is_empty() { + tx_lock.write_all(&extra).try_or_cancel(&cancel).await?; + } + } + } + EarlyUpgradeSocketInner::PostResponse(_, tx) => { + let tx = tx.clone(); + drop(borrow); + tx.borrow_mut() + .await + .write_all(buf) + .try_or_cancel(&cancel) + .await?; + } + }; + Ok(()) + } +} + +impl Resource for EarlyUpgradeSocket { + fn name(&self) -> Cow { + "upgradedHttpConnection".into() + } + + deno_core::impl_readable_byob!(); + + fn write( + self: Rc, + buf: BufView, + ) -> AsyncResult { + Box::pin(async move { + let nwritten = buf.len(); + Self::write_all(self, &buf).await?; + Ok(WriteOutcome::Full { nwritten }) + }) + } + + fn write_all(self: Rc, buf: BufView) -> AsyncResult<()> { + Box::pin(async move { Self::write_all(self, &buf).await }) + } + + fn close(self: Rc) { + self.1.cancel() + } +} + +#[op] +async fn op_http_upgrade_early( + state: Rc>, + rid: ResourceId, +) -> Result { + let stream = state + .borrow_mut() + .resource_table + .get::(rid)?; + let resources = &mut state.borrow_mut().resource_table; + let (tx, _rx) = tokio::sync::broadcast::channel(1); + let socket = EarlyUpgradeSocketInner::PreResponse( + stream, + WebSocketUpgrade::default(), + tx, + ); + let rid = resources.add(EarlyUpgradeSocket( + AsyncRefCell::new(socket), + CancelHandle::new(), + )); + Ok(rid) +} + struct UpgradedStream(hyper::upgrade::Upgraded); impl tokio::io::AsyncRead for UpgradedStream { fn poll_read( diff --git a/ext/http/websocket_upgrade.rs b/ext/http/websocket_upgrade.rs new file mode 100644 index 0000000000..042a467219 --- /dev/null +++ b/ext/http/websocket_upgrade.rs @@ -0,0 +1,333 @@ +// Copyright 2018-2023 the Deno authors. All rights reserved. MIT license. + +use bytes::Bytes; +use bytes::BytesMut; +use deno_core::error::AnyError; +use httparse::Status; +use hyper::http::HeaderName; +use hyper::http::HeaderValue; +use hyper::Body; +use hyper::Response; +use memmem::Searcher; +use memmem::TwoWaySearcher; +use once_cell::sync::OnceCell; + +use crate::http_error; + +/// Given a buffer that ends in `\n\n` or `\r\n\r\n`, returns a parsed [`Request`]. +fn parse_response( + header_bytes: &[u8], +) -> Result<(usize, Response), AnyError> { + let mut headers = [httparse::EMPTY_HEADER; 16]; + let status = httparse::parse_headers(header_bytes, &mut headers)?; + match status { + Status::Complete((index, parsed)) => { + let mut resp = Response::builder().status(101).body(Body::empty())?; + for header in parsed.iter() { + resp.headers_mut().append( + HeaderName::from_bytes(header.name.as_bytes())?, + HeaderValue::from_str(std::str::from_utf8(header.value)?)?, + ); + } + Ok((index, resp)) + } + _ => Err(http_error("invalid headers")), + } +} + +/// Find a newline in a slice. +fn find_newline(slice: &[u8]) -> Option { + for (i, byte) in slice.iter().enumerate() { + if *byte == b'\n' { + return Some(i); + } + } + None +} + +/// WebSocket upgrade state machine states. +#[derive(Default)] +enum WebSocketUpgradeState { + #[default] + Initial, + StatusLine, + Headers, + Complete, +} + +static HEADER_SEARCHER: OnceCell = OnceCell::new(); +static HEADER_SEARCHER2: OnceCell = OnceCell::new(); + +#[derive(Default)] +pub struct WebSocketUpgrade { + state: WebSocketUpgradeState, + buf: BytesMut, +} + +impl WebSocketUpgrade { + /// Ensures that the status line starts with "HTTP/1.1 101 " which matches all of the node.js + /// WebSocket libraries that are known. We don't care about the trailing status text. + fn validate_status(&self, status: &[u8]) -> Result<(), AnyError> { + if status.starts_with(b"HTTP/1.1 101 ") { + Ok(()) + } else { + Err(http_error("invalid HTTP status line")) + } + } + + /// Writes bytes to our upgrade buffer, returning [`Ok(None)`] if we need to keep feeding it data, + /// [`Ok(Some(Response))`] if we got a valid upgrade header, or [`Err`] if something went badly. + pub fn write( + &mut self, + bytes: &[u8], + ) -> Result, Bytes)>, AnyError> { + use WebSocketUpgradeState::*; + + match self.state { + Initial => { + if let Some(index) = find_newline(bytes) { + let (status, rest) = bytes.split_at(index + 1); + self.validate_status(status)?; + + // Fast path for the most common node.js WebSocket libraries that use \r\n as the + // separator between header lines and send the whole response in one packet. + if rest.ends_with(b"\r\n\r\n") { + let (index, response) = parse_response(rest)?; + if index == rest.len() { + return Ok(Some((response, Bytes::default()))); + } else { + let bytes = Bytes::copy_from_slice(&rest[index..]); + return Ok(Some((response, bytes))); + } + } + + self.state = Headers; + self.write(rest) + } else { + self.state = StatusLine; + self.buf.extend_from_slice(bytes); + Ok(None) + } + } + StatusLine => { + if let Some(index) = find_newline(bytes) { + let (status, rest) = bytes.split_at(index + 1); + self.buf.extend_from_slice(status); + self.validate_status(&self.buf)?; + self.buf.clear(); + // Recursively process this write + self.state = Headers; + self.write(rest) + } else { + self.buf.extend_from_slice(bytes); + Ok(None) + } + } + Headers => { + self.buf.extend_from_slice(bytes); + let header_searcher = + HEADER_SEARCHER.get_or_init(|| TwoWaySearcher::new(b"\r\n\r\n")); + let header_searcher2 = + HEADER_SEARCHER2.get_or_init(|| TwoWaySearcher::new(b"\n\n")); + if let Some(..) = header_searcher.search_in(&self.buf) { + let (index, response) = parse_response(&self.buf)?; + let mut buf = std::mem::take(&mut self.buf); + self.state = Complete; + Ok(Some((response, buf.split_off(index).freeze()))) + } else if let Some(..) = header_searcher2.search_in(&self.buf) { + let (index, response) = parse_response(&self.buf)?; + let mut buf = std::mem::take(&mut self.buf); + self.state = Complete; + Ok(Some((response, buf.split_off(index).freeze()))) + } else { + Ok(None) + } + } + Complete => { + Err(http_error("attempted to write to completed upgrade buffer")) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + type ExpectedResponseAndHead = Option<(Response, &'static [u8])>; + + fn assert_response( + result: Result, Bytes)>, AnyError>, + expected: Result, + chunk_info: Option<(usize, usize)>, + ) { + let formatted = format!("{result:?}"); + match expected { + Ok(Some((resp1, remainder1))) => match result { + Ok(Some((resp2, remainder2))) => { + assert_eq!(format!("{resp1:?}"), format!("{resp2:?}")); + if let Some((byte_len, chunk_size)) = chunk_info { + // We need to compute how many bytes should be in the trailing data + + // We know how many bytes of header data we had + let last_packet_header_size = + (byte_len - remainder1.len() + chunk_size - 1) % chunk_size + 1; + + // Which means we can compute how much was in the remainder + let remaining = + (chunk_size - last_packet_header_size).min(remainder1.len()); + + assert_eq!(remainder1[..remaining], remainder2); + } else { + assert_eq!(remainder1, remainder2); + } + } + _ => panic!("Expected Ok(Some(...)), was {formatted}"), + }, + Ok(None) => assert!( + result.ok().unwrap().is_none(), + "Expected Ok(None), was {formatted}", + ), + Err(e) => assert_eq!( + e, + result.err().map(|e| format!("{e:?}")).unwrap_or_default(), + "Expected error, was {formatted}", + ), + } + } + + fn validate_upgrade_all_at_once( + s: &str, + expected: Result, + ) { + let mut upgrade = WebSocketUpgrade::default(); + let res = upgrade.write(s.as_bytes()); + + assert_response(res, expected, None); + } + + fn validate_upgrade_chunks( + s: &str, + size: usize, + expected: Result, + ) { + let chunk_info = Some((s.as_bytes().len(), size)); + let mut upgrade = WebSocketUpgrade::default(); + let mut result = Ok(None); + for chunk in s.as_bytes().chunks(size) { + result = upgrade.write(chunk); + if let Ok(Some(..)) = &result { + assert_response(result, expected, chunk_info); + return; + } + } + assert_response(result, expected, chunk_info); + } + + fn validate_upgrade( + s: &str, + expected: fn() -> Result, + ) { + validate_upgrade_all_at_once(s, expected()); + validate_upgrade_chunks(s, 1, expected()); + validate_upgrade_chunks(s, 2, expected()); + validate_upgrade_chunks(s, 10, expected()); + + // Replace \n with \r\n, but only in headers + let (headers, trailing) = s.split_once("\n\n").unwrap(); + let s = headers.replace('\n', "\r\n") + "\r\n\r\n" + trailing; + let s = s.as_ref(); + + validate_upgrade_all_at_once(s, expected()); + validate_upgrade_chunks(s, 1, expected()); + validate_upgrade_chunks(s, 2, expected()); + validate_upgrade_chunks(s, 10, expected()); + } + + #[test] + fn upgrade1() { + validate_upgrade( + "HTTP/1.1 101 Switching Protocols\nConnection: Upgrade\n\n", + || { + let mut expected = + Response::builder().status(101).body(Body::empty()).unwrap(); + expected.headers_mut().append( + HeaderName::from_static("connection"), + HeaderValue::from_static("Upgrade"), + ); + Ok(Some((expected, b""))) + }, + ); + } + + #[test] + fn upgrade_trailing() { + validate_upgrade( + "HTTP/1.1 101 Switching Protocols\nConnection: Upgrade\n\ntrailing data", + || { + let mut expected = + Response::builder().status(101).body(Body::empty()).unwrap(); + expected.headers_mut().append( + HeaderName::from_static("connection"), + HeaderValue::from_static("Upgrade"), + ); + Ok(Some((expected, b"trailing data"))) + }, + ); + } + + #[test] + fn upgrade_trailing_with_newlines() { + validate_upgrade( + "HTTP/1.1 101 Switching Protocols\nConnection: Upgrade\n\ntrailing data\r\n\r\n", + || { + let mut expected = + Response::builder().status(101).body(Body::empty()).unwrap(); + expected.headers_mut().append( + HeaderName::from_static("connection"), + HeaderValue::from_static("Upgrade"), + ); + Ok(Some((expected, b"trailing data\r\n\r\n"))) + }, + ); + } + + #[test] + fn upgrade2() { + validate_upgrade( + "HTTP/1.1 101 Switching Protocols\nConnection: Upgrade\nOther: 123\n\n", + || { + let mut expected = + Response::builder().status(101).body(Body::empty()).unwrap(); + expected.headers_mut().append( + HeaderName::from_static("connection"), + HeaderValue::from_static("Upgrade"), + ); + expected.headers_mut().append( + HeaderName::from_static("other"), + HeaderValue::from_static("123"), + ); + Ok(Some((expected, b""))) + }, + ); + } + + #[test] + fn upgrade_invalid_status() { + validate_upgrade("HTTP/1.1 200 OK\nConnection: Upgrade\n\n", || { + Err("invalid HTTP status line") + }); + } + + #[test] + fn upgrade_too_many_headers() { + let headers = (0..20) + .map(|i| format!("h{i}: {i}")) + .collect::>() + .join("\n"); + validate_upgrade( + &format!("HTTP/1.1 101 Switching Protocols\n{headers}\n\n"), + || Err("too many headers"), + ); + } +}