Skip to content

Commit

Permalink
Allow custom prompts in .ghci (#93)
Browse files Browse the repository at this point in the history
Previously, `:set prompt` commands in `.ghci` configuration files would
cause `ghcid-ng` to hang on start. This fixes that bug.
  • Loading branch information
9999years authored Sep 28, 2023
1 parent c85df99 commit 5ed6199
Show file tree
Hide file tree
Showing 6 changed files with 270 additions and 69 deletions.
9 changes: 8 additions & 1 deletion src/aho_corasick.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ pub trait AhoCorasickExt {
/// Attempt to match at the start of the input.
fn find_at_start(&self, input: &str) -> Option<Match>;

/// Attempt to match anywhere in the input.
fn find_anywhere(&self, input: &str) -> Option<Match>;

/// Build a matcher from the given set of patterns, with anchored matching enabled (matching at
/// the start of the string only).
fn from_anchored_patterns(patterns: impl IntoIterator<Item = impl AsRef<[u8]>>) -> Self;
Expand All @@ -21,9 +24,13 @@ impl AhoCorasickExt for AhoCorasick {
self.find(Input::new(input).anchored(Anchored::Yes))
}

fn find_anywhere(&self, input: &str) -> Option<Match> {
self.find(Input::new(input).anchored(Anchored::No))
}

fn from_anchored_patterns(patterns: impl IntoIterator<Item = impl AsRef<[u8]>>) -> Self {
Self::builder()
.start_kind(StartKind::Anchored)
.start_kind(StartKind::Both)
.build(patterns)
.unwrap()
}
Expand Down
7 changes: 4 additions & 3 deletions src/ghci/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,13 +253,14 @@ impl Ghci {
ret.stderr_handle = stderr;

// Wait for the stdout job to start up.
let messages = ret.stdout.initialize().await?;
ret.process_ghc_messages(messages).await?;
ret.stdout.initialize().await?;

// Perform start-of-session initialization.
ret.stdin
let messages = ret
.stdin
.initialize(&mut ret.stdout, &ret.opts.hooks.after_startup_ghci)
.await?;
ret.process_ghc_messages(messages).await?;

// Sync up for any prompts.
ret.sync().await?;
Expand Down
38 changes: 29 additions & 9 deletions src/ghci/stdin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use tokio::task::JoinSet;
use tracing::instrument;

use crate::haskell_show::HaskellShow;
use crate::incremental_reader::FindAt;
use crate::sync_sentinel::SyncSentinel;

use super::parse::GhcMessage;
Expand All @@ -33,17 +34,32 @@ impl GhciStdin {
/// Write a line on `stdin` and wait for a prompt on stdout.
///
/// The `line` should contain the trailing newline.
///
/// The `find` parameter determines where the prompt can be found in the output line.
#[instrument(skip(self, stdout), level = "debug")]
async fn write_line(
async fn write_line_with_prompt_at(
&mut self,
stdout: &mut GhciStdout,
line: &str,
find: FindAt,
) -> miette::Result<Vec<GhcMessage>> {
self.stdin
.write_all(line.as_bytes())
.await
.into_diagnostic()?;
stdout.prompt(None).await
stdout.prompt(find).await
}

/// Write a line on `stdin` and wait for a prompt on stdout.
///
/// The `line` should contain the trailing newline.
async fn write_line(
&mut self,
stdout: &mut GhciStdout,
line: &str,
) -> miette::Result<Vec<GhcMessage>> {
self.write_line_with_prompt_at(stdout, line, FindAt::LineStart)
.await
}

/// Run a [`GhciCommand`].
Expand All @@ -63,7 +79,7 @@ impl GhciStdin {
.await
.into_diagnostic()?;
self.stdin.write_all(b"\n").await.into_diagnostic()?;
ret.extend(stdout.prompt(None).await?);
ret.extend(stdout.prompt(FindAt::LineStart).await?);
}

Ok(ret)
Expand All @@ -74,10 +90,14 @@ impl GhciStdin {
&mut self,
stdout: &mut GhciStdout,
setup_commands: &[GhciCommand],
) -> miette::Result<()> {
self.set_mode(stdout, Mode::Internal).await?;
self.write_line(stdout, &format!(":set prompt {PROMPT}\n"))
) -> miette::Result<Vec<GhcMessage>> {
// We tell stdout/stderr we're compiling for the first prompt because this includes all the
// module compilation before the first prompt.
self.set_mode(stdout, Mode::Compiling).await?;
let messages = self
.write_line_with_prompt_at(stdout, &format!(":set prompt {PROMPT}\n"), FindAt::Anywhere)
.await?;
self.set_mode(stdout, Mode::Internal).await?;
self.write_line(stdout, &format!(":set prompt-cont {PROMPT}\n"))
.await?;
self.write_line(
Expand All @@ -91,7 +111,7 @@ impl GhciStdin {
self.run_command(stdout, command).await?;
}

Ok(())
Ok(messages)
}

#[instrument(skip_all, level = "debug")]
Expand Down Expand Up @@ -200,15 +220,15 @@ impl GhciStdin {
.write_all(format!(":module + *{module}\n").as_bytes())
.await
.into_diagnostic()?;
stdout.prompt(None).await?;
stdout.prompt(FindAt::LineStart).await?;

self.run_command(stdout, command).await?;

self.stdin
.write_all(format!(":module - *{module}\n").as_bytes())
.await
.into_diagnostic()?;
stdout.prompt(None).await?;
stdout.prompt(FindAt::LineStart).await?;

Ok(())
}
Expand Down
68 changes: 41 additions & 27 deletions src/ghci/stdout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ use tokio::sync::oneshot;
use tracing::instrument;

use crate::aho_corasick::AhoCorasickExt;
use crate::incremental_reader::FindAt;
use crate::incremental_reader::IncrementalReader;
use crate::incremental_reader::ReadOpts;
use crate::incremental_reader::WriteBehavior;
use crate::sync_sentinel::SyncSentinel;

Expand Down Expand Up @@ -37,7 +39,7 @@ pub struct GhciStdout {

impl GhciStdout {
#[instrument(skip_all, name = "stdout_initialize", level = "debug")]
pub async fn initialize(&mut self) -> miette::Result<Vec<GhcMessage>> {
pub async fn initialize(&mut self) -> miette::Result<()> {
// Wait for `ghci` to start up. This may involve compiling a bunch of stuff.
let bootup_patterns = AhoCorasick::from_anchored_patterns([
"GHCi, version ",
Expand All @@ -46,36 +48,28 @@ impl GhciStdout {
]);
let data = self
.reader
.read_until(&bootup_patterns, WriteBehavior::Write, &mut self.buffer)
.read_until(&mut ReadOpts {
end_marker: &bootup_patterns,
find: FindAt::LineStart,
writing: WriteBehavior::Write,
buffer: &mut self.buffer,
})
.await?;
tracing::debug!(data, "ghci started, saw version marker");

// We know that we'll get _one_ `ghci> ` prompt on startup.
let init_prompt_patterns = AhoCorasick::from_anchored_patterns(["ghci> "]);
let messages = self.prompt(Some(&init_prompt_patterns)).await?;
tracing::debug!("Saw initial `ghci> ` prompt");

Ok(messages)
Ok(())
}

#[instrument(skip_all, level = "debug")]
pub async fn prompt(
&mut self,
// We usually want this to be `&self.prompt_patterns`, but when we initialize we want to
// pass in a different value. This method takes an `&mut self` reference, so if we try to
// pass in `&self.prompt_patterns` when we call it we get a borrow error because the
// compiler doesn't know we don't mess with `self.prompt_patterns` in here. So we use
// `None` to represent that case and handle the default inline.
prompt_patterns: Option<&AhoCorasick>,
) -> miette::Result<Vec<GhcMessage>> {
let prompt_patterns = prompt_patterns.unwrap_or(&self.prompt_patterns);
pub async fn prompt(&mut self, find: FindAt) -> miette::Result<Vec<GhcMessage>> {
let data = self
.reader
.read_until(
prompt_patterns,
WriteBehavior::NoFinalLine,
&mut self.buffer,
)
.read_until(&mut ReadOpts {
end_marker: &self.prompt_patterns,
find,
writing: WriteBehavior::NoFinalLine,
buffer: &mut self.buffer,
})
.await?;
tracing::debug!(bytes = data.len(), "Got data from ghci");

Expand Down Expand Up @@ -110,12 +104,22 @@ impl GhciStdout {
let sync_pattern = AhoCorasick::from_anchored_patterns([sentinel.to_string()]);
let data = self
.reader
.read_until(&sync_pattern, WriteBehavior::NoFinalLine, &mut self.buffer)
.read_until(&mut ReadOpts {
end_marker: &sync_pattern,
find: FindAt::LineStart,
writing: WriteBehavior::NoFinalLine,
buffer: &mut self.buffer,
})
.await?;
// Then make sure to consume the prompt on the next line, and then we'll be caught up.
let _ = self
.reader
.read_until(&self.prompt_patterns, WriteBehavior::Hide, &mut self.buffer)
.read_until(&mut ReadOpts {
end_marker: &self.prompt_patterns,
find: FindAt::LineStart,
writing: WriteBehavior::Hide,
buffer: &mut self.buffer,
})
.await?;
tracing::debug!(data, "Synced with ghci");

Expand All @@ -127,7 +131,12 @@ impl GhciStdout {
pub async fn show_paths(&mut self) -> miette::Result<ShowPaths> {
let lines = self
.reader
.read_until(&self.prompt_patterns, WriteBehavior::Hide, &mut self.buffer)
.read_until(&mut ReadOpts {
end_marker: &self.prompt_patterns,
find: FindAt::LineStart,
writing: WriteBehavior::Hide,
buffer: &mut self.buffer,
})
.await?;
parse_show_paths(&lines).wrap_err("Failed to parse `:show paths` output")
}
Expand All @@ -136,7 +145,12 @@ impl GhciStdout {
pub async fn show_targets(&mut self, search_paths: &ShowPaths) -> miette::Result<ModuleSet> {
let lines = self
.reader
.read_until(&self.prompt_patterns, WriteBehavior::Hide, &mut self.buffer)
.read_until(&mut ReadOpts {
end_marker: &self.prompt_patterns,
find: FindAt::LineStart,
writing: WriteBehavior::Hide,
buffer: &mut self.buffer,
})
.await?;
let paths = parse_show_targets(search_paths, &lines)
.wrap_err("Failed to parse `:show targets` output")?;
Expand Down
Loading

0 comments on commit 5ed6199

Please sign in to comment.