diff --git a/Cargo.lock b/Cargo.lock index c4a5784..83126b4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -209,6 +209,19 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "console" +version = "0.15.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb" +dependencies = [ + "encode_unicode", + "lazy_static", + "libc", + "unicode-width 0.1.14", + "windows-sys 0.52.0", +] + [[package]] name = "crossbeam-channel" version = "0.5.13" @@ -218,6 +231,16 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + [[package]] name = "crossbeam-epoch" version = "0.9.18" @@ -249,7 +272,9 @@ dependencies = [ "bytes", "clap", "domain", + "indicatif", "lexopt", + "rayon", "tempfile", "test_bin", ] @@ -257,7 +282,7 @@ dependencies = [ [[package]] name = "domain" version = "0.10.3" -source = "git+https://github.com/NLnetLabs/domain?branch=support-zonefile-fmt-with-padding#32fb700604d48fd3483a9aea1fa3fcfd71318379" +source = "git+https://github.com/NLnetLabs/domain?branch=report-signing-progress#bc14779780e6f6b3ce4c317526335ce3ff3f334b" dependencies = [ "bytes", "futures-util", @@ -276,6 +301,18 @@ dependencies = [ "tracing", ] +[[package]] +name = "either" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" + +[[package]] +name = "encode_unicode" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" + [[package]] name = "errno" version = "0.3.9" @@ -403,6 +440,19 @@ version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" +[[package]] +name = "indicatif" +version = "0.17.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cbf675b85ed934d3c67b5c5469701eec7db22689d0a2139d856e0925fa28b281" +dependencies = [ + "console", + "number_prefix", + "portable-atomic", + "unicode-width 0.2.0", + "web-time", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.1" @@ -418,6 +468,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + [[package]] name = "lexopt" version = "0.3.0" @@ -509,6 +565,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + [[package]] name = "object" version = "0.36.5" @@ -619,6 +681,12 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" +[[package]] +name = "portable-atomic" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6" + [[package]] name = "powerfmt" version = "0.2.0" @@ -706,6 +774,26 @@ dependencies = [ "bitflags", ] +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "redox_syscall" version = "0.5.7" @@ -993,6 +1081,18 @@ version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" +[[package]] +name = "unicode-width" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" + +[[package]] +name = "unicode-width" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" + [[package]] name = "untrusted" version = "0.9.0" @@ -1091,6 +1191,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "winapi" version = "0.3.9" diff --git a/Cargo.toml b/Cargo.toml index ce0fe59..cbdfe2c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ path = "src/bin/ldns.rs" [dependencies] bytes = { version = "1.1", default-features = false } clap = { version = "4.3.4", features = ["cargo", "derive"] } -domain = { git = "https://github.com/NLnetLabs/domain", branch = "support-zonefile-fmt-with-padding", features = [ +domain = { git = "https://github.com/NLnetLabs/domain", branch = "report-signing-progress", features = [ "bytes", "openssl", "ring", @@ -21,7 +21,9 @@ domain = { git = "https://github.com/NLnetLabs/domain", branch = "support-zonefi "unstable-zonetree", "zonefile", ] } +indicatif = { version = "0.17.9" } lexopt = "0.3.0" +rayon = "1.10.0" [dev-dependencies] test_bin = "0.4.0" diff --git a/src/commands/key2ds.rs b/src/commands/key2ds.rs index e14f8db..77b596c 100644 --- a/src/commands/key2ds.rs +++ b/src/commands/key2ds.rs @@ -140,7 +140,7 @@ impl Key2ds { })?; // We only care about records in a zonefile - let Entry::Record(record) = entry else { + let Entry::Record(record, _) = entry else { continue; }; diff --git a/src/commands/signzone.rs b/src/commands/signzone.rs index 8c9ebed..6cc2699 100644 --- a/src/commands/signzone.rs +++ b/src/commands/signzone.rs @@ -31,7 +31,11 @@ use domain::validate::Key; use domain::zonefile::inplace::{self, Entry}; use domain::zonetree::types::StoredRecordData; use domain::zonetree::{StoredName, StoredRecord}; +use indicatif::{ + HumanBytes, HumanDuration, ProgressBar, ProgressDrawTarget, ProgressState, ProgressStyle, +}; use lexopt::Arg; +use rayon::slice::ParallelSliceMut; use crate::env::{Env, Stream}; use crate::error::Error; @@ -42,6 +46,8 @@ use super::{parse_os, parse_os_with, LdnsCommand}; //------------ Constants ----------------------------------------------------- const FOUR_WEEKS: u32 = 2419200; +const PROGRESS_STYLE: &str = + "{msg} {spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {pos}/{len} ({eta})"; //------------ SignZone ------------------------------------------------------ @@ -195,6 +201,14 @@ pub struct SignZone { #[arg(short = 'M', default_value_t = false)] no_require_keys_match_apex: bool, + /// Show progress bars. + #[arg(long = "progress", default_value_t = false)] + progress: bool, + + /// Show verbose output. + #[arg(long = "verbose", default_value_t = false)] + verbose: bool, + // ----------------------------------------------------------------------- // Original ldns-signzone positional arguments in position order: // ----------------------------------------------------------------------- @@ -354,6 +368,8 @@ impl LdnsCommand for SignZone { zonefile_path, key_paths, no_require_keys_match_apex: false, + progress: false, + verbose: false, }) } } @@ -412,12 +428,14 @@ impl SignZone { .map_err(|err| format!("Cannot write to {out_file}: {err}"))? }; - let mut writer = if out_file.as_os_str() == "-" { - FileOrStdout::Stdout(env.stdout()) + let (mut writer, mut log) = if out_file.as_os_str() == "-" { + // Don't allow logging to stdout when the output zone will be + // written to stdout. + (FileOrStdout::Stdout(env.stdout()), None) } else { let file = File::create(env.in_cwd(&out_file))?; let file = BufWriter::new(file); - FileOrStdout::File(file) + (FileOrStdout::File(file), Some(env.stdout())) }; // Import the specified keys and check that the keys are all for the same zone. @@ -430,7 +448,7 @@ impl SignZone { let mut first_key_owner = None; for key_path in &self.key_paths { - let key = Self::load_key_pair(key_path)?; + let key = self.load_key_pair(key_path, &mut log)?; if !self.no_require_keys_match_apex { if first_key_owner.is_none() { @@ -462,7 +480,8 @@ impl SignZone { } // Read the zone file. - let mut records = self.load_zone(&env, first_key_owner.as_ref())?; + let mut records = self.load_zone(&env, first_key_owner.as_ref(), &mut log)?; + let n_records = records.len(); // Change the SOA serial. if self.set_soa_serial_to_epoch_time { @@ -473,6 +492,21 @@ impl SignZone { let (apex, ttl) = Self::find_apex(&records).unwrap(); // Hash the zone with NSEC or NSEC3. + let pb = if matches!(&log, Some(log) if log.is_terminal() && self.progress) { + let pb = ProgressBar::new(0).with_message("Hashing"); + pb.set_draw_target(ProgressDrawTarget::stderr_with_hz(1)); + pb.set_style( + ProgressStyle::with_template(PROGRESS_STYLE) + .unwrap() + .with_key("eta", |state: &ProgressState, w: &mut dyn Write| { + write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap() + }) + .progress_chars("#>-"), + ); + Some(pb) + } else { + None + }; let hashes = if self.use_nsec3 { let params = Nsec3param::new(self.algorithm, 0, self.iterations, self.salt.clone()); let Nsec3Records { @@ -480,13 +514,26 @@ impl SignZone { param, hashes, } = records - .nsec3s::<_, BytesMut>( + .nsec3s::<_, BytesMut, _>( &apex, ttl, params, opt_out, !self.do_not_add_keys_to_zone, self.extra_comments, + |inc_pos, inc_len, new_phase| { + if let Some(pb) = &pb { + if inc_len > 0 { + pb.inc_length(inc_len as u64); + } + if inc_pos > 0 { + pb.inc(inc_pos as u64); + } + if let Some(new_phase) = new_phase { + pb.set_message(new_phase); + } + } + }, ) .unwrap(); records.extend(recs.into_iter().map(Record::from_record)); @@ -498,8 +545,39 @@ impl SignZone { None }; + if let Some(pb) = pb { + let len = pb.length().unwrap(); + let elapsed = pb.elapsed(); + pb.finish_and_clear(); + if self.verbose { + if let Some(log) = &mut log { + writeln!( + log, + "Hashed {n_records} records in {len} steps in {}", + HumanDuration(elapsed) + ); + } + } + } + // Sign the zone unless disabled. if signing_mode == SigningMode::HashAndSign { + let n_records = records.len(); + let pb = if matches!(&log, Some(log) if log.is_terminal() && self.progress) { + let pb = ProgressBar::new(0).with_message("Signing"); + pb.set_draw_target(ProgressDrawTarget::stderr_with_hz(1)); + pb.set_style( + ProgressStyle::with_template(PROGRESS_STYLE) + .unwrap() + .with_key("eta", |state: &ProgressState, w: &mut dyn Write| { + write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap() + }) + .progress_chars("#>-"), + ); + Some(pb) + } else { + None + }; let extra_records = records .sign( &apex, @@ -507,11 +585,57 @@ impl SignZone { self.inception, keys.as_slice(), !self.do_not_add_keys_to_zone, + |inc_pos, inc_len, new_phase| { + if let Some(pb) = &pb { + if inc_len > 0 { + pb.inc_length(inc_len as u64); + } + if inc_pos > 0 { + pb.inc(inc_pos as u64); + } + if let Some(new_phase) = new_phase { + pb.set_message(new_phase); + } + } + }, ) .unwrap(); records.extend(extra_records.into_iter().map(Record::from_record)); + + if let Some(pb) = pb { + let len = pb.length().unwrap(); + let elapsed = pb.elapsed(); + pb.finish_and_clear(); + if self.verbose { + if let Some(log) = &mut log { + writeln!( + log, + "Signed {n_records} records in {len} steps in {}", + HumanDuration(elapsed) + ); + } + } + } } + let n_records = records.len(); + let pb = if matches!(&log, Some(log) if log.is_terminal() && self.progress) { + let num_families = records.families().count(); + let pb = ProgressBar::new(num_families as u64); + pb.set_draw_target(ProgressDrawTarget::stderr_with_hz(1)); + pb.set_style( + ProgressStyle::with_template(PROGRESS_STYLE) + .unwrap() + .with_key("eta", |state: &ProgressState, w: &mut dyn Write| { + write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap() + }) + .progress_chars("#>-"), + ); + Some(pb) + } else { + None + }; + // Output the resulting zone, with comments if enabled. if self.extra_comments { writer.write_fmt(format_args!(";; Zone: {}\n;\n", apex.owner()))?; @@ -531,6 +655,9 @@ impl SignZone { if self.extra_comments { writer.write_str(";\n")?; } + if let Some(pb) = &pb { + pb.inc(1); + } } let hashes_ref = hashes.as_ref(); @@ -550,6 +677,9 @@ impl SignZone { // data so it's indiividual are not the records themselves. let mut families; let family_iter: AnyFamiliesIter = if self.extra_comments && hashes_ref.is_some() { + if let Some(pb) = &pb { + pb.set_message("Applying diagnostic ordering"); + } families = records.families().collect::>(); let Some(hashes) = hashes_ref else { unreachable!(); @@ -585,6 +715,9 @@ impl SignZone { records.families().into() }; + if let Some(pb) = &pb { + pb.set_message("Saving"); + } for family in family_iter { if let Some(hashes_ref) = hashes_ref { // If this is family contains an NSEC3 RR and the number of @@ -647,6 +780,25 @@ impl SignZone { } } } + + if let Some(pb) = &pb { + pb.inc(1); + } + } + + if let Some(pb) = pb { + let len = pb.length().unwrap(); + let elapsed = pb.elapsed(); + pb.finish_and_clear(); + if self.verbose { + if let Some(log) = &mut log { + writeln!( + log, + "Saved {n_records} records in {len} steps in {}", + HumanDuration(elapsed) + ); + } + } } Ok(()) @@ -656,7 +808,14 @@ impl SignZone { &self, env: &impl Env, expected_apex: Option<&Name>, - ) -> Result, Error> { + log: &mut Option>, + ) -> Result, Error> { + if self.verbose { + if let Some(log) = log { + writeln!(log, "Loading zone from '{}'", self.zonefile_path.display()); + } + } + // Don't use Zonefile::load() as it knows nothing about the size of // the original file so uses default allocation which allocates more // bytes than are needed. Instead control the allocation size based on @@ -666,16 +825,37 @@ impl SignZone { let mut buf = inplace::Zonefile::with_capacity(zone_file_len as usize).writer(); std::io::copy(&mut zone_file, &mut buf)?; let mut reader = buf.into_inner(); - let mut records = SortedRecords::new(); + let n_records = reader.len() as u64; if let Some(origin) = &self.origin { reader.set_origin(origin.clone()); } + let pb = if matches!(log, Some(log) if log.is_terminal() && self.progress) { + let pb = ProgressBar::new(n_records).with_message("Parsing"); + pb.set_draw_target(ProgressDrawTarget::stderr_with_hz(1)); + pb.set_style( + ProgressStyle::with_template(PROGRESS_STYLE) + .unwrap() + .with_key("eta", |state: &ProgressState, w: &mut dyn Write| { + write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap() + }) + .progress_chars("#>-"), + ); + Some(pb) + } else { + None + }; + + // Push records to an unsorted vec, then sort at the end, as this is faster than + // sorting one record at a time. + let mut records = Vec::>::new(); + for entry in reader { let entry = entry.map_err(|err| format!("Invalid zone file: {err}"))?; + match entry { - Entry::Record(record) => { + Entry::Record(record, pos) => { let record: StoredRecord = record.flatten_into(); if let Some(expected_apex) = expected_apex { if record.rtype() == Rtype::SOA && record.owner() != expected_apex { @@ -687,22 +867,47 @@ impl SignZone { } } - records.insert(record).map_err(|record| { - format!("Invalid zone file: Duplicate record detected: {record:?}") - })?; + records.push(record); + if let Some(pb) = &pb { + pb.set_position(pos as u64); + } } Entry::Include { .. } => { - return Err(Error::from( - "Invalid zone file: $INCLUDE directive is not supported", - )); + return Err("Invalid zone file: $INCLUDE directive is not supported".into()); + } + } + } + + if let Some(pb) = &pb { + pb.set_message("Sorting"); + } + + // Use a multi-threaded parallel sorter to sort our unsorted vec into + // a `SortedRecords` type. + let records = SortedRecords::<_, _, MultiThreadedSorter>::from(records); + + if let Some(pb) = pb { + let len = pb.length().unwrap(); + let elapsed = pb.elapsed(); + pb.finish_and_clear(); + if self.verbose { + if let Some(log) = log { + writeln!( + log, + "Loaded {len} records from {} [{} bytes] in {}", + self.zonefile_path.display(), + HumanBytes(zone_file_len), + HumanDuration(elapsed) + ); } } } + Ok(records) } fn find_apex( - records: &SortedRecords, + records: &SortedRecords, ) -> Result<(FamilyName>, Ttl), Error> { let soa = match records.find_soa() { Some(soa) => soa, @@ -726,7 +931,11 @@ impl SignZone { } fn bump_soa_serial( - records: &mut SortedRecords, ZoneRecordData>>, + records: &mut SortedRecords< + Name, + ZoneRecordData>, + MultiThreadedSorter, + >, ) -> Result<(), Error> { let Some(old_soa_rr) = records.find_soa() else { return Err(Error::from("Error reading zonefile: missing SOA record")); @@ -778,11 +987,25 @@ impl SignZone { /// However, this function is not strict about the format of the prefix, it /// will attempt to load files with suffixes '.key' and '.private' irrespective /// of the format of the rest of the path. - fn load_key_pair(key_path: &Path) -> Result, Error> { + fn load_key_pair( + &self, + key_path: &Path, + log: &mut Option>, + ) -> Result, Error> { let key_path_str = key_path.to_string_lossy(); let public_key_path = PathBuf::from(format!("{key_path_str}.key")); let private_key_path = PathBuf::from(format!("{key_path_str}.private")); + if self.verbose { + if let Some(log) = log { + writeln!( + log, + "Loading private key from '{}'", + private_key_path.display() + ); + } + } + let private_data = std::fs::read_to_string(&private_key_path).map_err(|err| { format!( "Unable to load private key from file '{}': {}", @@ -791,6 +1014,16 @@ impl SignZone { ) })?; + if self.verbose { + if let Some(log) = log { + writeln!( + log, + "Loading public key from '{}'", + public_key_path.display() + ); + } + } + let public_data = std::fs::read_to_string(&public_key_path).map_err(|err| { format!( "Unable to load public key from file '{}': {}", @@ -1016,3 +1249,21 @@ impl<'a> From, ZoneRecordData>>> Self::FamiliesIter(iter) } } + +//------------ MultiThreadedSorter ------------------------------------------- + +/// A parallelized sort implementation for use with [`SortedRecords`]. +/// +/// TODO: Should we add a `-j` (jobs) command line argument to override the +/// default Rayon behaviour of using as many threads as their are CPU cores? +struct MultiThreadedSorter; + +impl domain::sign::records::Sorter for MultiThreadedSorter { + fn sort_by(records: &mut Vec>, compare: F) + where + Record: Send, + F: Fn(&Record, &Record) -> Ordering + Sync, + { + records.par_sort_by(compare); + } +} diff --git a/src/env/fake.rs b/src/env/fake.rs index 0715ee5..9caf4c5 100644 --- a/src/env/fake.rs +++ b/src/env/fake.rs @@ -51,11 +51,11 @@ impl Env for FakeEnv { } fn stdout(&self) -> Stream { - Stream(self.stdout.clone()) + Stream(self.stdout.clone(), false) } fn stderr(&self) -> Stream { - Stream(self.stderr.clone()) + Stream(self.stderr.clone(), false) } fn in_cwd<'a>(&self, path: &'a impl AsRef) -> Cow<'a, Path> { diff --git a/src/env/mod.rs b/src/env/mod.rs index d70336e..40703ee 100644 --- a/src/env/mod.rs +++ b/src/env/mod.rs @@ -44,7 +44,7 @@ pub trait Env { /// [`std::io::Write`]. Additionally, this `write_fmt` does not return a /// result. This means that we can use the [`write!`] and [`writeln`] macros /// without handling errors. -pub struct Stream(T); +pub struct Stream(T, bool); impl Stream { pub fn write_fmt(&mut self, args: fmt::Arguments<'_>) { @@ -60,6 +60,10 @@ impl Stream { // Same as with write_fmt... self.0.write_str(s).unwrap(); } + + pub fn is_terminal(&self) -> bool { + self.1 + } } impl Env for &E { diff --git a/src/env/real.rs b/src/env/real.rs index 26c01aa..0b71bab 100644 --- a/src/env/real.rs +++ b/src/env/real.rs @@ -1,6 +1,6 @@ use std::ffi::OsString; use std::fmt; -use std::io; +use std::io::{self, IsTerminal}; use std::path::Path; use super::Env; @@ -15,11 +15,15 @@ impl Env for RealEnv { } fn stdout(&self) -> Stream { - Stream(FmtWriter(io::stdout())) + let stdout = io::stdout(); + let is_term = stdout.is_terminal(); + Stream(FmtWriter(stdout), is_term) } fn stderr(&self) -> Stream { - Stream(FmtWriter(io::stderr())) + let stderr = io::stderr(); + let is_term = stderr.is_terminal(); + Stream(FmtWriter(stderr), is_term) } fn in_cwd<'a>(&self, path: &'a impl AsRef) -> std::borrow::Cow<'a, std::path::Path> {