diff --git a/ext/websocket/01_websocket.js b/ext/websocket/01_websocket.js index e0db51f4be..3a3f4bc902 100644 --- a/ext/websocket/01_websocket.js +++ b/ext/websocket/01_websocket.js @@ -429,6 +429,7 @@ class WebSocket extends EventTarget { const rid = this[_rid]; while (this[_readyState] !== CLOSED) { const kind = await op_ws_next_event(rid); + /* close the connection if read was cancelled, and we didn't get a close frame */ if ( (this[_readyState] == CLOSING) && @@ -442,6 +443,10 @@ class WebSocket extends EventTarget { break; } + if (kind == null) { + break; + } + switch (kind) { case 0: { /* string */ diff --git a/ext/websocket/lib.rs b/ext/websocket/lib.rs index deb424c9be..25af67a7c2 100644 --- a/ext/websocket/lib.rs +++ b/ext/websocket/lib.rs @@ -16,6 +16,7 @@ use deno_core::url; use deno_core::AsyncMutFuture; use deno_core::AsyncRefCell; use deno_core::ByteString; +use deno_core::CancelFuture; use deno_core::CancelHandle; use deno_core::CancelTryFuture; use deno_core::JsBuffer; @@ -552,6 +553,7 @@ pub struct ServerWebSocket { string: Cell>, ws_read: AsyncRefCell>>, ws_write: AsyncRefCell>>, + cancel_handle: Rc, } impl ServerWebSocket { @@ -566,6 +568,7 @@ impl ServerWebSocket { string: Cell::new(None), ws_read: AsyncRefCell::new(FragmentCollectorRead::new(ws_read)), ws_write: AsyncRefCell::new(ws_write), + cancel_handle: CancelHandle::new_rc(), } } @@ -769,7 +772,7 @@ pub async fn op_ws_close( let Ok(resource) = state .borrow_mut() .resource_table - .get::(rid) + .take::(rid) else { return Ok(()); }; @@ -784,6 +787,8 @@ pub async fn op_ws_close( }); resource.closed.set(true); + + resource.cancel_handle.cancel(); let lock = resource.reserve_lock(); resource.write_frame(lock, frame).await } @@ -826,19 +831,19 @@ pub fn op_ws_get_error(state: &mut OpState, #[smi] rid: ResourceId) -> String { pub async fn op_ws_next_event( state: Rc>, #[smi] rid: ResourceId, -) -> u16 { +) -> Option { let Ok(resource) = state .borrow_mut() .resource_table .get::(rid) else { // op_ws_get_error will correctly handle a bad resource - return MessageKind::Error as u16; + return Some(MessageKind::Error as u16); }; // If there's a pending error, this always returns error if resource.errored.get() { - return MessageKind::Error as u16; + return Some(MessageKind::Error as u16); } let mut ws = RcRef::map(&resource, |r| &r.ws_read).borrow_mut().await; @@ -847,19 +852,26 @@ pub async fn op_ws_next_event( let writer = writer.clone(); async move { writer.borrow_mut().await.write_frame(frame).await } }; + let cancel_handle = resource.cancel_handle.clone(); loop { - let res = ws.read_frame(&mut sender).await; + let Ok(res) = ws + .read_frame(&mut sender) + .or_cancel(cancel_handle.clone()) + .await + else { + return None; + }; let val = match res { Ok(val) => val, Err(err) => { // No message was received, socket closed while we waited. // Report closed status to JavaScript. if resource.closed.get() { - return MessageKind::ClosedDefault as u16; + return Some(MessageKind::ClosedDefault as u16); } resource.set_error(Some(err.to_string())); - return MessageKind::Error as u16; + return Some(MessageKind::Error as u16); } }; @@ -867,22 +879,22 @@ pub async fn op_ws_next_event( OpCode::Text => match String::from_utf8(val.payload.to_vec()) { Ok(s) => { resource.string.set(Some(s)); - MessageKind::Text as u16 + Some(MessageKind::Text as u16) } Err(_) => { resource.set_error(Some("Invalid string data".into())); - MessageKind::Error as u16 + Some(MessageKind::Error as u16) } }, OpCode::Binary => { resource.buffer.set(Some(val.payload.to_vec())); - MessageKind::Binary as u16 + Some(MessageKind::Binary as u16) } OpCode::Close => { // Close reason is returned through error if val.payload.len() < 2 { resource.set_error(None); - MessageKind::ClosedDefault as u16 + Some(MessageKind::ClosedDefault as u16) } else { let close_code = CloseCode::from(u16::from_be_bytes([ val.payload[0], @@ -890,10 +902,10 @@ pub async fn op_ws_next_event( ])); let reason = String::from_utf8(val.payload[2..].to_vec()).ok(); resource.set_error(reason); - close_code.into() + Some(close_code.into()) } } - OpCode::Pong => MessageKind::Pong as u16, + OpCode::Pong => Some(MessageKind::Pong as u16), OpCode::Continuation | OpCode::Ping => { continue; } diff --git a/tests/unit/websocket_test.ts b/tests/unit/websocket_test.ts index e5e4b1a7a7..ce13548359 100644 --- a/tests/unit/websocket_test.ts +++ b/tests/unit/websocket_test.ts @@ -822,3 +822,12 @@ Deno.test("send to a closed socket", async () => { }; await promise; }); + +Deno.test(async function websocketDoesntLeak() { + const { promise, resolve } = Promise.withResolvers(); + const ws = new WebSocket(new URL("ws://localhost:4242/")); + assertEquals(ws.url, "ws://localhost:4242/"); + ws.onopen = () => resolve(); + await promise; + ws.close(); +});