diff --git a/cli/dts/lib.deno.unstable.d.ts b/cli/dts/lib.deno.unstable.d.ts index 6b7755ee51..442f5d7d47 100644 --- a/cli/dts/lib.deno.unstable.d.ts +++ b/cli/dts/lib.deno.unstable.d.ts @@ -1172,6 +1172,7 @@ declare interface WorkerOptions { declare interface WebSocketStreamOptions { protocols?: string[]; signal?: AbortSignal; + headers?: HeadersInit; } declare interface WebSocketConnection { diff --git a/cli/tests/testdata/websocketstream_test.ts b/cli/tests/testdata/websocketstream_test.ts index 1198c41641..b43b901396 100644 --- a/cli/tests/testdata/websocketstream_test.ts +++ b/cli/tests/testdata/websocketstream_test.ts @@ -3,6 +3,7 @@ import { assert, assertEquals, + assertNotEquals, assertRejects, assertThrows, unreachable, @@ -137,3 +138,59 @@ Deno.test("aborting immediately with a primitive as reason throws that primitive (e) => assertEquals(e, "Some string"), ); }); + +Deno.test("headers", async () => { + const listener = Deno.listen({ port: 4501 }); + const promise = (async () => { + const httpConn = Deno.serveHttp(await listener.accept()); + const { request, respondWith } = (await httpConn.nextRequest())!; + assertEquals(request.headers.get("x-some-header"), "foo"); + const { + response, + socket, + } = Deno.upgradeWebSocket(request); + socket.onopen = () => socket.close(); + await respondWith(response); + })(); + + const ws = new WebSocketStream("ws://localhost:4501", { + headers: [["x-some-header", "foo"]], + }); + await promise; + await ws.closed; + listener.close(); +}); + +Deno.test("forbidden headers", async () => { + const forbiddenHeaders = [ + "sec-websocket-accept", + "sec-websocket-extensions", + "sec-websocket-key", + "sec-websocket-protocol", + "sec-websocket-version", + "upgrade", + "connection", + ]; + + const listener = Deno.listen({ port: 4501 }); + const promise = (async () => { + const httpConn = Deno.serveHttp(await listener.accept()); + const { request, respondWith } = (await httpConn.nextRequest())!; + for (const header of request.headers.keys()) { + assertNotEquals(header, "foo"); + } + const { + response, + socket, + } = Deno.upgradeWebSocket(request); + socket.onopen = () => socket.close(); + await respondWith(response); + })(); + + const ws = new WebSocketStream("ws://localhost:4501", { + headers: forbiddenHeaders.map((header) => [header, "foo"]), + }); + await promise; + await ws.closed; + listener.close(); +}); diff --git a/ext/websocket/02_websocketstream.js b/ext/websocket/02_websocketstream.js index 8b032d1c20..d0a4e055d7 100644 --- a/ext/websocket/02_websocketstream.js +++ b/ext/websocket/02_websocketstream.js @@ -39,6 +39,10 @@ key: "signal", converter: webidl.converters.AbortSignal, }, + { + key: "headers", + converter: webidl.converters.HeadersInit, + }, ], ); webidl.converters.WebSocketCloseInfo = webidl.createDictionaryConverter( @@ -139,6 +143,7 @@ ? ArrayPrototypeJoin(options.protocols, ", ") : "", cancelHandle: cancelRid, + headers: [...new Headers(options.headers).entries()], }), (create) => { options.signal?.[remove](abort); diff --git a/ext/websocket/lib.rs b/ext/websocket/lib.rs index 4796eddc65..544423066f 100644 --- a/ext/websocket/lib.rs +++ b/ext/websocket/lib.rs @@ -1,6 +1,7 @@ // Copyright 2018-2021 the Deno authors. All rights reserved. MIT license. use deno_core::error::invalid_hostname; +use deno_core::error::type_error; use deno_core::error::AnyError; use deno_core::futures::stream::SplitSink; use deno_core::futures::stream::SplitStream; @@ -11,6 +12,7 @@ use deno_core::op_async; use deno_core::op_sync; use deno_core::url; use deno_core::AsyncRefCell; +use deno_core::ByteString; use deno_core::CancelFuture; use deno_core::CancelHandle; use deno_core::Extension; @@ -20,6 +22,8 @@ use deno_core::Resource; use deno_core::ResourceId; use deno_core::ZeroCopyBuf; use deno_tls::create_client_config; +use http::header::HeaderName; +use http::HeaderValue; use http::Method; use http::Request; use http::Uri; @@ -215,6 +219,7 @@ pub struct CreateArgs { url: String, protocols: String, cancel_handle: Option, + headers: Option>, } #[derive(Serialize)] @@ -267,6 +272,30 @@ where request = request.header("Sec-WebSocket-Protocol", args.protocols); } + if let Some(headers) = args.headers { + for (key, value) in headers { + let name = HeaderName::from_bytes(&key) + .map_err(|err| type_error(err.to_string()))?; + let v = HeaderValue::from_bytes(&value) + .map_err(|err| type_error(err.to_string()))?; + + let is_disallowed_header = matches!( + name, + http::header::HOST + | http::header::SEC_WEBSOCKET_ACCEPT + | http::header::SEC_WEBSOCKET_EXTENSIONS + | http::header::SEC_WEBSOCKET_KEY + | http::header::SEC_WEBSOCKET_PROTOCOL + | http::header::SEC_WEBSOCKET_VERSION + | http::header::UPGRADE + | http::header::CONNECTION + ); + if !is_disallowed_header { + request = request.header(name, v); + } + } + } + let request = request.body(())?; let domain = &uri.host().unwrap().to_string(); let port = &uri.port_u16().unwrap_or(match uri.scheme_str() {