1
0
Fork 0
mirror of https://github.com/denoland/deno.git synced 2025-01-21 21:50:00 -05:00

chore(tests): windows pty tests (#12091)

This commit is contained in:
David Sherret 2021-09-20 22:15:44 -04:00 committed by GitHub
parent 60b68e63f1
commit 0f23d92601
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 563 additions and 174 deletions

42
Cargo.lock generated
View file

@ -597,7 +597,6 @@ dependencies = [
"dprint-plugin-typescript", "dprint-plugin-typescript",
"encoding_rs", "encoding_rs",
"env_logger", "env_logger",
"exec",
"fancy-regex", "fancy-regex",
"flaky_test", "flaky_test",
"fwdansi", "fwdansi",
@ -1183,27 +1182,6 @@ dependencies = [
"winapi 0.2.8", "winapi 0.2.8",
] ]
[[package]]
name = "errno"
version = "0.2.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa68f2fb9cae9d37c9b2b3584aba698a2e97f72d7aef7b9f7aa71d8b54ce46fe"
dependencies = [
"errno-dragonfly",
"libc",
"winapi 0.3.9",
]
[[package]]
name = "errno-dragonfly"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "14ca354e36190500e1e1fb267c647932382b54053c50b14970856c0b00a35067"
dependencies = [
"gcc",
"libc",
]
[[package]] [[package]]
name = "error-code" name = "error-code"
version = "2.3.0" version = "2.3.0"
@ -1214,16 +1192,6 @@ dependencies = [
"str-buf", "str-buf",
] ]
[[package]]
name = "exec"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "886b70328cba8871bfc025858e1de4be16b1d5088f2ba50b57816f4210672615"
dependencies = [
"errno 0.2.7",
"libc",
]
[[package]] [[package]]
name = "fallible-iterator" name = "fallible-iterator"
version = "0.2.0" version = "0.2.0"
@ -1490,12 +1458,6 @@ dependencies = [
"byteorder", "byteorder",
] ]
[[package]]
name = "gcc"
version = "0.3.55"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f5f3913fa0bfe7ee1fd8248b6b9f42a5af4b9d65ec2dd2c3c26132b950ecfc2"
[[package]] [[package]]
name = "generic-array" name = "generic-array"
version = "0.14.4" version = "0.14.4"
@ -2636,7 +2598,7 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f50f3d255966981eb4e4c5df3e983e6f7d163221f547406d83b6a460ff5c5ee8" checksum = "f50f3d255966981eb4e4c5df3e983e6f7d163221f547406d83b6a460ff5c5ee8"
dependencies = [ dependencies = [
"errno 0.1.8", "errno",
"libc", "libc",
] ]
@ -3881,6 +3843,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-stream", "async-stream",
"atty",
"base64 0.13.0", "base64 0.13.0",
"futures", "futures",
"hyper", "hyper",
@ -3894,6 +3857,7 @@ dependencies = [
"tokio", "tokio",
"tokio-rustls", "tokio-rustls",
"tokio-tungstenite", "tokio-tungstenite",
"winapi 0.3.9",
] ]
[[package]] [[package]]

View file

@ -100,7 +100,6 @@ trust-dns-client = "0.20.3"
trust-dns-server = "0.20.3" trust-dns-server = "0.20.3"
[target.'cfg(unix)'.dev-dependencies] [target.'cfg(unix)'.dev-dependencies]
exec = "0.3.1" # Used in test_raw_tty
nix = "0.22.1" nix = "0.22.1"
[package.metadata.winres] [package.metadata.winres]

View file

@ -2,27 +2,23 @@
use test_util as util; use test_util as util;
#[cfg(unix)]
#[test] #[test]
fn pty_multiline() { fn pty_multiline() {
use std::io::{Read, Write}; util::with_pty(&["repl"], |mut console| {
run_pty_test(|master| { console.write_line("(\n1 + 2\n)");
master.write_all(b"(\n1 + 2\n)\n").unwrap(); console.write_line("{\nfoo: \"foo\"\n}");
master.write_all(b"{\nfoo: \"foo\"\n}\n").unwrap(); console.write_line("`\nfoo\n`");
master.write_all(b"`\nfoo\n`\n").unwrap(); console.write_line("`\n\\`\n`");
master.write_all(b"`\n\\`\n`\n").unwrap(); console.write_line("'{'");
master.write_all(b"'{'\n").unwrap(); console.write_line("'('");
master.write_all(b"'('\n").unwrap(); console.write_line("'['");
master.write_all(b"'['\n").unwrap(); console.write_line("/{/");
master.write_all(b"/{/\n").unwrap(); console.write_line("/\\(/");
master.write_all(b"/\\(/\n").unwrap(); console.write_line("/\\[/");
master.write_all(b"/\\[/\n").unwrap(); console.write_line("console.log(\"{test1} abc {test2} def {{test3}}\".match(/{([^{].+?)}/));");
master.write_all(b"console.log(\"{test1} abc {test2} def {{test3}}\".match(/{([^{].+?)}/));\n").unwrap(); console.write_line("close();");
master.write_all(b"close();\n").unwrap();
let mut output = String::new();
master.read_to_string(&mut output).unwrap();
let output = console.read_all_output();
assert!(output.contains('3')); assert!(output.contains('3'));
assert!(output.contains("{ foo: \"foo\" }")); assert!(output.contains("{ foo: \"foo\" }"));
assert!(output.contains("\"\\nfoo\\n\"")); assert!(output.contains("\"\\nfoo\\n\""));
@ -37,109 +33,85 @@ fn pty_multiline() {
}); });
} }
#[cfg(unix)]
#[test] #[test]
fn pty_unpaired_braces() { fn pty_unpaired_braces() {
use std::io::{Read, Write}; util::with_pty(&["repl"], |mut console| {
run_pty_test(|master| { console.write_line(")");
master.write_all(b")\n").unwrap(); console.write_line("]");
master.write_all(b"]\n").unwrap(); console.write_line("}");
master.write_all(b"}\n").unwrap(); console.write_line("close();");
master.write_all(b"close();\n").unwrap();
let mut output = String::new();
master.read_to_string(&mut output).unwrap();
let output = console.read_all_output();
assert!(output.contains("Unexpected token `)`")); assert!(output.contains("Unexpected token `)`"));
assert!(output.contains("Unexpected token `]`")); assert!(output.contains("Unexpected token `]`"));
assert!(output.contains("Unexpected token `}`")); assert!(output.contains("Unexpected token `}`"));
}); });
} }
#[cfg(unix)]
#[test] #[test]
fn pty_bad_input() { fn pty_bad_input() {
use std::io::{Read, Write}; util::with_pty(&["repl"], |mut console| {
run_pty_test(|master| { console.write_line("'\\u{1f3b5}'[0]");
master.write_all(b"'\\u{1f3b5}'[0]\n").unwrap(); console.write_line("close();");
master.write_all(b"close();\n").unwrap();
let mut output = String::new();
master.read_to_string(&mut output).unwrap();
let output = console.read_all_output();
assert!(output.contains("Unterminated string literal")); assert!(output.contains("Unterminated string literal"));
}); });
} }
#[cfg(unix)]
#[test] #[test]
fn pty_syntax_error_input() { fn pty_syntax_error_input() {
use std::io::{Read, Write}; util::with_pty(&["repl"], |mut console| {
run_pty_test(|master| { console.write_line("('\\u')");
master.write_all(b"('\\u')\n").unwrap(); console.write_line("('");
master.write_all(b"('\n").unwrap(); console.write_line("close();");
master.write_all(b"close();\n").unwrap();
let mut output = String::new();
master.read_to_string(&mut output).unwrap();
let output = console.read_all_output();
assert!(output.contains("Unterminated string constant")); assert!(output.contains("Unterminated string constant"));
assert!(output.contains("Unexpected eof")); assert!(output.contains("Unexpected eof"));
}); });
} }
#[cfg(unix)]
#[test] #[test]
fn pty_complete_symbol() { fn pty_complete_symbol() {
use std::io::{Read, Write}; util::with_pty(&["repl"], |mut console| {
run_pty_test(|master| { console.write_line("Symbol.it\t");
master.write_all(b"Symbol.it\t\n").unwrap(); console.write_line("close();");
master.write_all(b"close();\n").unwrap();
let mut output = String::new();
master.read_to_string(&mut output).unwrap();
let output = console.read_all_output();
assert!(output.contains("Symbol(Symbol.iterator)")); assert!(output.contains("Symbol(Symbol.iterator)"));
}); });
} }
#[cfg(unix)]
#[test] #[test]
fn pty_complete_declarations() { fn pty_complete_declarations() {
use std::io::{Read, Write}; util::with_pty(&["repl"], |mut console| {
run_pty_test(|master| { console.write_line("class MyClass {}");
master.write_all(b"class MyClass {}\n").unwrap(); console.write_line("My\t");
master.write_all(b"My\t\n").unwrap(); console.write_line("let myVar;");
master.write_all(b"let myVar;\n").unwrap(); console.write_line("myV\t");
master.write_all(b"myV\t\n").unwrap(); console.write_line("close();");
master.write_all(b"close();\n").unwrap();
let mut output = String::new();
master.read_to_string(&mut output).unwrap();
let output = console.read_all_output();
assert!(output.contains("> MyClass")); assert!(output.contains("> MyClass"));
assert!(output.contains("> myVar")); assert!(output.contains("> myVar"));
}); });
} }
#[cfg(unix)]
#[test] #[test]
fn pty_complete_primitives() { fn pty_complete_primitives() {
use std::io::{Read, Write}; util::with_pty(&["repl"], |mut console| {
run_pty_test(|master| { console.write_line("let func = function test(){}");
master.write_all(b"let func = function test(){}\n").unwrap(); console.write_line("func.appl\t");
master.write_all(b"func.appl\t\n").unwrap(); console.write_line("let str = ''");
master.write_all(b"let str = ''\n").unwrap(); console.write_line("str.leng\t");
master.write_all(b"str.leng\t\n").unwrap(); console.write_line("false.valueO\t");
master.write_all(b"false.valueO\t\n").unwrap(); console.write_line("5n.valueO\t");
master.write_all(b"5n.valueO\t\n").unwrap(); console.write_line("let num = 5");
master.write_all(b"let num = 5\n").unwrap(); console.write_line("num.toStrin\t");
master.write_all(b"num.toStrin\t\n").unwrap(); console.write_line("close();");
master.write_all(b"close();\n").unwrap();
let mut output = String::new();
master.read_to_string(&mut output).unwrap();
let output = console.read_all_output();
assert!(output.contains("> func.apply")); assert!(output.contains("> func.apply"));
assert!(output.contains("> str.length")); assert!(output.contains("> str.length"));
assert!(output.contains("> 5n.valueOf")); assert!(output.contains("> 5n.valueOf"));
@ -148,17 +120,13 @@ fn pty_complete_primitives() {
}); });
} }
#[cfg(unix)]
#[test] #[test]
fn pty_ignore_symbols() { fn pty_ignore_symbols() {
use std::io::{Read, Write}; util::with_pty(&["repl"], |mut console| {
run_pty_test(|master| { console.write_line("Array.Symbol\t");
master.write_all(b"Array.Symbol\t\n").unwrap(); console.write_line("close();");
master.write_all(b"close();\n").unwrap();
let mut output = String::new();
master.read_to_string(&mut output).unwrap();
let output = console.read_all_output();
assert!(output.contains("undefined")); assert!(output.contains("undefined"));
assert!( assert!(
!output.contains("Uncaught TypeError: Array.Symbol is not a function") !output.contains("Uncaught TypeError: Array.Symbol is not a function")
@ -166,22 +134,6 @@ fn pty_ignore_symbols() {
}); });
} }
#[cfg(unix)]
fn run_pty_test(mut run: impl FnMut(&mut util::pty::fork::Master)) {
use util::pty::fork::*;
let deno_exe = util::deno_exe_path();
let fork = Fork::from_ptmx().unwrap();
if let Ok(mut master) = fork.is_parent() {
run(&mut master);
fork.wait().unwrap();
} else {
std::env::set_var("NO_COLOR", "1");
let err = exec::Command::new(deno_exe).arg("repl").exec();
println!("err {}", err);
unreachable!()
}
}
#[test] #[test]
fn console_log() { fn console_log() {
let (out, err) = util::run_and_collect_output( let (out, err) = util::run_and_collect_output(

View file

@ -335,7 +335,6 @@ itest!(_089_run_allow_list {
output: "089_run_allow_list.ts.out", output: "089_run_allow_list.ts.out",
}); });
#[cfg(unix)]
#[test] #[test]
fn _090_run_permissions_request() { fn _090_run_permissions_request() {
let args = "run --quiet 090_run_permissions_request.ts"; let args = "run --quiet 090_run_permissions_request.ts";
@ -1726,7 +1725,6 @@ mod permissions {
assert!(!err.contains(util::PERMISSION_DENIED_PATTERN)); assert!(!err.contains(util::PERMISSION_DENIED_PATTERN));
} }
#[cfg(unix)]
#[test] #[test]
fn _061_permissions_request() { fn _061_permissions_request() {
let args = "run --quiet 061_permissions_request.ts"; let args = "run --quiet 061_permissions_request.ts";
@ -1742,7 +1740,6 @@ mod permissions {
]); ]);
} }
#[cfg(unix)]
#[test] #[test]
fn _062_permissions_request_global() { fn _062_permissions_request_global() {
let args = "run --quiet 062_permissions_request_global.ts"; let args = "run --quiet 062_permissions_request_global.ts";
@ -1766,7 +1763,6 @@ mod permissions {
output: "064_permissions_revoke_global.ts.out", output: "064_permissions_revoke_global.ts.out",
}); });
#[cfg(unix)]
#[test] #[test]
fn _066_prompt() { fn _066_prompt() {
let args = "run --quiet --unstable 066_prompt.ts"; let args = "run --quiet --unstable 066_prompt.ts";
@ -1861,7 +1857,6 @@ itest!(byte_order_mark {
output: "byte_order_mark.out", output: "byte_order_mark.out",
}); });
#[cfg(unix)]
#[test] #[test]
fn issue9750() { fn issue9750() {
use util::PtyData::*; use util::PtyData::*;

View file

@ -14,6 +14,7 @@ path = "src/test_server.rs"
[dependencies] [dependencies]
anyhow = "1.0.43" anyhow = "1.0.43"
async-stream = "0.3.2" async-stream = "0.3.2"
atty = "0.2.14"
base64 = "0.13.0" base64 = "0.13.0"
futures = "0.3.16" futures = "0.3.16"
hyper = { version = "0.14.12", features = ["server", "http1", "runtime"] } hyper = { version = "0.14.12", features = ["server", "http1", "runtime"] }
@ -29,3 +30,6 @@ tokio-tungstenite = "0.14.0"
[target.'cfg(unix)'.dependencies] [target.'cfg(unix)'.dependencies]
pty = "0.2.2" pty = "0.2.2"
[target.'cfg(windows)'.dependencies]
winapi = { version = "0.3.9", features = ["consoleapi", "handleapi", "namedpipeapi", "winbase", "winerror"] }

View file

@ -44,10 +44,8 @@ use tokio_rustls::rustls::{self, Session};
use tokio_rustls::TlsAcceptor; use tokio_rustls::TlsAcceptor;
use tokio_tungstenite::accept_async; use tokio_tungstenite::accept_async;
#[cfg(unix)]
pub use pty;
pub mod lsp; pub mod lsp;
pub mod pty;
const PORT: u16 = 4545; const PORT: u16 = 4545;
const TEST_AUTH_TOKEN: &str = "abcdef123456789"; const TEST_AUTH_TOKEN: &str = "abcdef123456789";
@ -1589,62 +1587,97 @@ pub enum PtyData {
Output(&'static str), Output(&'static str),
} }
#[cfg(unix)]
pub fn test_pty2(args: &str, data: Vec<PtyData>) { pub fn test_pty2(args: &str, data: Vec<PtyData>) {
use pty::fork::Fork;
use std::io::BufRead; use std::io::BufRead;
let tests_path = testdata_path(); with_pty(&args.split_whitespace().collect::<Vec<_>>(), |console| {
let fork = Fork::from_ptmx().unwrap(); let mut buf_reader = std::io::BufReader::new(console);
if let Ok(master) = fork.is_parent() { for d in data.iter() {
let mut buf_reader = std::io::BufReader::new(master);
for d in data {
match d { match d {
PtyData::Input(s) => { PtyData::Input(s) => {
println!("INPUT {}", s.escape_debug()); println!("INPUT {}", s.escape_debug());
buf_reader.get_mut().write_all(s.as_bytes()).unwrap(); buf_reader.get_mut().write_text(s);
// Because of tty echo, we should be able to read the same string back. // Because of tty echo, we should be able to read the same string back.
assert!(s.ends_with('\n')); assert!(s.ends_with('\n'));
let mut echo = String::new(); let mut echo = String::new();
buf_reader.read_line(&mut echo).unwrap(); buf_reader.read_line(&mut echo).unwrap();
println!("ECHO: {}", echo.escape_debug()); println!("ECHO: {}", echo.escape_debug());
assert!(echo.starts_with(&s.trim()));
// Windows may also echo the previous line, so only check the end
assert!(normalize_text(&echo).ends_with(&normalize_text(s)));
} }
PtyData::Output(s) => { PtyData::Output(s) => {
let mut line = String::new(); let mut line = String::new();
if s.ends_with('\n') { if s.ends_with('\n') {
buf_reader.read_line(&mut line).unwrap(); buf_reader.read_line(&mut line).unwrap();
} else { } else {
while s != line { // assumes the buffer won't have overlapping virtual terminal sequences
while normalize_text(&line).len() < normalize_text(s).len() {
let mut buf = [0; 64 * 1024]; let mut buf = [0; 64 * 1024];
let _n = buf_reader.read(&mut buf).unwrap(); let bytes_read = buf_reader.read(&mut buf).unwrap();
assert!(bytes_read > 0);
let buf_str = std::str::from_utf8(&buf) let buf_str = std::str::from_utf8(&buf)
.unwrap() .unwrap()
.trim_end_matches(char::from(0)); .trim_end_matches(char::from(0));
line += buf_str; line += buf_str;
assert!(s.starts_with(&line));
} }
} }
println!("OUTPUT {}", line.escape_debug()); println!("OUTPUT {}", line.escape_debug());
assert_eq!(line, s); assert_eq!(normalize_text(&line), normalize_text(s));
} }
} }
} }
});
// This normalization function is not comprehensive
// and may need to updated as new scenarios emerge.
fn normalize_text(text: &str) -> String {
lazy_static! {
static ref MOVE_CURSOR_RIGHT_ONE_RE: Regex =
Regex::new(r"\x1b\[1C").unwrap();
static ref FOUND_SEQUENCES_RE: Regex =
Regex::new(r"(\x1b\]0;[^\x07]*\x07)*(\x08)*(\x1b\[\d+X)*").unwrap();
static ref CARRIAGE_RETURN_RE: Regex =
Regex::new(r"[^\n]*\r([^\n])").unwrap();
}
fork.wait().unwrap(); // any "move cursor right" sequences should just be a space
} else { let text = MOVE_CURSOR_RIGHT_ONE_RE.replace_all(text, " ");
deno_cmd() // replace additional virtual terminal sequences that strip ansi codes doesn't catch
.current_dir(tests_path) let text = FOUND_SEQUENCES_RE.replace_all(&text, "");
.env("NO_COLOR", "1") // strip any ansi codes, which also strips more terminal sequences
.args(args.split_whitespace()) let text = strip_ansi_codes(&text);
.spawn() // get rid of any text that is overwritten with only a carriage return
.unwrap() let text = CARRIAGE_RETURN_RE.replace_all(&text, "$1");
.wait() // finally, trim surrounding whitespace
.unwrap(); text.trim().to_string()
} }
} }
pub fn with_pty(deno_args: &[&str], mut action: impl FnMut(Box<dyn pty::Pty>)) {
if !atty::is(atty::Stream::Stdin) || !atty::is(atty::Stream::Stderr) {
eprintln!("Ignoring non-tty environment.");
return;
}
let deno_dir = new_deno_dir();
let mut env_vars = std::collections::HashMap::new();
env_vars.insert("NO_COLOR".to_string(), "1".to_string());
env_vars.insert(
"DENO_DIR".to_string(),
deno_dir.path().to_string_lossy().to_string(),
);
let pty = pty::create_pty(
&deno_exe_path().to_string_lossy().to_string(),
deno_args,
testdata_path(),
Some(env_vars),
);
action(pty);
}
pub struct WrkOutput { pub struct WrkOutput {
pub latency: f64, pub latency: f64,
pub requests: u64, pub requests: u64,

442
test_util/src/pty.rs Normal file
View file

@ -0,0 +1,442 @@
use std::collections::HashMap;
use std::io::Read;
use std::path::Path;
pub trait Pty: Read {
fn write_text(&mut self, text: &str);
fn write_line(&mut self, text: &str) {
self.write_text(&format!("{}\n", text));
}
/// Reads the output to the EOF.
fn read_all_output(&mut self) -> String {
let mut text = String::new();
self.read_to_string(&mut text).unwrap();
text
}
}
#[cfg(unix)]
pub fn create_pty(
program: impl AsRef<Path>,
args: &[&str],
cwd: impl AsRef<Path>,
env_vars: Option<HashMap<String, String>>,
) -> Box<dyn Pty> {
let fork = pty::fork::Fork::from_ptmx().unwrap();
if fork.is_parent().is_ok() {
Box::new(unix::UnixPty { fork })
} else {
std::process::Command::new(program.as_ref())
.current_dir(cwd)
.args(args)
.envs(env_vars.unwrap_or_default())
.spawn()
.unwrap()
.wait()
.unwrap();
unreachable!();
}
}
#[cfg(unix)]
mod unix {
use std::io::Read;
use std::io::Write;
use super::Pty;
pub struct UnixPty {
pub fork: pty::fork::Fork,
}
impl Drop for UnixPty {
fn drop(&mut self) {
self.fork.wait().unwrap();
}
}
impl Pty for UnixPty {
fn write_text(&mut self, text: &str) {
let mut master = self.fork.is_parent().unwrap();
master.write_all(text.as_bytes()).unwrap();
}
}
impl Read for UnixPty {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let mut master = self.fork.is_parent().unwrap();
master.read(buf)
}
}
}
#[cfg(target_os = "windows")]
pub fn create_pty(
program: impl AsRef<Path>,
args: &[&str],
cwd: impl AsRef<Path>,
env_vars: Option<HashMap<String, String>>,
) -> Box<dyn Pty> {
let pty = windows::WinPseudoConsole::new(
program,
args,
&cwd.as_ref().to_string_lossy().to_string(),
env_vars,
);
Box::new(pty)
}
#[cfg(target_os = "windows")]
mod windows {
use std::collections::HashMap;
use std::io::Read;
use std::io::Write;
use std::path::Path;
use std::ptr;
use std::time::Duration;
use winapi::shared::minwindef::FALSE;
use winapi::shared::minwindef::LPVOID;
use winapi::shared::minwindef::TRUE;
use winapi::shared::winerror::S_OK;
use winapi::um::consoleapi::ClosePseudoConsole;
use winapi::um::consoleapi::CreatePseudoConsole;
use winapi::um::fileapi::ReadFile;
use winapi::um::fileapi::WriteFile;
use winapi::um::handleapi::DuplicateHandle;
use winapi::um::handleapi::INVALID_HANDLE_VALUE;
use winapi::um::namedpipeapi::CreatePipe;
use winapi::um::processthreadsapi::CreateProcessW;
use winapi::um::processthreadsapi::DeleteProcThreadAttributeList;
use winapi::um::processthreadsapi::GetCurrentProcess;
use winapi::um::processthreadsapi::InitializeProcThreadAttributeList;
use winapi::um::processthreadsapi::UpdateProcThreadAttribute;
use winapi::um::processthreadsapi::LPPROC_THREAD_ATTRIBUTE_LIST;
use winapi::um::processthreadsapi::PROCESS_INFORMATION;
use winapi::um::synchapi::WaitForSingleObject;
use winapi::um::winbase::CREATE_UNICODE_ENVIRONMENT;
use winapi::um::winbase::EXTENDED_STARTUPINFO_PRESENT;
use winapi::um::winbase::INFINITE;
use winapi::um::winbase::STARTUPINFOEXW;
use winapi::um::wincontypes::COORD;
use winapi::um::wincontypes::HPCON;
use winapi::um::winnt::DUPLICATE_SAME_ACCESS;
use winapi::um::winnt::HANDLE;
use super::Pty;
macro_rules! assert_win_success {
($expression:expr) => {
let success = $expression;
if success != TRUE {
panic!("{}", std::io::Error::last_os_error().to_string())
}
};
}
pub struct WinPseudoConsole {
stdin_write_handle: WinHandle,
stdout_read_handle: WinHandle,
// keep these alive for the duration of the pseudo console
_process_handle: WinHandle,
_thread_handle: WinHandle,
_attribute_list: ProcThreadAttributeList,
}
impl WinPseudoConsole {
pub fn new(
program: impl AsRef<Path>,
args: &[&str],
cwd: &str,
maybe_env_vars: Option<HashMap<String, String>>,
) -> Self {
// https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session
unsafe {
let mut size: COORD = std::mem::zeroed();
size.X = 800;
size.Y = 500;
let mut console_handle = std::ptr::null_mut();
let (stdin_read_handle, stdin_write_handle) = create_pipe();
let (stdout_read_handle, stdout_write_handle) = create_pipe();
let result = CreatePseudoConsole(
size,
stdin_read_handle.as_raw_handle(),
stdout_write_handle.as_raw_handle(),
0,
&mut console_handle,
);
assert_eq!(result, S_OK);
let mut environment_vars = maybe_env_vars.map(get_env_vars);
let mut attribute_list = ProcThreadAttributeList::new(console_handle);
let mut startup_info: STARTUPINFOEXW = std::mem::zeroed();
startup_info.StartupInfo.cb =
std::mem::size_of::<STARTUPINFOEXW>() as u32;
startup_info.lpAttributeList = attribute_list.as_mut_ptr();
let mut proc_info: PROCESS_INFORMATION = std::mem::zeroed();
let command = format!(
"\"{}\" {}",
program.as_ref().to_string_lossy(),
args.join(" ")
)
.trim()
.to_string();
let mut application_str =
to_windows_str(&program.as_ref().to_string_lossy());
let mut command_str = to_windows_str(&command);
let mut cwd = to_windows_str(cwd);
assert_win_success!(CreateProcessW(
application_str.as_mut_ptr(),
command_str.as_mut_ptr(),
ptr::null_mut(),
ptr::null_mut(),
FALSE,
EXTENDED_STARTUPINFO_PRESENT | CREATE_UNICODE_ENVIRONMENT,
environment_vars
.as_mut()
.map(|v| v.as_mut_ptr() as LPVOID)
.unwrap_or(ptr::null_mut()),
cwd.as_mut_ptr(),
&mut startup_info.StartupInfo,
&mut proc_info,
));
// close the handles that the pseudoconsole now has
drop(stdin_read_handle);
drop(stdout_write_handle);
// start a thread that will close the pseudoconsole on process exit
let thread_handle = WinHandle::new(proc_info.hThread);
std::thread::spawn({
let thread_handle = thread_handle.duplicate();
let console_handle = WinHandle::new(console_handle);
move || {
WaitForSingleObject(thread_handle.as_raw_handle(), INFINITE);
// wait for the reading thread to catch up
std::thread::sleep(Duration::from_millis(200));
// close the console handle which will close the
// stdout pipe for the reader
ClosePseudoConsole(console_handle.into_raw_handle());
}
});
Self {
stdin_write_handle,
stdout_read_handle,
_process_handle: WinHandle::new(proc_info.hProcess),
_thread_handle: thread_handle,
_attribute_list: attribute_list,
}
}
}
}
impl Read for WinPseudoConsole {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
unsafe {
loop {
let mut bytes_read = 0;
let success = ReadFile(
self.stdout_read_handle.as_raw_handle(),
buf.as_mut_ptr() as _,
buf.len() as u32,
&mut bytes_read,
ptr::null_mut(),
);
// ignore zero-byte writes
let is_zero_byte_write = bytes_read == 0 && success == TRUE;
if !is_zero_byte_write {
return Ok(bytes_read as usize);
}
}
}
}
}
impl Pty for WinPseudoConsole {
fn write_text(&mut self, text: &str) {
// windows psuedo console requires a \r\n to do a newline
let newline_re = regex::Regex::new("\r?\n").unwrap();
self
.write_all(newline_re.replace_all(text, "\r\n").as_bytes())
.unwrap();
}
}
impl std::io::Write for WinPseudoConsole {
fn write(&mut self, buffer: &[u8]) -> std::io::Result<usize> {
unsafe {
let mut bytes_written = 0;
assert_win_success!(WriteFile(
self.stdin_write_handle.as_raw_handle(),
buffer.as_ptr() as *const _,
buffer.len() as u32,
&mut bytes_written,
ptr::null_mut(),
));
Ok(bytes_written as usize)
}
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
struct WinHandle {
inner: HANDLE,
}
impl WinHandle {
pub fn new(handle: HANDLE) -> Self {
WinHandle { inner: handle }
}
pub fn duplicate(&self) -> WinHandle {
unsafe {
let process_handle = GetCurrentProcess();
let mut duplicate_handle = ptr::null_mut();
assert_win_success!(DuplicateHandle(
process_handle,
self.inner,
process_handle,
&mut duplicate_handle,
0,
0,
DUPLICATE_SAME_ACCESS,
));
WinHandle::new(duplicate_handle)
}
}
pub fn as_raw_handle(&self) -> HANDLE {
self.inner
}
pub fn into_raw_handle(self) -> HANDLE {
let handle = self.inner;
// skip the drop implementation in order to not close the handle
std::mem::forget(self);
handle
}
}
unsafe impl Send for WinHandle {}
unsafe impl Sync for WinHandle {}
impl Drop for WinHandle {
fn drop(&mut self) {
unsafe {
if !self.inner.is_null() && self.inner != INVALID_HANDLE_VALUE {
winapi::um::handleapi::CloseHandle(self.inner);
}
}
}
}
struct ProcThreadAttributeList {
buffer: Vec<u8>,
}
impl ProcThreadAttributeList {
pub fn new(console_handle: HPCON) -> Self {
unsafe {
// discover size required for the list
let mut size = 0;
let attribute_count = 1;
assert_eq!(
InitializeProcThreadAttributeList(
ptr::null_mut(),
attribute_count,
0,
&mut size
),
FALSE
);
let mut buffer = vec![0u8; size];
let attribute_list_ptr = buffer.as_mut_ptr() as _;
assert_win_success!(InitializeProcThreadAttributeList(
attribute_list_ptr,
attribute_count,
0,
&mut size,
));
const PROC_THREAD_ATTRIBUTE_PSEUDOCONSOLE: usize = 0x00020016;
assert_win_success!(UpdateProcThreadAttribute(
attribute_list_ptr,
0,
PROC_THREAD_ATTRIBUTE_PSEUDOCONSOLE,
console_handle,
std::mem::size_of::<HPCON>(),
ptr::null_mut(),
ptr::null_mut(),
));
ProcThreadAttributeList { buffer }
}
}
pub fn as_mut_ptr(&mut self) -> LPPROC_THREAD_ATTRIBUTE_LIST {
self.buffer.as_mut_slice().as_mut_ptr() as *mut _
}
}
impl Drop for ProcThreadAttributeList {
fn drop(&mut self) {
unsafe { DeleteProcThreadAttributeList(self.as_mut_ptr()) };
}
}
fn create_pipe() -> (WinHandle, WinHandle) {
unsafe {
let mut read_handle = std::ptr::null_mut();
let mut write_handle = std::ptr::null_mut();
assert_win_success!(CreatePipe(
&mut read_handle,
&mut write_handle,
ptr::null_mut(),
0
));
(WinHandle::new(read_handle), WinHandle::new(write_handle))
}
}
fn to_windows_str(str: &str) -> Vec<u16> {
use std::os::windows::prelude::OsStrExt;
std::ffi::OsStr::new(str)
.encode_wide()
.chain(Some(0))
.collect()
}
fn get_env_vars(env_vars: HashMap<String, String>) -> Vec<u16> {
// See lpEnvironment: https://docs.microsoft.com/en-us/windows/win32/api/processthreadsapi/nf-processthreadsapi-createprocessw
let mut parts = env_vars
.into_iter()
// each environment variable is in the form `name=value\0`
.map(|(key, value)| format!("{}={}\0", key, value))
.collect::<Vec<_>>();
// all strings in an environment block must be case insensitively
// sorted alphabetically by name
// https://docs.microsoft.com/en-us/windows/win32/procthread/changing-environment-variables
parts.sort_by_key(|part| part.to_lowercase());
// the entire block is terminated by NULL (\0)
format!("{}\0", parts.join(""))
.encode_utf16()
.collect::<Vec<_>>()
}
}