From 5ed61990da6492b7a21f410ecdf872c0e1dedde4 Mon Sep 17 00:00:00 2001 From: Rebecca Turner Date: Thu, 28 Sep 2023 10:07:04 -0700 Subject: [PATCH] Allow custom prompts in `.ghci` (#93) Previously, `:set prompt` commands in `.ghci` configuration files would cause `ghcid-ng` to hang on start. This fixes that bug. --- src/aho_corasick.rs | 9 +- src/ghci/mod.rs | 7 +- src/ghci/stdin.rs | 38 ++++++-- src/ghci/stdout.rs | 68 ++++++++------ src/incremental_reader.rs | 188 ++++++++++++++++++++++++++++++++------ tests/dot_ghci.rs | 29 ++++++ 6 files changed, 270 insertions(+), 69 deletions(-) create mode 100644 tests/dot_ghci.rs diff --git a/src/aho_corasick.rs b/src/aho_corasick.rs index 347bad2f..f77d5e3f 100644 --- a/src/aho_corasick.rs +++ b/src/aho_corasick.rs @@ -11,6 +11,9 @@ pub trait AhoCorasickExt { /// Attempt to match at the start of the input. fn find_at_start(&self, input: &str) -> Option; + /// Attempt to match anywhere in the input. + fn find_anywhere(&self, input: &str) -> Option; + /// Build a matcher from the given set of patterns, with anchored matching enabled (matching at /// the start of the string only). fn from_anchored_patterns(patterns: impl IntoIterator>) -> Self; @@ -21,9 +24,13 @@ impl AhoCorasickExt for AhoCorasick { self.find(Input::new(input).anchored(Anchored::Yes)) } + fn find_anywhere(&self, input: &str) -> Option { + self.find(Input::new(input).anchored(Anchored::No)) + } + fn from_anchored_patterns(patterns: impl IntoIterator>) -> Self { Self::builder() - .start_kind(StartKind::Anchored) + .start_kind(StartKind::Both) .build(patterns) .unwrap() } diff --git a/src/ghci/mod.rs b/src/ghci/mod.rs index 5c7bfc16..4fe31dcf 100644 --- a/src/ghci/mod.rs +++ b/src/ghci/mod.rs @@ -253,13 +253,14 @@ impl Ghci { ret.stderr_handle = stderr; // Wait for the stdout job to start up. - let messages = ret.stdout.initialize().await?; - ret.process_ghc_messages(messages).await?; + ret.stdout.initialize().await?; // Perform start-of-session initialization. - ret.stdin + let messages = ret + .stdin .initialize(&mut ret.stdout, &ret.opts.hooks.after_startup_ghci) .await?; + ret.process_ghc_messages(messages).await?; // Sync up for any prompts. ret.sync().await?; diff --git a/src/ghci/stdin.rs b/src/ghci/stdin.rs index cfbf3d7d..b5dd7948 100644 --- a/src/ghci/stdin.rs +++ b/src/ghci/stdin.rs @@ -10,6 +10,7 @@ use tokio::task::JoinSet; use tracing::instrument; use crate::haskell_show::HaskellShow; +use crate::incremental_reader::FindAt; use crate::sync_sentinel::SyncSentinel; use super::parse::GhcMessage; @@ -33,17 +34,32 @@ impl GhciStdin { /// Write a line on `stdin` and wait for a prompt on stdout. /// /// The `line` should contain the trailing newline. + /// + /// The `find` parameter determines where the prompt can be found in the output line. #[instrument(skip(self, stdout), level = "debug")] - async fn write_line( + async fn write_line_with_prompt_at( &mut self, stdout: &mut GhciStdout, line: &str, + find: FindAt, ) -> miette::Result> { self.stdin .write_all(line.as_bytes()) .await .into_diagnostic()?; - stdout.prompt(None).await + stdout.prompt(find).await + } + + /// Write a line on `stdin` and wait for a prompt on stdout. + /// + /// The `line` should contain the trailing newline. + async fn write_line( + &mut self, + stdout: &mut GhciStdout, + line: &str, + ) -> miette::Result> { + self.write_line_with_prompt_at(stdout, line, FindAt::LineStart) + .await } /// Run a [`GhciCommand`]. @@ -63,7 +79,7 @@ impl GhciStdin { .await .into_diagnostic()?; self.stdin.write_all(b"\n").await.into_diagnostic()?; - ret.extend(stdout.prompt(None).await?); + ret.extend(stdout.prompt(FindAt::LineStart).await?); } Ok(ret) @@ -74,10 +90,14 @@ impl GhciStdin { &mut self, stdout: &mut GhciStdout, setup_commands: &[GhciCommand], - ) -> miette::Result<()> { - self.set_mode(stdout, Mode::Internal).await?; - self.write_line(stdout, &format!(":set prompt {PROMPT}\n")) + ) -> miette::Result> { + // We tell stdout/stderr we're compiling for the first prompt because this includes all the + // module compilation before the first prompt. + self.set_mode(stdout, Mode::Compiling).await?; + let messages = self + .write_line_with_prompt_at(stdout, &format!(":set prompt {PROMPT}\n"), FindAt::Anywhere) .await?; + self.set_mode(stdout, Mode::Internal).await?; self.write_line(stdout, &format!(":set prompt-cont {PROMPT}\n")) .await?; self.write_line( @@ -91,7 +111,7 @@ impl GhciStdin { self.run_command(stdout, command).await?; } - Ok(()) + Ok(messages) } #[instrument(skip_all, level = "debug")] @@ -200,7 +220,7 @@ impl GhciStdin { .write_all(format!(":module + *{module}\n").as_bytes()) .await .into_diagnostic()?; - stdout.prompt(None).await?; + stdout.prompt(FindAt::LineStart).await?; self.run_command(stdout, command).await?; @@ -208,7 +228,7 @@ impl GhciStdin { .write_all(format!(":module - *{module}\n").as_bytes()) .await .into_diagnostic()?; - stdout.prompt(None).await?; + stdout.prompt(FindAt::LineStart).await?; Ok(()) } diff --git a/src/ghci/stdout.rs b/src/ghci/stdout.rs index 23e6defc..1aba4bd5 100644 --- a/src/ghci/stdout.rs +++ b/src/ghci/stdout.rs @@ -8,7 +8,9 @@ use tokio::sync::oneshot; use tracing::instrument; use crate::aho_corasick::AhoCorasickExt; +use crate::incremental_reader::FindAt; use crate::incremental_reader::IncrementalReader; +use crate::incremental_reader::ReadOpts; use crate::incremental_reader::WriteBehavior; use crate::sync_sentinel::SyncSentinel; @@ -37,7 +39,7 @@ pub struct GhciStdout { impl GhciStdout { #[instrument(skip_all, name = "stdout_initialize", level = "debug")] - pub async fn initialize(&mut self) -> miette::Result> { + pub async fn initialize(&mut self) -> miette::Result<()> { // Wait for `ghci` to start up. This may involve compiling a bunch of stuff. let bootup_patterns = AhoCorasick::from_anchored_patterns([ "GHCi, version ", @@ -46,36 +48,28 @@ impl GhciStdout { ]); let data = self .reader - .read_until(&bootup_patterns, WriteBehavior::Write, &mut self.buffer) + .read_until(&mut ReadOpts { + end_marker: &bootup_patterns, + find: FindAt::LineStart, + writing: WriteBehavior::Write, + buffer: &mut self.buffer, + }) .await?; tracing::debug!(data, "ghci started, saw version marker"); - // We know that we'll get _one_ `ghci> ` prompt on startup. - let init_prompt_patterns = AhoCorasick::from_anchored_patterns(["ghci> "]); - let messages = self.prompt(Some(&init_prompt_patterns)).await?; - tracing::debug!("Saw initial `ghci> ` prompt"); - - Ok(messages) + Ok(()) } #[instrument(skip_all, level = "debug")] - pub async fn prompt( - &mut self, - // We usually want this to be `&self.prompt_patterns`, but when we initialize we want to - // pass in a different value. This method takes an `&mut self` reference, so if we try to - // pass in `&self.prompt_patterns` when we call it we get a borrow error because the - // compiler doesn't know we don't mess with `self.prompt_patterns` in here. So we use - // `None` to represent that case and handle the default inline. - prompt_patterns: Option<&AhoCorasick>, - ) -> miette::Result> { - let prompt_patterns = prompt_patterns.unwrap_or(&self.prompt_patterns); + pub async fn prompt(&mut self, find: FindAt) -> miette::Result> { let data = self .reader - .read_until( - prompt_patterns, - WriteBehavior::NoFinalLine, - &mut self.buffer, - ) + .read_until(&mut ReadOpts { + end_marker: &self.prompt_patterns, + find, + writing: WriteBehavior::NoFinalLine, + buffer: &mut self.buffer, + }) .await?; tracing::debug!(bytes = data.len(), "Got data from ghci"); @@ -110,12 +104,22 @@ impl GhciStdout { let sync_pattern = AhoCorasick::from_anchored_patterns([sentinel.to_string()]); let data = self .reader - .read_until(&sync_pattern, WriteBehavior::NoFinalLine, &mut self.buffer) + .read_until(&mut ReadOpts { + end_marker: &sync_pattern, + find: FindAt::LineStart, + writing: WriteBehavior::NoFinalLine, + buffer: &mut self.buffer, + }) .await?; // Then make sure to consume the prompt on the next line, and then we'll be caught up. let _ = self .reader - .read_until(&self.prompt_patterns, WriteBehavior::Hide, &mut self.buffer) + .read_until(&mut ReadOpts { + end_marker: &self.prompt_patterns, + find: FindAt::LineStart, + writing: WriteBehavior::Hide, + buffer: &mut self.buffer, + }) .await?; tracing::debug!(data, "Synced with ghci"); @@ -127,7 +131,12 @@ impl GhciStdout { pub async fn show_paths(&mut self) -> miette::Result { let lines = self .reader - .read_until(&self.prompt_patterns, WriteBehavior::Hide, &mut self.buffer) + .read_until(&mut ReadOpts { + end_marker: &self.prompt_patterns, + find: FindAt::LineStart, + writing: WriteBehavior::Hide, + buffer: &mut self.buffer, + }) .await?; parse_show_paths(&lines).wrap_err("Failed to parse `:show paths` output") } @@ -136,7 +145,12 @@ impl GhciStdout { pub async fn show_targets(&mut self, search_paths: &ShowPaths) -> miette::Result { let lines = self .reader - .read_until(&self.prompt_patterns, WriteBehavior::Hide, &mut self.buffer) + .read_until(&mut ReadOpts { + end_marker: &self.prompt_patterns, + find: FindAt::LineStart, + writing: WriteBehavior::Hide, + buffer: &mut self.buffer, + }) .await?; let paths = parse_show_targets(search_paths, &lines) .wrap_err("Failed to parse `:show targets` output")?; diff --git a/src/incremental_reader.rs b/src/incremental_reader.rs index 29c771f0..4aace738 100644 --- a/src/incremental_reader.rs +++ b/src/incremental_reader.rs @@ -71,14 +71,9 @@ where /// /// TODO: Should this even use `aho_corasick`? Might be overkill, and with the automaton /// construction cost it might not even be more efficient. - pub async fn read_until( - &mut self, - end_marker: &AhoCorasick, - writing: WriteBehavior, - buffer: &mut [u8], - ) -> miette::Result { + pub async fn read_until(&mut self, opts: &mut ReadOpts<'_>) -> miette::Result { loop { - if let Some(lines) = self.try_read_until(end_marker, writing, buffer).await? { + if let Some(lines) = self.try_read_until(opts).await? { return Ok(lines); } } @@ -89,30 +84,28 @@ where /// returned. Otherwise, nothing is returned. pub async fn try_read_until( &mut self, - end_marker: &AhoCorasick, - writing: WriteBehavior, - buffer: &mut [u8], + opts: &mut ReadOpts<'_>, ) -> miette::Result> { - if let Some(chunk) = self.take_chunk_from_buffer(end_marker) { + if let Some(chunk) = self.take_chunk_from_buffer(opts) { tracing::trace!(data = chunk.len(), "Got data from buffer"); return Ok(Some(chunk)); } - match self.reader.read(buffer).await { + match self.reader.read(opts.buffer).await { Ok(0) => { // EOF Err(miette!("End-of-file reached")) } Ok(n) => { - let decoded = std::str::from_utf8(&buffer[..n]) + let decoded = std::str::from_utf8(&opts.buffer[..n]) .into_diagnostic() .wrap_err_with(|| { format!( "Read invalid UTF-8: {:?}", - String::from_utf8_lossy(&buffer[..n]) + String::from_utf8_lossy(&opts.buffer[..n]) ) })?; - match self.consume_str(decoded, end_marker, writing).await? { + match self.consume_str(decoded, opts).await? { Some(lines) => { tracing::trace!(data = decoded, "Decoded data"); tracing::trace!(lines = lines.len(), "Got chunk"); @@ -138,8 +131,7 @@ where async fn consume_str( &mut self, mut data: &str, - end_marker: &AhoCorasick, - writing: WriteBehavior, + opts: &ReadOpts<'_>, ) -> miette::Result> { // Proof of this function's corectness: just trust me @@ -155,11 +147,11 @@ where ret = match ret { Some(lines) => Some(lines), None => { - match end_marker.find_at_start(&self.line) { + match opts.find(opts.end_marker, &self.line) { Some(_match) => { // If we found an `end_marker` in `self.line`, our chunk is // `self.lines`. - Some(self.take_lines(writing).await?) + Some(self.take_lines(opts.writing).await?) } None => None, } @@ -179,22 +171,22 @@ where // We already have a chunk to return, so we can just add the current // line to `self.lines` and continue to process the remaining data in // `rest`. - self.finish_line(writing).await?; + self.finish_line(opts.writing).await?; Some(lines) } None => { // We don't have a chunk to return yet, so check for an `end_marker`. - match end_marker.find_at_start(&self.line) { + match opts.find(opts.end_marker, &self.line) { Some(_match) => { // If we found an `end_marker` in `self.line`, our chunk is // `self.lines`. - Some(self.take_lines(writing).await?) + Some(self.take_lines(opts.writing).await?) } None => { // We didn't find an `end_marker`, so add the current line to // `self.lines` and continue to process the remaining data in // `rest. - self.finish_line(writing).await?; + self.finish_line(opts.writing).await?; None } } @@ -260,12 +252,12 @@ where /// seen, the lines before the marker are returned. Otherwise, nothing is returned. /// /// Does _not_ read from the underlying reader. - fn take_chunk_from_buffer(&mut self, end_marker: &AhoCorasick) -> Option { + fn take_chunk_from_buffer(&mut self, opts: &ReadOpts<'_>) -> Option { // Do any of the lines in `self.lines` start with `end_marker`? if let Some(span) = self .lines .line_spans() - .find(|span| end_marker.find_at_start(span.as_str()).is_some()) + .find(|span| opts.find(opts.end_marker, span.as_str()).is_some()) { // Suppose this is our original `self.lines`, with newlines indicated by `|`: // @@ -290,7 +282,7 @@ where } // Does the current line in `self.line` start with `end_marker`? - if end_marker.find_at_start(&self.line).is_some() { + if opts.find(opts.end_marker, &self.line).is_some() { let chunk = std::mem::replace( &mut self.lines, String::with_capacity(VEC_BUFFER_CAPACITY * LINE_BUFFER_CAPACITY), @@ -336,6 +328,38 @@ pub enum WriteBehavior { Hide, } +/// Determines where an [`IncrementalReader`] matches an [`AhoCorasick`] end marker. +#[derive(Clone, Copy, Debug)] +pub enum FindAt { + /// Match only at the start of a line. + LineStart, + /// Match anywhere in a line. + Anywhere, +} + +/// Options for performing a read from an [`IncrementalReader`]. +#[derive(Debug)] +pub struct ReadOpts<'a> { + /// The end marker to look for. + pub end_marker: &'a AhoCorasick, + /// Where the end marker should be looked for. + pub find: FindAt, + /// How to write output to the wrapped writer. + pub writing: WriteBehavior, + /// A buffer to read input into. This is used to avoid allocating additional buffers; no + /// particular constraints are placed on this buffer. + pub buffer: &'a mut [u8], +} + +impl<'a> ReadOpts<'a> { + fn find(&self, marker: &AhoCorasick, input: &str) -> Option { + match self.find { + FindAt::LineStart => marker.find_at_start(input), + FindAt::Anywhere => marker.find_anywhere(input), + } + } +} + #[cfg(test)] mod tests { use indoc::indoc; @@ -365,7 +389,51 @@ mod tests { assert_eq!( reader - .read_until(&end_marker, WriteBehavior::Hide, &mut buffer) + .read_until(&mut ReadOpts { + end_marker: &end_marker, + find: FindAt::LineStart, + writing: WriteBehavior::Hide, + buffer: &mut buffer, + }) + .await + .unwrap(), + indoc!( + " + Build profile: -w ghc-9.6.1 -O0 + In order, the following will be built (use -v for more details): + - mwb-0 (lib:test-dev) (ephemeral targets) + Preprocessing library 'test-dev' for mwb-0.. + " + ) + ); + } + + /// Same as `test_read_until` but with `FindAt::Anywhere`. + #[tokio::test] + async fn test_read_until_find_anywhere() { + let fake_reader = FakeReader::with_str_chunks([indoc!( + "Build profile: -w ghc-9.6.1 -O0 + In order, the following will be built (use -v for more details): + - mwb-0 (lib:test-dev) (ephemeral targets) + Preprocessing library 'test-dev' for mwb-0.. + GHCi, version 9.6.1: https://www.haskell.org/ghc/ :? for help + Loaded GHCi configuration from .ghci-mwb + Ok, 5699 modules loaded. + ghci> " + )]); + + let mut reader = IncrementalReader::new(fake_reader).with_writer(tokio::io::sink()); + let end_marker = AhoCorasick::from_anchored_patterns(["https://www.haskell.org/ghc/"]); + let mut buffer = vec![0; LINE_BUFFER_CAPACITY]; + + assert_eq!( + reader + .read_until(&mut ReadOpts { + end_marker: &end_marker, + find: FindAt::Anywhere, + writing: WriteBehavior::Hide, + buffer: &mut buffer, + }) .await .unwrap(), indoc!( @@ -404,7 +472,64 @@ mod tests { assert_eq!( reader - .read_until(&end_marker, WriteBehavior::Hide, &mut buffer) + .read_until(&mut ReadOpts { + end_marker: &end_marker, + find: FindAt::LineStart, + writing: WriteBehavior::Hide, + buffer: &mut buffer + }) + .await + .unwrap(), + indoc!( + " + Build profile: -w ghc-9.6.1 -O0 + In order, the following will be built (use -v for more details): + - mwb-0 (lib:test-dev) (ephemeral targets) + Preprocessing library 'test-dev' for mwb-0.. + " + ) + ); + + eprintln!("{:?}", reader.buffer()); + assert_eq!( + reader.buffer(), + indoc!( + "Loaded GHCi configuration from .ghci-mwb + Ok, 5699 modules loaded. + ghci> " + ) + ); + } + + #[tokio::test] + async fn test_read_until_with_data_in_buffer_find_anywhere() { + let fake_reader = FakeReader::default(); + let mut reader = IncrementalReader::new(fake_reader).with_writer(tokio::io::sink()); + reader + .push_to_buffer(indoc!( + " + Build profile: -w ghc-9.6.1 -O0 + In order, the following will be built (use -v for more details): + - mwb-0 (lib:test-dev) (ephemeral targets) + Preprocessing library 'test-dev' for mwb-0.. + GHCi, version 9.6.1: https://www.haskell.org/ghc/ :? for help + Loaded GHCi configuration from .ghci-mwb + Ok, 5699 modules loaded. + ghci> " + )) + .await; + + let end_marker = AhoCorasick::from_anchored_patterns(["https://www.haskell.org/ghc/"]); + let mut buffer = vec![0; LINE_BUFFER_CAPACITY]; + + assert_eq!( + reader + .read_until(&mut ReadOpts { + end_marker: &end_marker, + find: FindAt::Anywhere, + writing: WriteBehavior::Hide, + buffer: &mut buffer + }) .await .unwrap(), indoc!( @@ -462,7 +587,12 @@ mod tests { assert_eq!( reader - .read_until(&end_marker, WriteBehavior::Hide, &mut buffer) + .read_until(&mut ReadOpts { + end_marker: &end_marker, + find: FindAt::LineStart, + writing: WriteBehavior::Hide, + buffer: &mut buffer + }) .await .unwrap(), indoc!( diff --git a/tests/dot_ghci.rs b/tests/dot_ghci.rs new file mode 100644 index 00000000..726d9989 --- /dev/null +++ b/tests/dot_ghci.rs @@ -0,0 +1,29 @@ +use test_harness::fs; +use test_harness::test; +use test_harness::GhcidNgBuilder; + +use indoc::indoc; + +/// Test that `ghcid-ng` can run with a custom prompt in `.ghci`. +#[test] +async fn can_run_with_custom_ghci_prompt() { + let mut session = GhcidNgBuilder::new("tests/data/simple") + .before_start(|project| async move { + fs::write( + project.join(".ghci"), + indoc!( + r#" + :set prompt "λ " + :set prompt-cont "│ " + "# + ), + ) + .await?; + Ok(()) + }) + .start() + .await + .expect("ghcid-ng starts"); + + session.wait_until_ready().await.unwrap(); +}