diff --git a/Cargo.lock b/Cargo.lock index 50bc478c0d..f4f8cff802 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -638,6 +638,7 @@ name = "cli_tests" version = "0.0.0" dependencies = [ "bytes", + "chrono", "deno_ast", "deno_bench_util", "deno_cache_dir", @@ -665,6 +666,8 @@ dependencies = [ "trust-dns-client", "trust-dns-server", "url", + "uuid", + "zeromq", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index f522aa5751..f047740c02 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -176,6 +176,7 @@ tower-lsp = { version = "=0.20.0", features = ["proposed"] } url = { version = "2.3.1", features = ["serde", "expose_internals"] } uuid = { version = "1.3.0", features = ["v4"] } webpki-roots = "0.25.2" +zeromq = { version = "=0.3.4", default-features = false, features = ["tcp-transport", "tokio-runtime"] } zstd = "=0.12.4" # crypto diff --git a/cli/Cargo.toml b/cli/Cargo.toml index 939dfee74f..c252cd9b1a 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -149,7 +149,7 @@ typed-arena = "=2.0.1" unicode-width = "0.1" uuid = { workspace = true, features = ["serde"] } walkdir = "=2.3.2" -zeromq = { version = "=0.3.4", default-features = false, features = ["tcp-transport", "tokio-runtime"] } +zeromq.workspace = true zstd.workspace = true [target.'cfg(windows)'.dependencies] diff --git a/cli/tools/jupyter/server.rs b/cli/tools/jupyter/server.rs index bfdcdf1e17..f6218956da 100644 --- a/cli/tools/jupyter/server.rs +++ b/cli/tools/jupyter/server.rs @@ -396,6 +396,7 @@ impl JupyterServer { .with_content(json!({ "status": "ok", "execution_count": self.execution_count, + // FIXME: also include user_expressions })) .send(connection) .await?; diff --git a/tests/Cargo.toml b/tests/Cargo.toml index 872920bd9a..1c2d8de4b1 100644 --- a/tests/Cargo.toml +++ b/tests/Cargo.toml @@ -33,6 +33,7 @@ required-features = ["run"] [dev-dependencies] bytes.workspace = true +chrono.workspace = true deno_ast.workspace = true deno_bench_util.workspace = true deno_cache_dir = { workspace = true } @@ -59,6 +60,8 @@ tower-lsp.workspace = true trust-dns-client = "=0.22.0" trust-dns-server = "=0.22.1" url.workspace = true +uuid = { workspace = true, features = ["serde"] } +zeromq.workspace = true [target.'cfg(unix)'.dev-dependencies] nix.workspace = true diff --git a/tests/integration/jupyter_tests.rs b/tests/integration/jupyter_tests.rs new file mode 100644 index 0000000000..29636f305f --- /dev/null +++ b/tests/integration/jupyter_tests.rs @@ -0,0 +1,535 @@ +// Copyright 2018-2024 the Deno authors. All rights reserved. MIT license. + +use std::process::Output; +use std::sync::Arc; +use std::time::Duration; + +use bytes::Bytes; +use test_util::DenoChild; +use test_util::TestContext; +use test_util::TestContextBuilder; + +use deno_core::anyhow::Result; +use deno_core::serde_json; +use deno_core::serde_json::json; +use deno_core::serde_json::Value; +use serde::Deserialize; +use serde::Serialize; +use tokio::sync::Mutex; +use tokio::time::timeout; +use uuid::Uuid; +use zeromq::SocketRecv; +use zeromq::SocketSend; +use zeromq::ZmqMessage; + +// for the `utc_now` function +include!("../../cli/util/time.rs"); + +/// Jupyter connection file format +#[derive(Serialize)] +struct ConnectionSpec { + // key used for HMAC signature, if empty, hmac is not used + key: String, + signature_scheme: String, + transport: String, + ip: String, + hb_port: u16, + control_port: u16, + shell_port: u16, + stdin_port: u16, + iopub_port: u16, + kernel_name: String, +} + +impl ConnectionSpec { + fn endpoint(&self, port: u16) -> String { + format!("{}://{}:{}", self.transport, self.ip, port) + } +} + +fn pick_unused_port() -> u16 { + let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + listener.local_addr().unwrap().port() +} + +impl Default for ConnectionSpec { + fn default() -> Self { + Self { + key: "".into(), + signature_scheme: "hmac-sha256".into(), + transport: "tcp".into(), + ip: "127.0.0.1".into(), + hb_port: pick_unused_port(), + control_port: pick_unused_port(), + shell_port: pick_unused_port(), + stdin_port: pick_unused_port(), + iopub_port: pick_unused_port(), + kernel_name: "deno".into(), + } + } +} + +const DELIMITER: &[u8] = b""; + +#[derive(Debug, Clone)] +struct JupyterMsg { + routing_prefix: Vec, + signature: String, + header: MsgHeader, + parent_header: Value, + metadata: Value, + content: Value, + buffers: Vec, +} + +impl Default for JupyterMsg { + fn default() -> Self { + Self { + routing_prefix: vec![Uuid::new_v4().to_string()], + signature: "".into(), + header: MsgHeader::default(), + parent_header: json!({}), + metadata: json!({}), + content: json!({}), + buffers: Vec::new(), + } + } +} + +#[derive(Serialize, Clone, Debug, Deserialize)] +struct MsgHeader { + msg_id: Uuid, + session: Uuid, + date: String, + username: String, + msg_type: String, + version: String, +} + +impl MsgHeader { + fn to_json(&self) -> Value { + serde_json::to_value(self).unwrap() + } +} + +impl Default for MsgHeader { + fn default() -> Self { + Self { + msg_id: Uuid::new_v4(), + session: Uuid::new_v4(), + date: utc_now().to_rfc3339(), + username: "test".into(), + msg_type: "kernel_info_request".into(), + version: "5.3".into(), + } + } +} + +impl JupyterMsg { + fn to_raw(&self) -> ZmqMessage { + let mut parts = Vec::new(); + parts.extend( + self + .routing_prefix + .iter() + .map(|uuid| uuid.as_bytes().to_vec().into()), + ); + parts.push(Bytes::from_static(DELIMITER)); + parts.push(self.signature.clone().into()); + parts.push(serde_json::to_vec(&self.header).unwrap().into()); + parts.push(self.parent_header.to_string().into()); + parts.push(self.metadata.to_string().into()); + parts.push(self.content.to_string().into()); + parts.extend(self.buffers.clone()); + ZmqMessage::try_from(parts).unwrap() + } + + fn new(session: Uuid, msg_type: impl AsRef, content: Value) -> Self { + Self { + header: MsgHeader { + session, + msg_type: msg_type.as_ref().into(), + ..Default::default() + }, + content, + ..Default::default() + } + } + + fn from_raw(msg: ZmqMessage) -> Self { + let parts = msg.into_vec(); + let delimiter = parts.iter().position(|part| part == DELIMITER).unwrap(); + let routing_prefix = parts[..delimiter] + .iter() + .map(|part: &Bytes| String::from_utf8_lossy(part.as_ref()).to_string()) + .collect(); + let signature = String::from_utf8(parts[delimiter + 1].to_vec()) + .expect("Failed to parse signature"); + let header: MsgHeader = serde_json::from_slice(&parts[delimiter + 2]) + .expect("Failed to parse header"); + let parent_header: Value = + serde_json::from_slice(&parts[delimiter + 3]).unwrap(); + let metadata: Value = + serde_json::from_slice(&parts[delimiter + 4]).unwrap(); + let content: Value = serde_json::from_slice(&parts[delimiter + 5]).unwrap(); + let buffers = parts[delimiter + 6..].to_vec(); + Self { + routing_prefix, + signature, + header, + parent_header, + metadata, + content, + buffers, + } + } +} + +async fn connect_socket( + spec: &ConnectionSpec, + port: u16, +) -> S { + let addr = spec.endpoint(port); + let mut socket = S::new(); + socket.connect(&addr).await.unwrap(); + socket +} + +#[derive(Clone)] +struct JupyterClient { + recv_timeout: Duration, + session: Uuid, + heartbeat: Arc>, + control: Arc>, + shell: Arc>, + io_pub: Arc>, + stdin: Arc>, +} + +#[derive(Debug, Clone, Copy)] +enum JupyterChannel { + Control, + Shell, + #[allow(dead_code)] + Stdin, + IoPub, +} + +use JupyterChannel::*; + +impl JupyterClient { + async fn new(spec: &ConnectionSpec) -> Self { + Self::new_with_timeout(spec, Duration::from_secs(5)).await + } + + async fn new_with_timeout(spec: &ConnectionSpec, timeout: Duration) -> Self { + let (heartbeat, control, shell, io_pub, stdin) = tokio::join!( + connect_socket::(spec, spec.hb_port), + connect_socket::(spec, spec.control_port), + connect_socket::(spec, spec.shell_port), + connect_socket::(spec, spec.iopub_port), + connect_socket::(spec, spec.stdin_port), + ); + + Self { + session: Uuid::new_v4(), + heartbeat: Arc::new(Mutex::new(heartbeat)), + control: Arc::new(Mutex::new(control)), + shell: Arc::new(Mutex::new(shell)), + io_pub: Arc::new(Mutex::new(io_pub)), + stdin: Arc::new(Mutex::new(stdin)), + recv_timeout: timeout, + } + } + + async fn io_subscribe(&self, topic: &str) -> Result<()> { + Ok(self.io_pub.lock().await.subscribe(topic).await?) + } + + async fn recv_with_timeout( + &self, + s: &mut S, + ) -> Result { + let msg = timeout(self.recv_timeout, s.recv()).await??; + Ok(JupyterMsg::from_raw(msg)) + } + + async fn send_msg( + &self, + channel: JupyterChannel, + msg: JupyterMsg, + ) -> Result { + let raw = msg.to_raw(); + match channel { + Control => self.control.lock().await.send(raw).await?, + Shell => self.shell.lock().await.send(raw).await?, + Stdin => self.stdin.lock().await.send(raw).await?, + IoPub => panic!("Cannot send over IOPub"), + } + Ok(msg) + } + + async fn send( + &self, + channel: JupyterChannel, + msg_type: &str, + content: Value, + ) -> Result { + let msg = JupyterMsg::new(self.session, msg_type, content); + self.send_msg(channel, msg).await + } + + async fn recv(&self, channel: JupyterChannel) -> Result { + Ok(match channel { + Control => { + self + .recv_with_timeout(&mut *self.control.lock().await) + .await? + } + Shell => { + self + .recv_with_timeout(&mut *self.shell.lock().await) + .await? + } + Stdin => { + self + .recv_with_timeout(&mut *self.stdin.lock().await) + .await? + } + IoPub => { + self + .recv_with_timeout(&mut *self.io_pub.lock().await) + .await? + } + }) + } + + async fn send_heartbeat(&self, bytes: impl AsRef<[u8]>) -> Result<()> { + Ok( + self + .heartbeat + .lock() + .await + .send(ZmqMessage::from(bytes.as_ref().to_vec())) + .await?, + ) + } + + async fn recv_heartbeat(&self) -> Result { + Ok( + timeout(self.recv_timeout, self.heartbeat.lock().await.recv()) + .await?? + .into_vec()[0] + .clone(), + ) + } +} + +async fn wait_or_kill( + mut process: DenoChild, + wait: Duration, +) -> Result { + let start = std::time::Instant::now(); + while start.elapsed() < wait { + if process.try_wait()?.is_some() { + return Ok(process.wait_with_output()?); + } + tokio::time::sleep(Duration::from_millis(100)).await; + } + process.kill()?; + Ok(process.wait_with_output()?) +} + +// Wrapper around the Jupyter server process that +// ensures the process is killed when dropped. +struct JupyterServerProcess(Option); + +impl JupyterServerProcess { + // Wait for the process to exit, or kill it after the given duration. + // + // Ideally we could use this at the end of each test, but the server + // doesn't seem to exit in a reasonable amount of time after getting + // a shutdown request. + #[allow(dead_code)] + async fn wait_or_kill(mut self, wait: Duration) -> Output { + wait_or_kill(self.0.take().unwrap(), wait).await.unwrap() + } +} + +impl Drop for JupyterServerProcess { + fn drop(&mut self) { + let Some(mut proc) = self.0.take() else { + return; + }; + if proc.try_wait().unwrap().is_some() { + // already exited + return; + } + proc.kill().unwrap(); + } +} + +fn setup_server() -> (TestContext, ConnectionSpec, JupyterServerProcess) { + let context = TestContextBuilder::new().use_temp_cwd().build(); + let conn = ConnectionSpec::default(); + let conn_file = context.temp_dir().path().join("connection.json"); + conn_file.write_json(&conn); + let process = context + .new_command() + .piped_output() + .args_vec(vec![ + "jupyter", + "--kernel", + "--conn", + conn_file.to_string().as_str(), + ]) + .spawn() + .unwrap(); + (context, conn, JupyterServerProcess(Some(process))) +} + +async fn setup() -> (TestContext, JupyterClient, JupyterServerProcess) { + let (context, conn, process) = setup_server(); + let client = JupyterClient::new(&conn).await; + client.io_subscribe("").await.unwrap(); + + (context, client, process) +} + +/// Asserts that the actual value is equal to the expected value, but +/// only for the keys present in the expected value. +/// In other words, `assert_eq_subset(json!({"a": 1, "b": 2}), json!({"a": 1}))` would pass. +#[track_caller] +fn assert_eq_subset(actual: Value, expected: Value) { + match (actual, expected) { + (Value::Object(actual), Value::Object(expected)) => { + for (k, v) in expected.iter() { + let Some(actual_v) = actual.get(k) else { + panic!("Key {k:?} not found in actual value ({actual:#?})"); + }; + assert_eq_subset(actual_v.clone(), v.clone()); + } + } + (Value::Array(actual), Value::Array(expected)) => { + for (i, v) in expected.iter().enumerate() { + assert_eq_subset(actual[i].clone(), v.clone()); + } + } + (actual, expected) => { + assert_eq!(actual, expected); + } + } +} + +#[tokio::test] +async fn jupyter_heartbeat_echoes() -> Result<()> { + let (_ctx, client, _process) = setup().await; + client.send_heartbeat(b"ping").await?; + let msg = client.recv_heartbeat().await?; + assert_eq!(msg, Bytes::from_static(b"ping")); + + Ok(()) +} + +#[tokio::test] +async fn jupyter_kernel_info() -> Result<()> { + let (_ctx, client, _process) = setup().await; + client + .send(Control, "kernel_info_request", json!({})) + .await?; + let msg = client.recv(Control).await?; + assert_eq!(msg.header.msg_type, "kernel_info_reply"); + assert_eq_subset( + msg.content, + json!({ + "status": "ok", + "implementation": "Deno kernel", + "language_info": { + "name": "typescript", + "mimetype": "text/x.typescript", + "file_extension": ".ts", + "pygments_lexer": "typescript", + "nb_converter": "script" + }, + }), + ); + + Ok(()) +} + +#[tokio::test] +async fn jupyter_execute_request() -> Result<()> { + let (_ctx, client, _process) = setup().await; + let request = client + .send( + Shell, + "execute_request", + json!({ + "silent": false, + "store_history": true, + "user_expressions": {}, + "allow_stdin": true, + "stop_on_error": false, + "code": "console.log(\"asdf\")" + }), + ) + .await?; + let reply = client.recv(Shell).await?; + assert_eq!(reply.header.msg_type, "execute_reply"); + assert_eq_subset( + reply.content, + json!({ + "status": "ok", + "execution_count": 1, + }), + ); + + let mut msgs = Vec::new(); + + for _ in 0..4 { + match client.recv(IoPub).await { + Ok(msg) => msgs.push(msg), + Err(e) => { + if e.downcast_ref::().is_some() { + // may timeout if we missed some messages + break; + } + panic!("Error: {:#?}", e); + } + } + } + + let execution_idle = msgs + .iter() + .find(|msg| { + if let Some(state) = msg.content.get("execution_state") { + state == "idle" + } else { + false + } + }) + .expect("execution_state idle not found"); + assert_eq!(execution_idle.parent_header, request.header.to_json()); + assert_eq_subset( + execution_idle.content.clone(), + json!({ + "execution_state": "idle", + }), + ); + + let execution_result = msgs + .iter() + .find(|msg| msg.header.msg_type == "stream") + .expect("stream not found"); + assert_eq!(execution_result.header.msg_type, "stream"); + assert_eq!(execution_result.parent_header, request.header.to_json()); + assert_eq_subset( + execution_result.content.clone(), + json!({ + "name": "stdout", + "text": "asdf\n", // the trailing newline is added by console.log + }), + ); + + Ok(()) +} diff --git a/tests/integration/mod.rs b/tests/integration/mod.rs index 40dd880b54..9253cae32e 100644 --- a/tests/integration/mod.rs +++ b/tests/integration/mod.rs @@ -38,6 +38,8 @@ mod install; mod js_unit_tests; #[path = "jsr_tests.rs"] mod jsr; +#[path = "jupyter_tests.rs"] +mod jupyter; #[path = "lint_tests.rs"] mod lint; #[path = "lsp_tests.rs"]