From 660b21685fb3824c01191df70960f668b5dec8ce Mon Sep 17 00:00:00 2001 From: Reisen Date: Thu, 6 Jun 2024 11:02:01 +0100 Subject: [PATCH] refactor(agent): various refactors (#126) * refactor(agent): convert keypair service to api * refactor(agent): use exit flag everywhere * refactor(agent): StateApi -> Prices and refactor module The StateApi is left over from the initial Adapter, but all functionality is for pricing/product accounts. This refactors that module and fixes the cyclic dependency between it and GlobalStore. The new logic performs updates within the Prices API (Which is where the state relevant to subscriptions already was, so is the better place for it). File rename left for a future commit to keep the diffs clean. * refactor(agent): refactor all references to adapter to state * refactor(agent): remove pythd module, raise pyth module * refactor(agent): remove store module * refactor(agent): convert to a tracing logger --- Cargo.lock | 324 +++++------------ Cargo.toml | 8 +- integration-tests/tests/test_integration.py | 2 +- src/agent.rs | 107 +++--- src/agent/metrics.rs | 84 ++--- src/agent/{pythd/api.rs => pyth.rs} | 4 +- src/agent/{pythd/api => pyth}/rpc.rs | 145 +++----- src/agent/pyth/rpc/get_all_products.rs | 12 + .../{pythd/api => pyth}/rpc/get_product.rs | 6 +- src/agent/pyth/rpc/get_product_list.rs | 12 + .../api => pyth}/rpc/subscribe_price.rs | 6 +- .../api => pyth}/rpc/subscribe_price_sched.rs | 6 +- .../{pythd/api => pyth}/rpc/update_price.rs | 8 +- src/agent/pythd.rs | 1 - src/agent/pythd/api/rpc/get_all_products.rs | 12 - src/agent/pythd/api/rpc/get_product_list.rs | 12 - src/agent/remote_keypair_loader.rs | 328 ------------------ src/agent/solana.rs | 37 +- src/agent/solana/exporter.rs | 171 ++++----- src/agent/solana/oracle.rs | 171 ++++----- src/agent/state.rs | 277 ++++++++------- src/agent/state/api.rs | 261 ++++++++------ src/agent/state/global.rs | 52 +-- src/agent/state/keypairs.rs | 282 +++++++++++++++ src/agent/state/local.rs | 43 +-- src/agent/store.rs | 1 - src/bin/agent.rs | 108 ++---- src/lib.rs | 7 - 28 files changed, 1090 insertions(+), 1397 deletions(-) rename src/agent/{pythd/api.rs => pyth.rs} (100%) rename src/agent/{pythd/api => pyth}/rpc.rs (77%) create mode 100644 src/agent/pyth/rpc/get_all_products.rs rename src/agent/{pythd/api => pyth}/rpc/get_product.rs (86%) create mode 100644 src/agent/pyth/rpc/get_product_list.rs rename src/agent/{pythd/api => pyth}/rpc/subscribe_price.rs (92%) rename src/agent/{pythd/api => pyth}/rpc/subscribe_price_sched.rs (92%) rename src/agent/{pythd/api => pyth}/rpc/update_price.rs (90%) delete mode 100644 src/agent/pythd.rs delete mode 100644 src/agent/pythd/api/rpc/get_all_products.rs delete mode 100644 src/agent/pythd/api/rpc/get_product_list.rs delete mode 100644 src/agent/remote_keypair_loader.rs create mode 100644 src/agent/state/keypairs.rs delete mode 100644 src/agent/store.rs diff --git a/Cargo.lock b/Cargo.lock index 641e24a2..4acfe921 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -195,12 +195,6 @@ version = "1.0.81" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0952808a6c2afd1aa8947271f3a60f1a6763c7b912d210184c5149b5cf147247" -[[package]] -name = "arc-swap" -version = "1.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" - [[package]] name = "ark-bn254" version = "0.4.0" @@ -348,7 +342,7 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff8eb72df928aafb99fe5d37b383f2fe25bd2a765e3e5f7c365916b6f2463a29" dependencies = [ - "term 0.5.2", + "term", ] [[package]] @@ -1395,28 +1389,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fd78930633bd1c6e35c4b42b1df7b0cbc6bc191146e512bb3bedf243fcc3901" dependencies = [ "libc", - "redox_users 0.3.5", - "winapi", -] - -[[package]] -name = "dirs-next" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b98cf8ebf19c3d1b223e151f99a4f9f0690dca41414773390fc824184ac833e1" -dependencies = [ - "cfg-if", - "dirs-sys-next", -] - -[[package]] -name = "dirs-sys-next" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ebda144c4fe02d1f7ea1a7d9641b6fc6b580adcfa024ae48797ecdeb6825b4d" -dependencies = [ - "libc", - "redox_users 0.4.4", + "redox_users", "winapi", ] @@ -1597,15 +1570,6 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" -[[package]] -name = "erased-serde" -version = "0.3.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c138974f9d5e7fe373eb04df7cae98833802ae4b11c24ac7039a21d5af4b26c" -dependencies = [ - "serde", -] - [[package]] name = "errno" version = "0.3.8" @@ -2000,17 +1964,6 @@ dependencies = [ "hmac 0.8.1", ] -[[package]] -name = "hostname" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c731c3e10504cc8ed35cfe2f1db4c9274c3d35fa486e3b31df46f068ef3e867" -dependencies = [ - "libc", - "match_cfg", - "winapi", -] - [[package]] name = "htmlescape" version = "0.3.1" @@ -2214,17 +2167,6 @@ version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" -[[package]] -name = "is-terminal" -version = "0.4.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f23ff5ef2b80d608d61efee834934d862cd92461afc0560dedf493e4c033738b" -dependencies = [ - "hermit-abi 0.3.9", - "libc", - "windows-sys 0.52.0", -] - [[package]] name = "itertools" version = "0.8.2" @@ -2335,7 +2277,7 @@ dependencies = [ "serde_derive", "sha2 0.8.2", "string_cache", - "term 0.5.2", + "term", "unicode-xid 0.1.0", ] @@ -2369,17 +2311,6 @@ version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" -[[package]] -name = "libredox" -version = "0.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85c833ca1e66078851dba29046874e38f08b2c883700aa29a03ddd3b23814ee8" -dependencies = [ - "bitflags 2.5.0", - "libc", - "redox_syscall 0.4.1", -] - [[package]] name = "libsecp256k1" version = "0.6.0" @@ -2469,10 +2400,13 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" [[package]] -name = "match_cfg" +name = "matchers" version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffbee8634e0d45d258acb448e7eaab3fce7a0a467395d4d9f228e3c1f01fb2e4" +checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +dependencies = [ + "regex-automata 0.1.10", +] [[package]] name = "memchr" @@ -2608,6 +2542,16 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nu-ansi-term" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +dependencies = [ + "overload", + "winapi", +] + [[package]] name = "num" version = "0.2.1" @@ -2776,15 +2720,6 @@ dependencies = [ "syn 2.0.55", ] -[[package]] -name = "num_threads" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c7398b9c8b70908f6371f47ed36737907c87c52af34c268fed0bf0ceb92ead9" -dependencies = [ - "libc", -] - [[package]] name = "number_prefix" version = "0.4.0" @@ -2855,6 +2790,12 @@ version = "6.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2355d85b9a3786f481747ced0e0ff2ba35213a1f9bd406ed906554d7af805a1" +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + [[package]] name = "parking_lot" version = "0.12.1" @@ -3296,12 +3237,6 @@ dependencies = [ "serde", "serde-this-or-that", "serde_json", - "slog", - "slog-async", - "slog-bunyan", - "slog-envlogger", - "slog-extlog", - "slog-term", "soketto", "solana-account-decoder", "solana-client", @@ -3312,6 +3247,8 @@ dependencies = [ "tokio-stream", "tokio-util", "toml_edit 0.22.9", + "tracing", + "tracing-subscriber", "typed-html", "warp", "winnow 0.6.5", @@ -3693,17 +3630,6 @@ dependencies = [ "rust-argon2", ] -[[package]] -name = "redox_users" -version = "0.4.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a18479200779601e498ada4e8c1e1f50e3ee19deb0259c25825a98b5603b2cb4" -dependencies = [ - "getrandom 0.2.12", - "libredox", - "thiserror", -] - [[package]] name = "regex" version = "1.10.4" @@ -3712,10 +3638,19 @@ checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" dependencies = [ "aho-corasick", "memchr", - "regex-automata", + "regex-automata 0.4.6", "regex-syntax 0.8.3", ] +[[package]] +name = "regex-automata" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" +dependencies = [ + "regex-syntax 0.6.29", +] + [[package]] name = "regex-automata" version = "0.4.6" @@ -4253,6 +4188,15 @@ dependencies = [ "keccak", ] +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + [[package]] name = "shell-words" version = "1.1.0" @@ -4305,116 +4249,6 @@ dependencies = [ "autocfg 1.2.0", ] -[[package]] -name = "slog" -version = "2.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8347046d4ebd943127157b94d63abb990fcf729dc4e9978927fdf4ac3c998d06" -dependencies = [ - "erased-serde", -] - -[[package]] -name = "slog-async" -version = "2.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72c8038f898a2c79507940990f05386455b3a317d8f18d4caea7cbc3d5096b84" -dependencies = [ - "crossbeam-channel", - "slog", - "take_mut", - "thread_local", -] - -[[package]] -name = "slog-bunyan" -version = "2.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcaaf6e68789d3f0411f1e72bc443214ef252a1038b6e344836e50442541f190" -dependencies = [ - "hostname", - "slog", - "slog-json", - "time", -] - -[[package]] -name = "slog-envlogger" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "906a1a0bc43fed692df4b82a5e2fbfc3733db8dad8bb514ab27a4f23ad04f5c0" -dependencies = [ - "log", - "regex", - "slog", - "slog-async", - "slog-scope", - "slog-stdlog", - "slog-term", -] - -[[package]] -name = "slog-extlog" -version = "8.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c00caea52ddc6535e015114a7eb1d2483898f14d6f5110755c56c9f0d765fb71" -dependencies = [ - "erased-serde", - "iobuffer", - "serde", - "serde_json", - "slog", - "slog-json", -] - -[[package]] -name = "slog-json" -version = "2.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e1e53f61af1e3c8b852eef0a9dee29008f55d6dd63794f3f12cef786cf0f219" -dependencies = [ - "erased-serde", - "serde", - "serde_json", - "slog", - "time", -] - -[[package]] -name = "slog-scope" -version = "4.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f95a4b4c3274cd2869549da82b57ccc930859bdbf5bcea0424bc5f140b3c786" -dependencies = [ - "arc-swap", - "lazy_static", - "slog", -] - -[[package]] -name = "slog-stdlog" -version = "4.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6706b2ace5bbae7291d3f8d2473e2bfab073ccd7d03670946197aec98471fa3e" -dependencies = [ - "log", - "slog", - "slog-scope", -] - -[[package]] -name = "slog-term" -version = "2.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6e022d0b998abfe5c3782c1f03551a596269450ccd677ea51c56f8b214610e8" -dependencies = [ - "is-terminal", - "slog", - "term 0.7.0", - "thread_local", - "time", -] - [[package]] name = "smallvec" version = "1.13.2" @@ -5561,12 +5395,6 @@ dependencies = [ "libc", ] -[[package]] -name = "take_mut" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f764005d11ee5f36500a149ace24e00e3da98b0158b3e2d53a7495660d3f4d60" - [[package]] name = "tempfile" version = "3.10.1" @@ -5590,17 +5418,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "term" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c59df8ac95d96ff9bede18eb7300b0fda5e5d8d90960e76f8e14ae765eedbf1f" -dependencies = [ - "dirs-next", - "rustversion", - "winapi", -] - [[package]] name = "termcolor" version = "1.4.1" @@ -5663,9 +5480,7 @@ checksum = "c8248b6521bb14bc45b4067159b9b6ad792e2d6d754d6c41fb50e29fefe38749" dependencies = [ "deranged", "itoa", - "libc", "num-conv", - "num_threads", "powerfmt", "serde", "time-core", @@ -5927,6 +5742,49 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" dependencies = [ "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-serde" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc6b213177105856957181934e4920de57730fc69bf42c37ee5bb664d406d9e1" +dependencies = [ + "serde", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex", + "serde", + "serde_json", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", + "tracing-serde", ] [[package]] @@ -6122,6 +5980,12 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" +[[package]] +name = "valuable" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" + [[package]] name = "vec_map" version = "0.8.2" diff --git a/Cargo.toml b/Cargo.toml index 4612b389..f7fa1f9d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,15 +33,11 @@ solana-account-decoder = "1.18.8" solana-client = "1.18.8" solana-sdk = "1.18.8" bincode = "1.3.3" -slog = { version = "2.7.0", features = ["max_level_trace", "release_max_level_trace"] } -slog-term = "2.9.1" rand = "0.8.5" -slog-async = "2.8.0" config = "0.14.0" thiserror = "1.0.58" clap = { version = "4.5.4", features = ["derive"] } humantime-serde = "1.1.1" -slog-envlogger = "2.2.0" serde-this-or-that = "0.4.2" # The public typed-html 0.2.2 release is causing a recursion limit # error that cannot be fixed from outside the crate. @@ -52,9 +48,10 @@ humantime = "2.1.0" prometheus-client = "0.22.2" lazy_static = "1.4.0" toml_edit = "0.22.9" -slog-bunyan = "2.5.0" winnow = "0.6.5" proptest = "1.4.0" +tracing = { version = "0.1.40", features = ["log"] } +tracing-subscriber = { version = "0.3.18", features = ["env-filter", "json"] } [dev-dependencies] tokio-util = { version = "0.7.10", features = ["full"] } @@ -62,7 +59,6 @@ soketto = "0.8.0" portpicker = "0.1.1" rand = "0.8.5" tokio-retry = "0.3.0" -slog-extlog = "8.1.0" iobuffer = "0.2.0" [profile.release] diff --git a/integration-tests/tests/test_integration.py b/integration-tests/tests/test_integration.py index fae2e87d..d155a903 100644 --- a/integration-tests/tests/test_integration.py +++ b/integration-tests/tests/test_integration.py @@ -695,7 +695,7 @@ async def test_update_price_discards_unpermissioned(self, client: PythAgentClien lines_found += 1 expected_unperm_pubkey = final_price_account_unperm["account"] # Must point at the expected account as all other attempts must be valid - assert f"price_account: {expected_unperm_pubkey}" in line + assert f'"unpermissioned_price_account":"{expected_unperm_pubkey}"' in line # Must find at least one log discarding the account assert lines_found > 0 diff --git a/src/agent.rs b/src/agent.rs index c3090757..d0f66206 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -61,32 +61,40 @@ Metrics Server: Note that there is an Oracle and Exporter for each network, but only one Local Store and Global Store. ################################################################################################################################## */ - -pub mod legacy_schedule; -pub mod market_schedule; -pub mod metrics; -pub mod pythd; -pub mod remote_keypair_loader; -pub mod solana; -pub mod state; -pub mod store; use { self::{ config::Config, - pythd::api::rpc, + pyth::rpc, solana::network, state::notifier, }, anyhow::Result, futures_util::future::join_all, - slog::Logger, + lazy_static::lazy_static, std::sync::Arc, - tokio::sync::{ - broadcast, - mpsc, - }, + tokio::sync::watch, }; +pub mod legacy_schedule; +pub mod market_schedule; +pub mod metrics; +pub mod pyth; +pub mod solana; +pub mod state; + +lazy_static! { + /// A static exit flag to indicate to running threads that we're shutting down. This is used to + /// gracefully shut down the application. + /// + /// We make this global based on the fact the: + /// - The `Sender` side does not rely on any async runtime. + /// - Exit logic doesn't really require carefully threading this value through the app. + /// - The `Receiver` side of a watch channel performs the detection based on if the change + /// happened after the subscribe, so it means all listeners should always be notified + /// correctly. + pub static ref EXIT: watch::Sender = watch::channel(false).0; +} + pub struct Agent { config: Config, } @@ -96,40 +104,34 @@ impl Agent { Agent { config } } - pub async fn start(&self, logger: Logger) { - info!(logger, "Starting {}", env!("CARGO_PKG_NAME"); - "config" => format!("{:?}", &self.config), - "version" => env!("CARGO_PKG_VERSION"), - "cwd" => std::env::current_dir().map(|p| format!("{}", p.display())).unwrap_or("".to_owned()) + pub async fn start(&self) { + tracing::info!( + config = format!("{:?}", &self.config), + version = env!("CARGO_PKG_VERSION"), + cwd = std::env::current_dir() + .map(|p| format!("{}", p.display())) + .unwrap_or("".to_owned()), + "Starting {}", + env!("CARGO_PKG_NAME"), ); - if let Err(err) = self.spawn(logger.clone()).await { - error!(logger, "{}", err); - debug!(logger, "error context"; "context" => format!("{:?}", err)); + if let Err(err) = self.spawn().await { + tracing::error!(err = ?err, "Agent spawn failed."); }; } - async fn spawn(&self, logger: Logger) -> Result<()> { + async fn spawn(&self) -> Result<()> { // job handles let mut jhs = vec![]; - // Create the channels - // TODO: make all components listen to shutdown signal - let (shutdown_tx, _) = broadcast::channel(self.config.channel_capacities.shutdown); - let (primary_keypair_loader_tx, primary_keypair_loader_rx) = mpsc::channel(10); - let (secondary_keypair_loader_tx, secondary_keypair_loader_rx) = mpsc::channel(10); - - // Create the Pythd Adapter. - let adapter = - Arc::new(state::State::new(self.config.pythd_adapter.clone(), logger.clone()).await); + // Create the Application State. + let state = Arc::new(state::State::new(self.config.state.clone()).await); // Spawn the primary network jhs.extend(network::spawn_network( self.config.primary_network.clone(), network::Network::Primary, - primary_keypair_loader_tx, - logger.new(o!("primary" => true)), - adapter.clone(), + state.clone(), )?); // Spawn the secondary network, if needed @@ -137,45 +139,34 @@ impl Agent { jhs.extend(network::spawn_network( config.clone(), network::Network::Secondary, - secondary_keypair_loader_tx, - logger.new(o!("primary" => false)), - adapter.clone(), + state.clone(), )?); } // Create the Notifier task for the Pythd RPC. - jhs.push(tokio::spawn(notifier( - adapter.clone(), - shutdown_tx.subscribe(), - ))); + jhs.push(tokio::spawn(notifier(state.clone()))); // Spawn the Pythd API Server jhs.push(tokio::spawn(rpc::run( self.config.pythd_api_server.clone(), - logger.clone(), - adapter.clone(), - shutdown_tx.subscribe(), + state.clone(), ))); // Spawn the metrics server - jhs.push(tokio::spawn(metrics::MetricsServer::spawn( + jhs.push(tokio::spawn(metrics::spawn( self.config.metrics_server.bind_address, - logger.clone(), - adapter, ))); // Spawn the remote keypair loader endpoint for both networks jhs.append( - &mut remote_keypair_loader::RemoteKeypairLoader::spawn( - primary_keypair_loader_rx, - secondary_keypair_loader_rx, + &mut state::keypairs::spawn( self.config.primary_network.rpc_url.clone(), self.config .secondary_network .as_ref() .map(|c| c.rpc_url.clone()), self.config.remote_keypair_loader.clone(), - logger, + state, ) .await, ); @@ -191,8 +182,7 @@ pub mod config { use { super::{ metrics, - pythd, - remote_keypair_loader, + pyth, solana::network, state, }, @@ -214,13 +204,14 @@ pub mod config { pub primary_network: network::Config, pub secondary_network: Option, #[serde(default)] - pub pythd_adapter: state::Config, + #[serde(rename = "pythd_adapter")] + pub state: state::Config, #[serde(default)] - pub pythd_api_server: pythd::api::rpc::Config, + pub pythd_api_server: pyth::rpc::Config, #[serde(default)] pub metrics_server: metrics::Config, #[serde(default)] - pub remote_keypair_loader: remote_keypair_loader::Config, + pub remote_keypair_loader: state::keypairs::Config, } impl Config { diff --git a/src/agent/metrics.rs b/src/agent/metrics.rs index b8068402..4c830ce1 100644 --- a/src/agent/metrics.rs +++ b/src/agent/metrics.rs @@ -1,12 +1,6 @@ use { - super::state::{ - local::PriceInfo, - State, - }, - crate::agent::{ - solana::oracle::PriceEntry, - store::PriceIdentifier, - }, + super::state::local::PriceInfo, + crate::agent::solana::oracle::PriceEntry, lazy_static::lazy_static, prometheus_client::{ encoding::{ @@ -21,7 +15,6 @@ use { registry::Registry, }, serde::Deserialize, - slog::Logger, solana_sdk::pubkey::Pubkey, std::{ net::SocketAddr, @@ -29,7 +22,6 @@ use { atomic::AtomicU64, Arc, }, - time::Instant, }, tokio::sync::Mutex, warp::{ @@ -64,49 +56,33 @@ lazy_static! { Arc::new(Mutex::new(::default())); } -/// Internal metrics server state, holds state needed for serving -/// metrics. -pub struct MetricsServer { - pub start_time: Instant, - pub logger: Logger, - pub adapter: Arc, -} - -impl MetricsServer { - /// Instantiate a metrics API. - pub async fn spawn(addr: impl Into + 'static, logger: Logger, adapter: Arc) { - let server = MetricsServer { - start_time: Instant::now(), - logger, - adapter, - }; - - let shared_state = Arc::new(Mutex::new(server)); - let shared_state4metrics = shared_state.clone(); - let metrics_route = warp::path("metrics") - .and(warp::path::end()) - .and_then(move || { - let shared_state = shared_state4metrics.clone(); - async move { - let locked_state = shared_state.lock().await; - let mut buf = String::new(); - let response = encode(&mut buf, &&PROMETHEUS_REGISTRY.lock().await) - .map_err(|e| -> Box { - e.into() - }) - .and_then(|_| -> Result<_, Box> { - Ok(Box::new(reply::with_status(buf, StatusCode::OK))) - }).unwrap_or_else(|e| { - error!(locked_state.logger, "Metrics: Could not gather metrics from registry"; "error" => e.to_string()); - Box::new(reply::with_status("Could not gather metrics. See logs for details".to_string(), StatusCode::INTERNAL_SERVER_ERROR)) - }); - - Result::, Rejection>::Ok(response) - } - }); - - warp::serve(metrics_route).bind(addr).await; - } +/// Instantiate a metrics API. +pub async fn spawn(addr: impl Into + 'static) { + let metrics_route = warp::path("metrics") + .and(warp::path::end()) + .and_then(move || async move { + let mut buf = String::new(); + let response = encode(&mut buf, &&PROMETHEUS_REGISTRY.lock().await) + .map_err(|e| -> Box { e.into() }) + .and_then(|_| -> Result<_, Box> { + Ok(Box::new(reply::with_status(buf, StatusCode::OK))) + }) + .unwrap_or_else(|e| { + tracing::error!(err = ?e, "Metrics: Could not gather metrics from registry"); + Box::new(reply::with_status( + "Could not gather metrics. See logs for details".to_string(), + StatusCode::INTERNAL_SERVER_ERROR, + )) + }); + + Result::, Rejection>::Ok(response) + }); + + let (_, serve) = warp::serve(metrics_route).bind_with_graceful_shutdown(addr, async { + let _ = crate::agent::EXIT.subscribe().changed().await; + }); + + serve.await } #[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)] @@ -362,7 +338,7 @@ impl PriceLocalMetrics { metrics } - pub fn update(&self, price_id: &PriceIdentifier, price_info: &PriceInfo) { + pub fn update(&self, price_id: &pyth_sdk::Identifier, price_info: &PriceInfo) { #[deny(unused_variables)] let Self { price, diff --git a/src/agent/pythd/api.rs b/src/agent/pyth.rs similarity index 100% rename from src/agent/pythd/api.rs rename to src/agent/pyth.rs index 3d8a27c5..0ab98a59 100644 --- a/src/agent/pythd/api.rs +++ b/src/agent/pyth.rs @@ -6,6 +6,8 @@ use { std::collections::BTreeMap, }; +pub mod rpc; + pub type Pubkey = String; pub type Attrs = BTreeMap; @@ -83,5 +85,3 @@ pub struct PriceUpdate { pub valid_slot: Slot, pub pub_slot: Slot, } - -pub mod rpc; diff --git a/src/agent/pythd/api/rpc.rs b/src/agent/pyth/rpc.rs similarity index 77% rename from src/agent/pythd/api/rpc.rs rename to src/agent/pyth/rpc.rs index 9b3b38ac..8d705a44 100644 --- a/src/agent/pythd/api/rpc.rs +++ b/src/agent/pyth/rpc.rs @@ -44,16 +44,12 @@ use { as_i64, as_u64, }, - slog::Logger, std::{ fmt::Debug, net::SocketAddr, sync::Arc, }, - tokio::sync::{ - broadcast, - mpsc, - }, + tokio::sync::mpsc, warp::{ ws::{ Message, @@ -115,12 +111,11 @@ enum ConnectionError { async fn handle_connection( ws_conn: WebSocket, - adapter: Arc, + state: Arc, notify_price_tx_buffer: usize, notify_price_sched_tx_buffer: usize, - logger: Logger, ) where - S: state::StateApi, + S: state::Prices, S: Send, S: Sync, S: 'static, @@ -133,8 +128,7 @@ async fn handle_connection( loop { if let Err(err) = handle_next( - &logger, - &*adapter, + &*state, &mut ws_tx, &mut ws_rx, &mut notify_price_tx, @@ -147,19 +141,17 @@ async fn handle_connection( if let Some(ConnectionError::WebsocketConnectionClosed) = err.downcast_ref::() { - info!(logger, "websocket connection closed"); + tracing::info!("Websocket connection closed."); return; } - error!(logger, "{}", err); - debug!(logger, "error context"; "context" => format!("{:?}", err)); + tracing::error!(err = ?err, "RPC failed to handle WebSocket message."); } } } async fn handle_next( - logger: &Logger, - adapter: &S, + state: &S, ws_tx: &mut SplitSink, ws_rx: &mut SplitStream, notify_price_tx: &mut mpsc::Sender, @@ -168,7 +160,7 @@ async fn handle_next( notify_price_sched_rx: &mut mpsc::Receiver, ) -> Result<()> where - S: state::StateApi, + S: state::Prices, { tokio::select! { msg = ws_rx.next() => { @@ -176,9 +168,8 @@ where Some(body) => match body { Ok(msg) => { handle( - logger, ws_tx, - adapter, + state, notify_price_tx, notify_price_sched_tx, msg, @@ -202,19 +193,18 @@ where } async fn handle( - logger: &Logger, ws_tx: &mut SplitSink, - adapter: &S, + state: &S, notify_price_tx: &mpsc::Sender, notify_price_sched_tx: &mpsc::Sender, msg: Message, ) -> Result<()> where - S: state::StateApi, + S: state::Prices, { // Ignore control and binary messages if !msg.is_text() { - debug!(logger, "JSON RPC API: skipped non-text message"); + tracing::debug!("JSON RPC API: skipped non-text message"); return Ok(()); } @@ -226,8 +216,7 @@ where // Perform requests in sequence and gather responses for request in requests { let response = dispatch_and_catch_error( - logger, - adapter, + state, notify_price_tx, notify_price_sched_tx, &request, @@ -289,29 +278,27 @@ async fn parse(msg: Message) -> Result<(Vec>, bool)> { } async fn dispatch_and_catch_error( - logger: &Logger, - adapter: &S, + state: &S, notify_price_tx: &mpsc::Sender, notify_price_sched_tx: &mpsc::Sender, request: &Request, ) -> Response where - S: state::StateApi, + S: state::Prices, { - debug!( - logger, - "JSON RPC API: handling request"; - "method" => format!("{:?}", request.method), + tracing::debug!( + method = ?request.method, + "JSON RPC API: handling request", ); let result = match request.method { - Method::GetProductList => get_product_list(adapter).await, - Method::GetProduct => get_product(adapter, request).await, - Method::GetAllProducts => get_all_products(adapter).await, - Method::UpdatePrice => update_price(adapter, request).await, - Method::SubscribePrice => subscribe_price(adapter, notify_price_tx, request).await, + Method::GetProductList => get_product_list(state).await, + Method::GetProduct => get_product(state, request).await, + Method::GetAllProducts => get_all_products(state).await, + Method::UpdatePrice => update_price(state, request).await, + Method::SubscribePrice => subscribe_price(state, notify_price_tx, request).await, Method::SubscribePriceSched => { - subscribe_price_sched(adapter, notify_price_sched_tx, request).await + subscribe_price_sched(state, notify_price_sched_tx, request).await } Method::NotifyPrice | Method::NotifyPriceSched => { Err(anyhow!("unsupported method: {:?}", request.method)) @@ -324,11 +311,10 @@ where Response::success(request.id.clone().to_id().unwrap_or(Id::from(0)), payload) } Err(e) => { - warn!( - logger, - "Error handling JSON RPC request"; - "request" => format!("{:?}", request), - "error" => format!("{}", e.to_string()), + tracing::warn!( + request = ?request, + error = e.to_string(), + "Error handling JSON RPC request", ); Response::error( @@ -402,21 +388,16 @@ async fn send_text(ws_tx: &mut SplitSink, msg: &str) -> Resu .map_err(|e| e.into()) } -#[derive(Clone)] -struct WithLogger { - logger: Logger, -} - #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(default)] pub struct Config { /// The address which the websocket API server will listen on. pub listen_address: String, /// Size of the buffer of each Server's channel on which `notify_price` events are - /// received from the Adapter. + /// received from the Price state. pub notify_price_tx_buffer: usize, /// Size of the buffer of each Server's channel on which `notify_price_sched` events are - /// received from the Adapter. + /// received from the Price state. pub notify_price_sched_tx_buffer: usize, } @@ -430,72 +411,58 @@ impl Default for Config { } } -pub async fn run( - config: Config, - logger: Logger, - adapter: Arc, - shutdown_rx: broadcast::Receiver<()>, -) where - S: state::StateApi, +pub async fn run(config: Config, state: Arc) +where + S: state::Prices, S: Send, S: Sync, S: 'static, { - if let Err(err) = serve(config, &logger, adapter, shutdown_rx).await { - error!(logger, "{}", err); - debug!(logger, "error context"; "context" => format!("{:?}", err)); + if let Err(err) = serve(config, state).await { + tracing::error!(err = ?err, "RPC server failed."); } } -async fn serve( - config: Config, - logger: &Logger, - adapter: Arc, - mut shutdown_rx: broadcast::Receiver<()>, -) -> Result<()> +async fn serve(config: Config, state: Arc) -> Result<()> where - S: state::StateApi, + S: state::Prices, S: Send, S: Sync, S: 'static, { let config = config.clone(); - let with_logger = WithLogger { - logger: logger.clone(), - }; let index = { let config = config.clone(); warp::path::end() .and(warp::ws()) - .and(warp::any().map(move || adapter.clone())) - .and(warp::any().map(move || with_logger.clone())) + .and(warp::any().map(move || state.clone())) .and(warp::any().map(move || config.clone())) - .map( - |ws: Ws, adapter: Arc, with_logger: WithLogger, config: Config| { - ws.on_upgrade(move |conn| async move { - info!(with_logger.logger, "websocket user connected"); - handle_connection( - conn, - adapter, - config.notify_price_tx_buffer, - config.notify_price_sched_tx_buffer, - with_logger.logger, - ) - .await - }) - }, - ) + .map(|ws: Ws, state: Arc, config: Config| { + ws.on_upgrade(move |conn| async move { + tracing::info!("Websocket user connected."); + handle_connection( + conn, + state, + config.notify_price_tx_buffer, + config.notify_price_sched_tx_buffer, + ) + .await + }) + }) }; let (_, serve) = warp::serve(index).bind_with_graceful_shutdown( config.listen_address.as_str().parse::()?, - async move { - let _ = shutdown_rx.recv().await; + async { + let _ = crate::agent::EXIT.subscribe().changed().await; }, ); - info!(logger, "starting api server"; "listen address" => config.listen_address.clone()); + tracing::info!( + listen_address = config.listen_address.clone(), + "Starting api server.", + ); tokio::task::spawn(serve).await.map_err(|e| e.into()) } diff --git a/src/agent/pyth/rpc/get_all_products.rs b/src/agent/pyth/rpc/get_all_products.rs new file mode 100644 index 00000000..7a342789 --- /dev/null +++ b/src/agent/pyth/rpc/get_all_products.rs @@ -0,0 +1,12 @@ +use { + crate::agent::state, + anyhow::Result, +}; + +pub async fn get_all_products(state: &S) -> Result +where + S: state::Prices, +{ + let products = state.get_all_products().await?; + Ok(serde_json::to_value(products)?) +} diff --git a/src/agent/pythd/api/rpc/get_product.rs b/src/agent/pyth/rpc/get_product.rs similarity index 86% rename from src/agent/pythd/api/rpc/get_product.rs rename to src/agent/pyth/rpc/get_product.rs index 8ff49dc8..19c69a5e 100644 --- a/src/agent/pythd/api/rpc/get_product.rs +++ b/src/agent/pyth/rpc/get_product.rs @@ -15,11 +15,11 @@ use { }; pub async fn get_product( - adapter: &S, + state: &S, request: &Request, ) -> Result where - S: state::StateApi, + S: state::Prices, { let params: GetProductParams = { let value = request.params.clone(); @@ -27,6 +27,6 @@ where }?; let account = params.account.parse::()?; - let product = adapter.get_product(&account).await?; + let product = state.get_product(&account).await?; Ok(serde_json::to_value(product)?) } diff --git a/src/agent/pyth/rpc/get_product_list.rs b/src/agent/pyth/rpc/get_product_list.rs new file mode 100644 index 00000000..30cde6e1 --- /dev/null +++ b/src/agent/pyth/rpc/get_product_list.rs @@ -0,0 +1,12 @@ +use { + crate::agent::state, + anyhow::Result, +}; + +pub async fn get_product_list(state: &S) -> Result +where + S: state::Prices, +{ + let product_list = state.get_product_list().await?; + Ok(serde_json::to_value(product_list)?) +} diff --git a/src/agent/pythd/api/rpc/subscribe_price.rs b/src/agent/pyth/rpc/subscribe_price.rs similarity index 92% rename from src/agent/pythd/api/rpc/subscribe_price.rs rename to src/agent/pyth/rpc/subscribe_price.rs index f2319b1c..59365051 100644 --- a/src/agent/pythd/api/rpc/subscribe_price.rs +++ b/src/agent/pyth/rpc/subscribe_price.rs @@ -18,12 +18,12 @@ use { }; pub async fn subscribe_price( - adapter: &S, + state: &S, notify_price_tx: &mpsc::Sender, request: &Request, ) -> Result where - S: state::StateApi, + S: state::Prices, { let params: SubscribePriceParams = serde_json::from_value( request @@ -33,7 +33,7 @@ where )?; let account = params.account.parse::()?; - let subscription = adapter + let subscription = state .subscribe_price(&account, notify_price_tx.clone()) .await; diff --git a/src/agent/pythd/api/rpc/subscribe_price_sched.rs b/src/agent/pyth/rpc/subscribe_price_sched.rs similarity index 92% rename from src/agent/pythd/api/rpc/subscribe_price_sched.rs rename to src/agent/pyth/rpc/subscribe_price_sched.rs index c11ffa8d..608a489d 100644 --- a/src/agent/pythd/api/rpc/subscribe_price_sched.rs +++ b/src/agent/pyth/rpc/subscribe_price_sched.rs @@ -18,12 +18,12 @@ use { }; pub async fn subscribe_price_sched( - adapter: &S, + state: &S, notify_price_sched_tx: &mpsc::Sender, request: &Request, ) -> Result where - S: state::StateApi, + S: state::Prices, { let params: SubscribePriceSchedParams = serde_json::from_value( request @@ -33,7 +33,7 @@ where )?; let account = params.account.parse::()?; - let subscription = adapter + let subscription = state .subscribe_price_sched(&account, notify_price_sched_tx.clone()) .await; diff --git a/src/agent/pythd/api/rpc/update_price.rs b/src/agent/pyth/rpc/update_price.rs similarity index 90% rename from src/agent/pythd/api/rpc/update_price.rs rename to src/agent/pyth/rpc/update_price.rs index 0eb532ea..c5748af1 100644 --- a/src/agent/pythd/api/rpc/update_price.rs +++ b/src/agent/pyth/rpc/update_price.rs @@ -15,11 +15,11 @@ use { }; pub async fn update_price( - adapter: &S, + state: &S, request: &Request, ) -> Result where - S: state::StateApi, + S: state::Prices, { let params: UpdatePriceParams = serde_json::from_value( request @@ -28,8 +28,8 @@ where .ok_or_else(|| anyhow!("Missing request parameters"))?, )?; - adapter - .update_price( + state + .update_local_price( ¶ms.account.parse::()?, params.price, params.conf, diff --git a/src/agent/pythd.rs b/src/agent/pythd.rs deleted file mode 100644 index e5fdf85e..00000000 --- a/src/agent/pythd.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod api; diff --git a/src/agent/pythd/api/rpc/get_all_products.rs b/src/agent/pythd/api/rpc/get_all_products.rs deleted file mode 100644 index be7b39bf..00000000 --- a/src/agent/pythd/api/rpc/get_all_products.rs +++ /dev/null @@ -1,12 +0,0 @@ -use { - crate::agent::state, - anyhow::Result, -}; - -pub async fn get_all_products(adapter: &S) -> Result -where - S: state::StateApi, -{ - let products = adapter.get_all_products().await?; - Ok(serde_json::to_value(products)?) -} diff --git a/src/agent/pythd/api/rpc/get_product_list.rs b/src/agent/pythd/api/rpc/get_product_list.rs deleted file mode 100644 index 833b2688..00000000 --- a/src/agent/pythd/api/rpc/get_product_list.rs +++ /dev/null @@ -1,12 +0,0 @@ -use { - crate::agent::state, - anyhow::Result, -}; - -pub async fn get_product_list(adapter: &S) -> Result -where - S: state::StateApi, -{ - let product_list = adapter.get_product_list().await?; - Ok(serde_json::to_value(product_list)?) -} diff --git a/src/agent/remote_keypair_loader.rs b/src/agent/remote_keypair_loader.rs deleted file mode 100644 index f8f3a670..00000000 --- a/src/agent/remote_keypair_loader.rs +++ /dev/null @@ -1,328 +0,0 @@ -//! Remote keypair loading endpoint. Lets you hotload a keypair in -//! runtime for publishing to the given network. -//! -use { - anyhow::{ - Context, - Result, - }, - serde::Deserialize, - slog::Logger, - solana_client::nonblocking::rpc_client::RpcClient, - solana_sdk::{ - commitment_config::CommitmentConfig, - signature::Keypair, - signer::Signer, - }, - std::{ - net::SocketAddr, - sync::Arc, - time::Duration, - }, - tokio::{ - sync::{ - mpsc, - oneshot, - Mutex, - }, - task::JoinHandle, - }, - warp::{ - hyper::StatusCode, - reply::{ - self, - WithStatus, - }, - Filter, - Rejection, - }, -}; - -pub fn default_min_keypair_balance_sol() -> u64 { - 1 -} - -pub fn default_bind_address() -> SocketAddr { - "127.0.0.1:9001" - .parse() - .expect("INTERNAL: Could not build default remote keypair loader bind address") -} - -#[derive(Clone, Debug, Deserialize)] -#[serde(default)] -pub struct Config { - primary_min_keypair_balance_sol: u64, - secondary_min_keypair_balance_sol: u64, - bind_address: SocketAddr, -} - -impl Default for Config { - fn default() -> Self { - Self { - primary_min_keypair_balance_sol: default_min_keypair_balance_sol(), - secondary_min_keypair_balance_sol: default_min_keypair_balance_sol(), - bind_address: default_bind_address(), - } - } -} - -#[derive(Debug)] -pub struct KeypairRequest { - /// Where to send the key back - response_tx: oneshot::Sender, -} - -pub struct RemoteKeypairLoader { - primary_current_keypair: Option, - secondary_current_keypair: Option, - primary_rpc_url: String, - secondary_rpc_url: Option, - config: Config, -} - -impl RemoteKeypairLoader { - pub async fn spawn( - primary_requests_rx: mpsc::Receiver, - secondary_requests_rx: mpsc::Receiver, - primary_rpc_url: String, - secondary_rpc_url: Option, - config: Config, - logger: Logger, - ) -> Vec> { - let bind_address = config.bind_address; - - let ip = bind_address.ip(); - - if !ip.is_loopback() { - warn!(logger, "Remote key loader: bind address is not localhost. Make sure the access on the selected address is secure."; "bind_address" => bind_address,); - } - - let shared_state = Arc::new(Mutex::new(Self { - primary_current_keypair: None, - secondary_current_keypair: None, - primary_rpc_url, - secondary_rpc_url, - config, - })); - - let request_handler_jh = tokio::spawn(handle_key_requests( - primary_requests_rx, - secondary_requests_rx, - shared_state.clone(), - logger.clone(), - )); - - let logger4primary = logger.clone(); - let shared_state4primary = shared_state.clone(); - let primary_upload_route = warp::path!("primary" / "load_keypair") - .and(warp::post()) - .and(warp::body::content_length_limit(1024)) - .and(warp::body::json()) - .and(warp::path::end()) - .and_then(move |kp: Vec| { - let shared_state = shared_state4primary.clone(); - let logger = logger4primary.clone(); - async move { - let mut locked_state = shared_state.lock().await; - - let min_balance = locked_state.config.primary_min_keypair_balance_sol; - let rpc_url = locked_state.primary_rpc_url.clone(); - - let response = Self::handle_new_keypair( - &mut (locked_state.primary_current_keypair), - kp, - min_balance, - rpc_url, - "primary", - logger, - ) - .await; - - Result::, Rejection>::Ok(response) - } - }); - - let secondary_upload_route = warp::path!("secondary" / "load_keypair") - .and(warp::post()) - .and(warp::body::content_length_limit(1024)) - .and(warp::body::json()) - .and(warp::path::end()) - .and_then(move |kp: Vec| { - let shared_state = shared_state.clone(); - let logger = logger.clone(); - async move { - let mut locked_state = shared_state.lock().await; - - if let Some(rpc_url) = locked_state.secondary_rpc_url.clone() { - let min_balance = locked_state.config.secondary_min_keypair_balance_sol; - - let response = Self::handle_new_keypair( - &mut (locked_state.secondary_current_keypair), - kp, - min_balance, - rpc_url, - "secondary", - logger, - ) - .await; - - Result::, Rejection>::Ok(response) - } else { - Result::, Rejection>::Ok(reply::with_status( - "Secondary network is not active", - StatusCode::SERVICE_UNAVAILABLE, - )) - } - } - }); - - let http_api_jh = tokio::spawn( - warp::serve(primary_upload_route.or(secondary_upload_route)).bind(bind_address), - ); - - // WARNING: All jobs spawned here must report their join handles in this vec - vec![request_handler_jh, http_api_jh] - } - - /// Validate and apply a keypair to the specified mut reference, - /// hiding errors in logs. - /// - /// Returns the appropriate HTTP response depending on checks success. - /// - /// NOTE(2023-03-22): Lifetime bounds are currently necessary - /// because of https://github.com/rust-lang/rust/issues/63033 - async fn handle_new_keypair<'a, 'b: 'a>( - keypair_slot: &'a mut Option, - new_keypair_bytes: Vec, - min_keypair_balance_sol: u64, - rpc_url: String, - network_name: &'b str, - logger: Logger, - ) -> WithStatus<&'static str> { - let mut upload_ok = true; - - match Keypair::from_bytes(&new_keypair_bytes) { - Ok(kp) => { - match Self::validate_keypair(&kp, min_keypair_balance_sol, rpc_url.clone()).await { - Ok(()) => { - *keypair_slot = Some(kp); - } - Err(e) => { - warn!(logger, "Remote keypair loader: Keypair failed validation"; - "network" => network_name, - "error" => e.to_string(), - ); - upload_ok = false; - } - } - } - Err(e) => { - warn!(logger, "Remote keypair loader: Keypair failed validation"; - "network" => network_name, - "error" => e.to_string(), - ); - upload_ok = false; - } - } - - if upload_ok { - reply::with_status("keypair upload OK", StatusCode::OK) - } else { - reply::with_status( - "Could not upload keypair. See logs for details.", - StatusCode::BAD_REQUEST, - ) - } - } - - /// Validate keypair balance before using it in transactions. - pub async fn validate_keypair( - kp: &Keypair, - min_keypair_balance_sol: u64, - rpc_url: String, - ) -> Result<()> { - let c = RpcClient::new_with_commitment(rpc_url, CommitmentConfig::confirmed()); - - let balance_lamports = c - .get_balance(&kp.pubkey()) - .await - .context("Could not check keypair's balance")?; - - let lamports_in_sol = 1_000_000_000; - - if balance_lamports > min_keypair_balance_sol * lamports_in_sol { - Ok(()) - } else { - Err(anyhow::anyhow!(format!( - "Keypair {} balance of {} SOL below threshold of {} SOL", - kp.pubkey(), - balance_lamports as f64 / lamports_in_sol as f64, - min_keypair_balance_sol - ))) - } - } - - /// Get a keypair using the specified request - /// sender. primary/secondary is decided by the channel the tx - /// that request_tx comes from. - pub async fn request_keypair(request_tx: &mpsc::Sender) -> Result { - let (tx, rx) = oneshot::channel(); - - request_tx.send(KeypairRequest { response_tx: tx }).await?; - - Ok(rx.await?) - } -} - -/// Query channel receivers indefinitely, sending back the requested -/// keypair if available. -async fn handle_key_requests( - mut primary_rx: mpsc::Receiver, - mut secondary_rx: mpsc::Receiver, - shared_state: Arc>, - logger: Logger, -) { - loop { - let locked_state = shared_state.lock().await; - - // Only handle requests for defined keypairs. The possibility - // of missing keypair is the reason we are not - // tokio::select!()-ing on the two channel receivers. - - if let Some(primary_keypair) = locked_state.primary_current_keypair.as_ref() { - // Drain all primary keypair requests - while let Ok(KeypairRequest { response_tx }) = primary_rx.try_recv() { - let copied_keypair = Keypair::from_bytes(&primary_keypair.to_bytes()) - .expect("INTERNAL: could not convert Keypair to bytes and back"); - - match response_tx.send(copied_keypair) { - Ok(()) => {} - Err(_e) => { - warn!(logger, "remote_keypair_loader: Could not send back primary keypair to channel"; - ); - } - } - } - } - - if let Some(secondary_keypair) = locked_state.secondary_current_keypair.as_ref() { - // Drain all secondary keypair requests - while let Ok(KeypairRequest { response_tx }) = secondary_rx.try_recv() { - let copied_keypair = Keypair::from_bytes(&secondary_keypair.to_bytes()) - .expect("INTERNAL: could not convert Keypair to bytes and back"); - - match response_tx.send(copied_keypair) { - Ok(()) => {} - Err(_e) => { - warn!(logger, "remote_keypair_loader: Could not send back secondary keypair to channel"); - } - } - } - } - - // Free the state for others while we sleep - drop(locked_state); - - tokio::time::sleep(Duration::from_millis(500)).await; - } -} diff --git a/src/agent/solana.rs b/src/agent/solana.rs index 6a58fdc4..5a32ff40 100644 --- a/src/agent/solana.rs +++ b/src/agent/solana.rs @@ -14,25 +14,18 @@ pub mod network { }, oracle, }, - crate::agent::{ - remote_keypair_loader::KeypairRequest, - state::State, - }, + crate::agent::state::State, anyhow::Result, serde::{ Deserialize, Serialize, }, - slog::Logger, std::{ sync::Arc, time::Duration, }, tokio::{ - sync::{ - mpsc::Sender, - watch, - }, + sync::watch, task::JoinHandle, }, }; @@ -80,9 +73,7 @@ pub mod network { pub fn spawn_network( config: Config, network: Network, - keypair_request_tx: Sender, - logger: Logger, - adapter: Arc, + state: Arc, ) -> Result>> { // Publisher permissions updates between oracle and exporter let (publisher_permissions_tx, publisher_permissions_rx) = watch::channel(<_>::default()); @@ -95,9 +86,8 @@ pub mod network { &config.wss_url, config.rpc_timeout, publisher_permissions_tx, - KeyStore::new(config.key_store.clone(), &logger)?, - logger.clone(), - adapter.clone(), + KeyStore::new(config.key_store.clone())?, + state.clone(), ); // Spawn the Exporter @@ -107,10 +97,8 @@ pub mod network { &config.rpc_url, config.rpc_timeout, publisher_permissions_rx, - KeyStore::new(config.key_store.clone(), &logger)?, - keypair_request_tx, - logger, - adapter, + KeyStore::new(config.key_store.clone())?, + state, )?; jhs.extend(exporter_jhs); @@ -130,7 +118,6 @@ mod key_store { Serialize, Serializer, }, - slog::Logger, solana_sdk::{ pubkey::Pubkey, signature::Keypair, @@ -184,13 +171,15 @@ mod key_store { } impl KeyStore { - pub fn new(config: Config, logger: &Logger) -> Result { + pub fn new(config: Config) -> Result { let publish_keypair = match keypair::read_keypair_file(&config.publish_keypair_path) { Ok(k) => Some(k), Err(e) => { - warn!(logger, - "Reading publish keypair returned an error. Waiting for a remote-loaded key before publishing."; - "publish_keypair_path" => config.publish_keypair_path.display(), "error" => e.to_string()); + tracing::warn!( + error = ?e, + publish_keypair_path = config.publish_keypair_path.display().to_string(), + "Reading publish keypair returned an error. Waiting for a remote-loaded key before publishing.", + ); None } }; diff --git a/src/agent/solana/exporter.rs b/src/agent/solana/exporter.rs index 6818a4e6..870fa87f 100644 --- a/src/agent/solana/exporter.rs +++ b/src/agent/solana/exporter.rs @@ -1,24 +1,18 @@ use { self::transaction_monitor::TransactionMonitor, super::{ - super::store::PriceIdentifier, key_store, network::Network, oracle::PricePublishingMetadata, }, - crate::agent::{ - remote_keypair_loader::{ - KeypairRequest, - RemoteKeypairLoader, - }, - state::{ - global::GlobalStore, - local::{ - LocalStore, - PriceInfo, - }, - State, + crate::agent::state::{ + global::GlobalStore, + keypairs::Keypairs, + local::{ + LocalStore, + PriceInfo, }, + State, }, anyhow::{ anyhow, @@ -38,7 +32,6 @@ use { Deserialize, Serialize, }, - slog::Logger, solana_client::{ nonblocking::rpc_client::RpcClient, rpc_config::RpcSendTransactionConfig, @@ -180,9 +173,7 @@ pub fn spawn_exporter( HashMap>, >, key_store: KeyStore, - keypair_request_tx: mpsc::Sender, - logger: Logger, - adapter: Arc, + state: Arc, ) -> Result>> { // Create and spawn the network state querier let (network_state_tx, network_state_rx) = watch::channel(Default::default()); @@ -191,7 +182,6 @@ pub fn spawn_exporter( rpc_timeout, time::interval(config.refresh_network_state_interval_duration), network_state_tx, - logger.clone(), ); let network_state_querier_jh = tokio::spawn(async move { network_state_querier.run().await }); @@ -203,7 +193,6 @@ pub fn spawn_exporter( rpc_url, rpc_timeout, transactions_rx, - logger.clone(), ); let transaction_monitor_jh = tokio::spawn(async move { transaction_monitor.run().await }); @@ -217,9 +206,7 @@ pub fn spawn_exporter( network_state_rx, transactions_tx, publisher_permissions_rx, - keypair_request_tx, - logger, - adapter, + state, ); let exporter_jh = tokio::spawn(async move { exporter.run().await }); @@ -249,7 +236,7 @@ pub struct Exporter { /// The last state published for each price identifier. Used to /// rule out stale data and prevent repetitive publishing of /// unchanged prices. - last_published_state: HashMap, + last_published_state: HashMap, /// Watch receiver channel to access the current network state network_state_rx: watch::Receiver, @@ -270,11 +257,7 @@ pub struct Exporter { /// Recent compute unit price in micro lamports (set if dynamic compute unit pricing is enabled) recent_compute_unit_price_micro_lamports: Option, - keypair_request_tx: Sender, - - logger: Logger, - - adapter: Arc, + state: Arc, } impl Exporter { @@ -289,9 +272,7 @@ impl Exporter { publisher_permissions_rx: watch::Receiver< HashMap>, >, - keypair_request_tx: mpsc::Sender, - logger: Logger, - adapter: Arc, + state: Arc, ) -> Self { let publish_interval = time::interval(config.publish_interval_duration); Exporter { @@ -312,9 +293,7 @@ impl Exporter { time::Duration::from_secs(1), ), recent_compute_unit_price_micro_lamports: None, - keypair_request_tx, - logger, - adapter, + state, } } @@ -323,15 +302,13 @@ impl Exporter { tokio::select! { _ = self.publish_interval.tick() => { if let Err(err) = self.publish_updates().await { - error!(self.logger, "{}", err); - debug!(self.logger, "error context"; "context" => format!("{:?}", err)); + tracing::error!(err = ?err, "Exporter failed to publish."); } } _ = self.dynamic_compute_unit_price_update_interval.tick() => { if self.config.dynamic_compute_unit_pricing_enabled { if let Err(err) = self.update_recent_compute_unit_price().await { - error!(self.logger, "{}", err); - debug!(self.logger, "error context"; "context" => format!("{:?}", err)); + tracing::error!(err = ?err, "Exporter failed to compute unit price."); } } } @@ -418,17 +395,14 @@ impl Exporter { // keypairs it does not have. Currently expressed in // handle_key_requests() in remote_keypair_loader.rs - debug!( - self.logger, - "Exporter: Publish keypair is None, requesting remote loaded key" - ); - let kp = RemoteKeypairLoader::request_keypair(&self.keypair_request_tx).await?; - debug!(self.logger, "Exporter: Keypair received"); + tracing::debug!("Exporter: Publish keypair is None, requesting remote loaded key"); + let kp = Keypairs::request_keypair(&*self.state, self.network).await?; + tracing::debug!("Exporter: Keypair received"); Ok(kp) } } - async fn get_permissioned_updates(&mut self) -> Result> { + async fn get_permissioned_updates(&mut self) -> Result> { let local_store_contents = self.fetch_local_store_contents().await?; let publish_keypair = self.get_publish_keypair().await?; @@ -436,9 +410,10 @@ impl Exporter { let now = Utc::now().naive_utc(); - debug!(self.logger, "Exporter: filtering prices permissioned to us"; - "our_prices" => format!("{:?}", self.our_prices.keys()), - "publish_pubkey" => publish_keypair.pubkey().to_string(), + tracing::debug!( + our_prices = ?self.our_prices.keys(), + publish_pubkey = publish_keypair.pubkey().to_string(), + "Exporter: filtering prices permissioned to us", ); // Filter the contents to only include information we haven't already sent, @@ -471,10 +446,11 @@ impl Exporter { let ret = publisher_permission.schedule.can_publish_at(&now_utc); if !ret { - debug!(self.logger, "Exporter: Attempted to publish price outside market hours"; - "price_account" => key_from_id.to_string(), - "schedule" => format!("{:?}", publisher_permission.schedule), - "utc_time" => now_utc.format("%c").to_string(), + tracing::debug!( + price_account = key_from_id.to_string(), + schedule = ?publisher_permission.schedule, + utc_time = now_utc.format("%c").to_string(), + "Exporter: Attempted to publish price outside market hours", ); } @@ -483,11 +459,10 @@ impl Exporter { // Note: This message is not an error. Some // publishers have different permissions on // primary/secondary networks - debug!( - self.logger, - "Exporter: Attempted to publish a price without permission, skipping"; - "unpermissioned_price_account" => key_from_id.to_string(), - "permissioned_accounts" => format!("{:?}", self.our_prices) + tracing::debug!( + unpermissioned_price_account = key_from_id.to_string(), + permissioned_accounts = ?self.our_prices, + "Exporter: Attempted to publish a price without permission, skipping", ); false } @@ -581,11 +556,10 @@ impl Exporter { Ok(true) => {} Ok(false) => return, Err(other) => { - warn!( - self.logger, - "Exporter: Updating permissioned price accounts failed unexpectedly, using cached value"; - "cached_value" => format!("{:?}", self.our_prices), - "error" => other.to_string(), + tracing::warn!( + cached_value = ?self.our_prices, + error = other.to_string(), + "Exporter: Updating permissioned price accounts failed unexpectedly, using cached value", ); return; } @@ -597,17 +571,16 @@ impl Exporter { .get(publish_pubkey) .cloned() .unwrap_or_else(|| { - warn!( - self.logger, - "Exporter: No permissioned prices were found for the publishing keypair on-chain. This is expected only on startup."; - "publish_pubkey" => publish_pubkey.to_string(), + tracing::warn!( + publish_pubkey = publish_pubkey.to_string(), + "Exporter: No permissioned prices were found for the publishing keypair on-chain. This is expected only on startup.", ); HashMap::new() }); } - async fn fetch_local_store_contents(&self) -> Result> { - Ok(LocalStore::get_all_price_infos(&*self.adapter).await) + async fn fetch_local_store_contents(&self) -> Result> { + Ok(LocalStore::get_all_price_infos(&*self.state).await) } async fn publish_batch(&self, batch: &[(Identifier, PriceInfo)]) -> Result<()> { @@ -711,7 +684,7 @@ impl Exporter { // in this batch. This will use the maximum total compute unit fee if the publisher // hasn't updated for >= MAXIMUM_SLOT_GAP_FOR_DYNAMIC_COMPUTE_UNIT_PRICE slots. let result = GlobalStore::price_accounts( - &*self.adapter, + &*self.state, self.network, price_accounts.clone().into_iter().collect(), ) @@ -762,7 +735,10 @@ impl Exporter { compute_unit_price_micro_lamports = compute_unit_price_micro_lamports .min(self.config.maximum_compute_unit_price_micro_lamports); - debug!(self.logger, "setting compute unit price"; "unit_price" => compute_unit_price_micro_lamports); + tracing::debug!( + unit_price = compute_unit_price_micro_lamports, + "setting compute unit price", + ); instructions.push(ComputeBudgetInstruction::set_compute_unit_price( compute_unit_price_micro_lamports, )); @@ -776,7 +752,6 @@ impl Exporter { ); let tx = self.inflight_transactions_tx.clone(); - let logger = self.logger.clone(); let rpc_client = self.rpc_client.clone(); // Fire this off in a separate task so we don't block the main thread of the exporter @@ -793,17 +768,20 @@ impl Exporter { { Ok(signature) => signature, Err(err) => { - error!(logger, "{}", err); - debug!(logger, "error context"; "context" => format!("{:?}", err)); + tracing::error!(err = ?err, "Exporter: failed to send transaction."); return; } }; - debug!(logger, "sent upd_price transaction"; "signature" => signature.to_string(), "instructions" => instructions.len(), "price_accounts" => format!("{:?}", price_accounts)); + tracing::debug!( + signature = signature.to_string(), + instructions = instructions.len(), + price_accounts = ?price_accounts, + "Sent upd_price transaction.", + ); if let Err(err) = tx.send(signature).await { - error!(logger, "{}", err); - debug!(logger, "error context"; "context" => format!("{:?}", err)); + tracing::error!(err = ?err, "Exporter failed to send signature to transaction monitor"); } }); @@ -958,9 +936,6 @@ struct NetworkStateQuerier { /// Channel the current network state is sent on network_state_tx: watch::Sender, - - /// Logger - logger: Logger, } impl NetworkStateQuerier { @@ -969,13 +944,11 @@ impl NetworkStateQuerier { rpc_timeout: Duration, query_interval: Interval, network_state_tx: watch::Sender, - logger: Logger, ) -> Self { NetworkStateQuerier { rpc_client: RpcClient::new_with_timeout(rpc_endpoint.to_string(), rpc_timeout), query_interval, network_state_tx, - logger, } } @@ -983,8 +956,7 @@ impl NetworkStateQuerier { loop { self.query_interval.tick().await; if let Err(err) = self.query_network_state().await { - error!(self.logger, "{}", err); - debug!(self.logger, "error context"; "context" => format!("{:?}", err)); + tracing::error!(err = ?err, "Network state query failed"); } } } @@ -1016,7 +988,6 @@ mod transaction_monitor { Deserialize, Serialize, }, - slog::Logger, solana_client::nonblocking::rpc_client::RpcClient, solana_sdk::{ commitment_config::CommitmentConfig, @@ -1073,8 +1044,6 @@ mod transaction_monitor { /// Interval with which to poll the status of transactions poll_interval: Interval, - - logger: Logger, } impl TransactionMonitor { @@ -1083,7 +1052,6 @@ mod transaction_monitor { rpc_url: &str, rpc_timeout: Duration, transactions_rx: mpsc::Receiver, - logger: Logger, ) -> Self { let poll_interval = time::interval(config.poll_interval_duration); let rpc_client = RpcClient::new_with_timeout(rpc_url.to_string(), rpc_timeout); @@ -1093,15 +1061,13 @@ mod transaction_monitor { sent_transactions: VecDeque::new(), transactions_rx, poll_interval, - logger, } } pub async fn run(&mut self) { loop { if let Err(err) = self.handle_next().await { - error!(self.logger, "{}", err); - debug!(self.logger, "error context"; "context" => format!("{:?}", err)); + tracing::error!(err = ?err, "Transaction monitor failed."); } } } @@ -1119,7 +1085,10 @@ mod transaction_monitor { } fn add_transaction(&mut self, signature: Signature) { - debug!(self.logger, "monitoring new transaction"; "signature" => signature.to_string()); + tracing::debug!( + signature = signature.to_string(), + "Monitoring new transaction.", + ); // Add the new transaction to the list self.sent_transactions.push_back(signature); @@ -1144,7 +1113,10 @@ mod transaction_monitor { .await? .value; - debug!(self.logger, "Processing Signature Statuses"; "statuses" => format!("{:?}", statuses)); + tracing::debug!( + statuses = ?statuses, + "Processing Signature Statuses", + ); // Determine the percentage of the recently sent transactions that have successfully been committed // TODO: expose as metric @@ -1155,10 +1127,11 @@ mod transaction_monitor { .flatten() .filter(|(status, sig)| { if let Some(err) = status.err.as_ref() { - warn!(self.logger, "TX status has err value"; - "error" => err.to_string(), - "tx_signature" => sig.to_string(), - ) + tracing::warn!( + error = err.to_string(), + tx_signature = sig.to_string(), + "TX status has err value", + ); } status.satisfies_commitment(CommitmentConfig::confirmed()) @@ -1166,7 +1139,11 @@ mod transaction_monitor { .count(); let percentage_confirmed = ((confirmed as f64) / (self.sent_transactions.len() as f64)) * 100.0; - info!(self.logger, "monitoring transaction hit rate"; "percentage confirmed" => format!("{:.}", percentage_confirmed)); + + tracing::info!( + percentage_confirmed = format!("{:.}", percentage_confirmed), + "monitoring transaction hit rate", + ); Ok(()) } diff --git a/src/agent/solana/oracle.rs b/src/agent/solana/oracle.rs index 81408079..ff7c1f42 100644 --- a/src/agent/solana/oracle.rs +++ b/src/agent/solana/oracle.rs @@ -10,10 +10,8 @@ use { legacy_schedule::LegacySchedule, market_schedule::MarketSchedule, state::{ - global::{ - GlobalStore, - Update, - }, + global::Update, + Prices, State, }, }, @@ -34,7 +32,6 @@ use { Deserialize, Serialize, }, - slog::Logger, solana_client::nonblocking::rpc_client::RpcClient, solana_sdk::{ account::Account, @@ -178,9 +175,7 @@ pub struct Oracle { network: Network, - logger: Logger, - - adapter: Arc, + state: Arc, } #[derive(Clone, Serialize, Deserialize, Debug)] @@ -229,8 +224,7 @@ pub fn spawn_oracle( HashMap>, >, key_store: KeyStore, - logger: Logger, - adapter: Arc, + state: Arc, ) -> Vec> { let mut jhs = vec![]; @@ -242,7 +236,6 @@ pub fn spawn_oracle( config.commitment, key_store.program_key, updates_tx, - logger.clone(), ); jhs.push(tokio::spawn(async move { subscriber.run().await })); } @@ -258,12 +251,11 @@ pub fn spawn_oracle( config.poll_interval_duration, config.max_lookup_batch_size, key_store.mapping_key, - logger.clone(), ); jhs.push(tokio::spawn(async move { poller.run().await })); // Create and spawn the Oracle - let mut oracle = Oracle::new(data_rx, updates_rx, network, logger, adapter); + let mut oracle = Oracle::new(data_rx, updates_rx, network, state); jhs.push(tokio::spawn(async move { oracle.run().await })); jhs @@ -274,24 +266,21 @@ impl Oracle { data_rx: mpsc::Receiver, updates_rx: mpsc::Receiver<(Pubkey, solana_sdk::account::Account)>, network: Network, - logger: Logger, - adapter: Arc, + state: Arc, ) -> Self { Oracle { data: Default::default(), data_rx, updates_rx, network, - logger, - adapter, + state, } } pub async fn run(&mut self) { loop { if let Err(err) = self.handle_next().await { - error!(self.logger, "{}", err); - debug!(self.logger, "error context"; "context" => format!("{:?}", err)); + tracing::error!(err = ?err, "Oracle failed to handle next update."); } } } @@ -316,33 +305,45 @@ impl Oracle { .keys() .cloned() .collect::>(); - info!(self.logger, "fetched mapping accounts"; "new" => format!("{:?}", data + tracing::info!( + new = ?data .mapping_accounts .keys() .cloned() - .collect::>().difference(&previous_mapping_accounts)), "total" => data.mapping_accounts.len()); + .collect::>().difference(&previous_mapping_accounts), + total = data.mapping_accounts.len(), + "Fetched mapping accounts." + ); let previous_product_accounts = self .data .product_accounts .keys() .cloned() .collect::>(); - info!(self.logger, "fetched product accounts"; "new" => format!("{:?}", data + tracing::info!( + new = ?data .product_accounts .keys() .cloned() - .collect::>().difference(&previous_product_accounts)), "total" => data.product_accounts.len()); + .collect::>().difference(&previous_product_accounts), + total = data.product_accounts.len(), + "Fetched product accounts.", + ); let previous_price_accounts = self .data .price_accounts .keys() .cloned() .collect::>(); - info!(self.logger, "fetched price accounts"; "new" => format!("{:?}", data + tracing::info!( + new = ?data .price_accounts .keys() .cloned() - .collect::>().difference(&previous_price_accounts)), "total" => data.price_accounts.len()); + .collect::>().difference(&previous_price_accounts), + total = data.price_accounts.len(), + "Fetched price accounts.", + ); let previous_publishers = self .data @@ -350,11 +351,10 @@ impl Oracle { .keys() .collect::>(); let new_publishers = data.publisher_permissions.keys().collect::>(); - info!( - self.logger, - "updated publisher permissions"; - "new_publishers" => format!("{:?}", new_publishers.difference(&previous_publishers).collect::>()), - "total_publishers" => new_publishers.len(), + tracing::info!( + new_publishers = ?new_publishers.difference(&previous_publishers).collect::>(), + total_publishers = new_publishers.len(), + "Updated publisher permissions.", ); // Update the data with the new data structs @@ -366,7 +366,7 @@ impl Oracle { account_key: &Pubkey, account: &Account, ) -> Result<()> { - debug!(self.logger, "handling account update"); + tracing::debug!("Handling account update."); // We are only interested in price account updates, all other types of updates // will be fetched using polling. @@ -385,7 +385,13 @@ impl Oracle { let price_entry = PriceEntry::load_from_account(&account.data) .with_context(|| format!("load price account {}", account_key))?; - debug!(self.logger, "observed on-chain price account update"; "pubkey" => account_key.to_string(), "price" => price_entry.agg.price, "conf" => price_entry.agg.conf, "status" => format!("{:?}", price_entry.agg.status)); + tracing::debug!( + pubkey = account_key.to_string(), + price = price_entry.agg.price, + conf = price_entry.agg.conf, + status = ?price_entry.agg.status, + "Observed on-chain price account update.", + ); self.data .price_accounts @@ -416,8 +422,8 @@ impl Oracle { account_key: &Pubkey, account: &ProductEntry, ) -> Result<()> { - GlobalStore::update( - &*self.adapter, + Prices::update_global_price( + &*self.state, self.network, &Update::ProductAccountUpdate { account_key: *account_key, @@ -433,8 +439,8 @@ impl Oracle { account_key: &Pubkey, account: &PriceEntry, ) -> Result<()> { - GlobalStore::update( - &*self.adapter, + Prices::update_global_price( + &*self.state, self.network, &Update::PriceAccountUpdate { account_key: *account_key, @@ -464,9 +470,6 @@ struct Poller { max_lookup_batch_size: usize, mapping_key: Pubkey, - - /// Logger - logger: Logger, } impl Poller { @@ -481,7 +484,6 @@ impl Poller { poll_interval_duration: Duration, max_lookup_batch_size: usize, mapping_key: Pubkey, - logger: Logger, ) -> Self { let rpc_client = RpcClient::new_with_timeout_and_commitment( rpc_url.to_string(), @@ -497,17 +499,15 @@ impl Poller { poll_interval, max_lookup_batch_size, mapping_key, - logger, } } pub async fn run(&mut self) { loop { self.poll_interval.tick().await; - info!(self.logger, "fetching all pyth account data"); + tracing::info!("Fetching all pyth account data."); if let Err(err) = self.poll_and_send().await { - error!(self.logger, "{}", err); - debug!(self.logger, "error context"; "context" => format!("{:?}", err)); + tracing::error!(err = ?err, "Oracle Poll/Send Failed."); } } } @@ -552,9 +552,10 @@ impl Poller { publish_interval: prod_entry.publish_interval.clone(), } } else { - warn!(&self.logger, "Oracle: INTERNAL: could not find product from price `prod` field, market hours falling back to 24/7."; - "price" => price_key.to_string(), - "missing_product" => price_entry.prod.to_string(), + tracing::warn!( + price = price_key.to_string(), + missing_product = price_entry.prod.to_string(), + "Oracle: INTERNAL: could not find product from price `prod` field, market hours falling back to 24/7.", ); Default::default() }; @@ -651,13 +652,12 @@ impl Poller { product.iter().find(|(k, _v)| *k == "weekly_schedule") { wsched_val.parse().unwrap_or_else(|err| { - warn!( - self.logger, - "Oracle: Product has weekly_schedule defined but it could not be parsed. Falling back to 24/7 publishing."; - "product_key" => product_key.to_string(), - "weekly_schedule" => wsched_val, + tracing::warn!( + product_key = product_key.to_string(), + weekly_schedule = wsched_val, + "Oracle: Product has weekly_schedule defined but it could not be parsed. Falling back to 24/7 publishing.", ); - debug!(self.logger, "parsing error context"; "context" => format!("{:?}", err)); + tracing::debug!(err = ?err, "Parsing error context."); Default::default() }) } else { @@ -673,13 +673,12 @@ impl Poller { match msched_val.parse::() { Ok(schedule) => Some(schedule), Err(err) => { - warn!( - self.logger, - "Oracle: Product has schedule defined but it could not be parsed. Falling back to legacy schedule."; - "product_key" => product_key.to_string(), - "schedule" => msched_val, + tracing::warn!( + product_key = product_key.to_string(), + schedule = msched_val, + "Oracle: Product has schedule defined but it could not be parsed. Falling back to legacy schedule.", ); - debug!(self.logger, "parsing error context"; "context" => format!("{:?}", err)); + tracing::debug!(err = ?err, "Parsing error context."); None } } @@ -696,13 +695,12 @@ impl Poller { match publish_interval_val.parse::() { Ok(interval) => Some(Duration::from_secs_f64(interval)), Err(err) => { - warn!( - self.logger, - "Oracle: Product has publish_interval defined but it could not be parsed. Falling back to None."; - "product_key" => product_key.to_string(), - "publish_interval" => publish_interval_val, + tracing::warn!( + product_key = product_key.to_string(), + publish_interval = publish_interval_val, + "Oracle: Product has publish_interval defined but it could not be parsed. Falling back to None.", ); - debug!(self.logger, "parsing error context"; "context" => format!("{:?}", err)); + tracing::debug!(err = ?err, "parsing error context"); None } } @@ -720,8 +718,10 @@ impl Poller { }, ); } else { - warn!(self.logger, "Oracle: Could not find product on chain, skipping"; - "product_key" => product_key.to_string(),); + tracing::warn!( + product_key = product_key.to_string(), + "Oracle: Could not find product on chain, skipping", + ); } } @@ -757,9 +757,10 @@ impl Poller { prod.price_accounts.push(*price_key); price_entries.insert(*price_key, price); } else { - warn!(self.logger, "Could not find product entry for price, listed in its prod field, skipping"; - "missing_product" => price.prod.to_string(), - "price_key" => price_key.to_string(), + tracing::warn!( + missing_product = price.prod.to_string(), + price_key = price_key.to_string(), + "Could not find product entry for price, listed in its prod field, skipping", ); continue; @@ -769,7 +770,10 @@ impl Poller { next_todo.push(next_price); } } else { - warn!(self.logger, "Could not look up price account on chain, skipping"; "price_key" => price_key.to_string(),); + tracing::warn!( + price_key = price_key.to_string(), + "Could not look up price account on chain, skipping", + ); continue; } } @@ -786,7 +790,6 @@ mod subscriber { anyhow, Result, }, - slog::Logger, solana_account_decoder::UiAccountEncoding, solana_client::{ nonblocking::pubsub_client::PubsubClient, @@ -825,8 +828,6 @@ mod subscriber { /// Channel on which updates are sent updates_tx: mpsc::Sender<(Pubkey, solana_sdk::account::Account)>, - - logger: Logger, } impl Subscriber { @@ -835,14 +836,12 @@ mod subscriber { commitment: CommitmentLevel, program_key: Pubkey, updates_tx: mpsc::Sender<(Pubkey, solana_sdk::account::Account)>, - logger: Logger, ) -> Self { Subscriber { wss_url, commitment, program_key, updates_tx, - logger, } } @@ -850,13 +849,9 @@ mod subscriber { loop { let current_time = Instant::now(); if let Err(ref err) = self.start().await { - error!(self.logger, "{}", err); - debug!(self.logger, "error context"; "context" => format!("{:?}", err)); + tracing::error!(err = ?err, "Oracle exited unexpectedly."); if current_time.elapsed() < Duration::from_secs(30) { - warn!( - self.logger, - "Subscriber restarting too quickly. Sleeping for 1 second." - ); + tracing::warn!("Subscriber restarting too quickly. Sleeping for 1 second."); tokio::time::sleep(Duration::from_secs(1)).await; } } @@ -882,7 +877,10 @@ mod subscriber { .program_subscribe(&self.program_key, Some(config)) .await?; - debug!(self.logger, "subscribed to program account updates"; "program_key" => self.program_key.to_string()); + tracing::debug!( + program_key = self.program_key.to_string(), + "subscribed to program account updates", + ); loop { match tokio_stream::StreamExt::next(&mut notif).await { @@ -890,7 +888,10 @@ mod subscriber { let account: Account = match update.value.account.decode() { Some(account) => account, None => { - error!(self.logger, "Failed to decode account from update."; "update" => format!("{:?}", update)); + tracing::error!( + update = ?update, + "Failed to decode account from update.", + ); continue; } }; @@ -901,7 +902,7 @@ mod subscriber { .map_err(|_| anyhow!("failed to send update to oracle"))?; } None => { - debug!(self.logger, "subscriber closed connection"); + tracing::debug!("subscriber closed connection"); return Ok(()); } } diff --git a/src/agent/state.rs b/src/agent/state.rs index 50c0f2eb..99f7b6c5 100644 --- a/src/agent/state.rs +++ b/src/agent/state.rs @@ -1,35 +1,27 @@ use { - super::{ - pythd::api::{ + crate::agent::{ + metrics::PROMETHEUS_REGISTRY, + pyth::{ NotifyPrice, NotifyPriceSched, SubscriptionID, }, - store::PriceIdentifier, }, - crate::agent::metrics::PROMETHEUS_REGISTRY, serde::{ Deserialize, Serialize, }, - slog::Logger, - std::{ - collections::HashMap, - sync::atomic::AtomicI64, - time::Duration, - }, - tokio::sync::{ - mpsc, - RwLock, - }, + std::time::Duration, + tokio::sync::mpsc, }; pub mod api; pub mod global; +pub mod keypairs; pub mod local; pub use api::{ notifier, - StateApi, + Prices, }; #[derive(Clone, Serialize, Deserialize, Debug)] @@ -49,31 +41,19 @@ impl Default for Config { } } -/// Adapter is the adapter between the pythd websocket API, and the stores. -/// It is responsible for implementing the business logic for responding to -/// the pythd websocket API calls. +/// State contains all relevant shared application state. pub struct State { - /// Subscription ID sequencer. - subscription_id_seq: AtomicI64, - - /// Notify Price Sched subscriptions - notify_price_sched_subscriptions: - RwLock>>, - - // Notify Price Subscriptions - notify_price_subscriptions: RwLock>>, - - /// The fixed interval at which Notify Price Sched notifications are sent - notify_price_sched_interval_duration: Duration, - - /// The logger - logger: Logger, - /// Global store for managing the unified state of Pyth-on-Solana networks. global_store: global::Store, /// Local store for managing the unpushed state. local_store: local::Store, + + /// State for managing state of runtime keypairs. + keypairs: keypairs::KeypairState, + + /// State for Price related functionality. + prices: api::PricesState, } /// Represents a single Notify Price Sched subscription @@ -93,16 +73,13 @@ struct NotifyPriceSubscription { } impl State { - pub async fn new(config: Config, logger: Logger) -> Self { + pub async fn new(config: Config) -> Self { let registry = &mut *PROMETHEUS_REGISTRY.lock().await; State { - global_store: global::Store::new(logger.clone(), registry), - local_store: local::Store::new(logger.clone(), registry), - subscription_id_seq: 1.into(), - notify_price_sched_subscriptions: RwLock::new(HashMap::new()), - notify_price_subscriptions: RwLock::new(HashMap::new()), - notify_price_sched_interval_duration: config.notify_price_sched_interval_duration, - logger, + global_store: global::Store::new(registry), + local_store: local::Store::new(registry), + keypairs: keypairs::KeypairState::default(), + prices: api::PricesState::new(config), } } } @@ -117,11 +94,11 @@ mod tests { }, notifier, Config, + Prices, State, - StateApi, }, crate::agent::{ - pythd::api::{ + pyth::{ self, NotifyPrice, NotifyPriceSched, @@ -131,10 +108,16 @@ mod tests { ProductAccountMetadata, PublisherAccount, }, - solana, - state::local::LocalStore, + solana::{ + self, + network::Network, + oracle::PriceEntry, + }, + state::{ + global::Update, + local::LocalStore, + }, }, - iobuffer::IoBuffer, pyth_sdk::Identifier, pyth_sdk_solana::state::{ PriceComp, @@ -144,7 +127,6 @@ mod tests { Rational, SolanaPriceAccount, }, - slog_extlog::slog_test, std::{ collections::{ BTreeMap, @@ -163,27 +145,25 @@ mod tests { }, }; - struct TestAdapter { - adapter: Arc, + struct TestState { + state: Arc, shutdown_tx: broadcast::Sender<()>, jh: JoinHandle<()>, } - async fn setup() -> TestAdapter { - // Create and spawn an adapter + async fn setup() -> TestState { let notify_price_sched_interval_duration = Duration::from_nanos(10); - let logger = slog_test::new_test_logger(IoBuffer::new()); let config = Config { notify_price_sched_interval_duration, }; - let adapter = Arc::new(State::new(config, logger).await); + let state = Arc::new(State::new(config).await); let (shutdown_tx, _) = broadcast::channel(1); // Spawn Price Notifier - let jh = tokio::spawn(notifier(adapter.clone(), shutdown_tx.subscribe())); + let jh = tokio::spawn(notifier(state.clone())); - TestAdapter { - adapter, + TestState { + state, shutdown_tx, jh, } @@ -191,15 +171,15 @@ mod tests { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn test_subscribe_price_sched() { - let test_adapter = setup().await; + let state = setup().await; // Send a Subscribe Price Sched message let account = "2wrWGm63xWubz7ue4iYR3qvBbaUJhZVi4eSpNuU8k8iF" .parse::() .unwrap(); let (notify_price_sched_tx, mut notify_price_sched_rx) = mpsc::channel(1000); - let subscription_id = test_adapter - .adapter + let subscription_id = state + .state .subscribe_price_sched(&account, notify_price_sched_tx) .await; @@ -213,8 +193,8 @@ mod tests { ) } - let _ = test_adapter.shutdown_tx.send(()); - test_adapter.jh.abort(); + let _ = state.shutdown_tx.send(()); + state.jh.abort(); } fn get_test_all_accounts_metadata() -> global::AllAccountsMetadata { @@ -336,17 +316,16 @@ mod tests { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn test_get_product_list() { - // Start the test adapter - let test_adapter = setup().await; + let state = setup().await; let accounts_metadata = get_test_all_accounts_metadata(); - test_adapter - .adapter + state + .state .global_store ._account_metadata(accounts_metadata) .await; // Send a Get Product List message - let mut product_list = test_adapter.adapter.get_product_list().await.unwrap(); + let mut product_list = state.state.get_product_list().await.unwrap(); // Check that the result is what we expected let expected = vec![ @@ -416,8 +395,8 @@ mod tests { product_list.sort(); assert_eq!(product_list, expected); - let _ = test_adapter.shutdown_tx.send(()); - test_adapter.jh.abort(); + let _ = state.shutdown_tx.send(()); + state.jh.abort(); } fn pad_price_comps(mut inputs: Vec) -> [PriceComp; 32] { @@ -1043,21 +1022,20 @@ mod tests { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn test_get_all_products() { - // Start the test adapter - let test_adapter = setup().await; + let state = setup().await; let accounts_data = get_all_accounts_data(); - test_adapter - .adapter + state + .state .global_store ._account_data_primary(accounts_data) .await; // Send a Get All Products message - let mut all_products = test_adapter.adapter.get_all_products().await.unwrap(); + let mut all_products = state.state.get_all_products().await.unwrap(); // Check that the result of the conversion to the Pythd API format is what we expected let expected = vec![ - api::ProductAccount { + pyth::ProductAccount { account: "BjHoZWRxo9dgbR1NQhPyTiUs6xFiX6mGS4TMYvy3b2yc".to_string(), attr_dict: BTreeMap::from( [ @@ -1071,7 +1049,7 @@ mod tests { .map(|(k, v)| (k.to_string(), v.to_string())), ), price_accounts: vec![ - api::PriceAccount { + pyth::PriceAccount { account: "GG3FTE7xhc9Diy7dn9P6BWzoCrAEE4D3p5NBYrDAm5DD" .to_string(), price_type: "price".to_string(), @@ -1103,7 +1081,7 @@ mod tests { }, ], }, - api::PriceAccount { + pyth::PriceAccount { account: "fTNjSfj5uW9e4CAMHzUcm65ftRNBxCN1gG5GS1mYfid" .to_string(), price_type: "price".to_string(), @@ -1135,7 +1113,7 @@ mod tests { }, ], }, - api::PriceAccount { + pyth::PriceAccount { account: "GKNcUmNacSJo4S2Kq3DuYRYRGw3sNUfJ4tyqd198t6vQ" .to_string(), price_type: "price".to_string(), @@ -1160,7 +1138,7 @@ mod tests { }, ], }, - api::ProductAccount { + pyth::ProductAccount { account: "CkMrDWtmFJZcmAUC11qNaWymbXQKvnRx4cq1QudLav7t".to_string(), attr_dict: BTreeMap::from( [ @@ -1174,7 +1152,7 @@ mod tests { .map(|(k, v)| (k.to_string(), v.to_string())), ), price_accounts: vec![ - api::PriceAccount { + pyth::PriceAccount { account: "GVXRSBjFk6e6J3NbVPXohDJetcTjaeeuykUpbQF8UoMU" .to_string(), price_type: "price".to_string(), @@ -1191,7 +1169,7 @@ mod tests { prev_conf: 398674, publisher_accounts: vec![], }, - api::PriceAccount { + pyth::PriceAccount { account: "3VQwtcntVQN1mj1MybQw8qK7Li3KNrrgNskSQwZAPGNr" .to_string(), price_type: "price".to_string(), @@ -1214,7 +1192,7 @@ mod tests { slot: 14765, }], }, - api::PriceAccount { + pyth::PriceAccount { account: "2V7t5NaKY7aGkwytCWQgvUYZfEr9XMwNChhJEakTExk6" .to_string(), price_type: "price".to_string(), @@ -1252,17 +1230,16 @@ mod tests { all_products.sort(); assert_eq!(all_products, expected); - let _ = test_adapter.shutdown_tx.send(()); - test_adapter.jh.abort(); + let _ = state.shutdown_tx.send(()); + state.jh.abort(); } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn test_get_product() { - // Start the test adapter - let test_adapter = setup().await; + let state = setup().await; let accounts_data = get_all_accounts_data(); - test_adapter - .adapter + state + .state .global_store ._account_data_primary(accounts_data) .await; @@ -1271,13 +1248,13 @@ mod tests { let account = "CkMrDWtmFJZcmAUC11qNaWymbXQKvnRx4cq1QudLav7t" .parse::() .unwrap(); - let product = test_adapter.adapter.get_product(&account).await.unwrap(); + let product = state.state.get_product(&account).await.unwrap(); // Check that the result of the conversion to the Pythd API format is what we expected let expected = ProductAccount { account: account.to_string(), price_accounts: vec![ - api::PriceAccount { + pyth::PriceAccount { account: "GVXRSBjFk6e6J3NbVPXohDJetcTjaeeuykUpbQF8UoMU".to_string(), price_type: "price".to_string(), price_exponent: -8, @@ -1293,7 +1270,7 @@ mod tests { prev_conf: 398674, publisher_accounts: vec![], }, - api::PriceAccount { + pyth::PriceAccount { account: "3VQwtcntVQN1mj1MybQw8qK7Li3KNrrgNskSQwZAPGNr".to_string(), price_type: "price".to_string(), price_exponent: -10, @@ -1315,7 +1292,7 @@ mod tests { slot: 14765, }], }, - api::PriceAccount { + pyth::PriceAccount { account: "2V7t5NaKY7aGkwytCWQgvUYZfEr9XMwNChhJEakTExk6".to_string(), price_type: "price".to_string(), price_exponent: -6, @@ -1361,14 +1338,13 @@ mod tests { }; assert_eq!(product, expected); - let _ = test_adapter.shutdown_tx.send(()); - test_adapter.jh.abort(); + let _ = state.shutdown_tx.send(()); + state.jh.abort(); } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn test_update_price() { - // Start the test adapter - let test_adapter = setup().await; + let state = setup().await; // Send an Update Price message let account = "CkMrDWtmFJZcmAUC11qNaWymbXQKvnRx4cq1QudLav7t" @@ -1376,14 +1352,14 @@ mod tests { .unwrap(); let price = 2365; let conf = 98754; - let _ = test_adapter - .adapter - .update_price(&account, price, conf, "trading".to_string()) + let _ = state + .state + .update_local_price(&account, price, conf, "trading".to_string()) .await .unwrap(); // Check that the local store indeed received the correct update - let price_infos = LocalStore::get_all_price_infos(&*test_adapter.adapter).await; + let price_infos = LocalStore::get_all_price_infos(&*state.state).await; let price_info = price_infos .get(&Identifier::new(account.to_bytes())) .unwrap(); @@ -1392,60 +1368,101 @@ mod tests { assert_eq!(price_info.conf, conf); assert_eq!(price_info.status, PriceStatus::Trading); - let _ = test_adapter.shutdown_tx.send(()); - test_adapter.jh.abort(); + let _ = state.shutdown_tx.send(()); + state.jh.abort(); } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn test_subscribe_notify_price() { - // Start the test adapter - let test_adapter = setup().await; + let state = setup().await; // Send a Subscribe Price message let account = "2wrWGm63xWubz7ue4iYR3qvBbaUJhZVi4eSpNuU8k8iF" .parse::() .unwrap(); let (notify_price_tx, mut notify_price_rx) = mpsc::channel(1000); - let subscription_id = test_adapter - .adapter - .subscribe_price(&account, notify_price_tx) - .await; + let subscription_id = state.state.subscribe_price(&account, notify_price_tx).await; + + // Send an update via the global store. + let test_price: PriceEntry = SolanaPriceAccount { + magic: 0xa1b2c3d4, + ver: 7, + atype: 9, + size: 300, + ptype: PriceType::Price, + expo: -8, + num: 8794, + num_qt: 32, + last_slot: 172761888, + valid_slot: 310, + ema_price: Rational { + val: 5882210200, + numer: 921349408, + denom: 1566332030, + }, + ema_conf: Rational { + val: 1422289, + numer: 2227777916, + denom: 1566332030, + }, + timestamp: 1667333704, + min_pub: 23, + drv2: 0xde, + drv3: 0xdeed, + drv4: 0xdeeed, + prod: solana_sdk::pubkey::Pubkey::from_str( + "CkMrDWtmFJZcmAUC11qNaWymbXQKvnRx4cq1QudLav7t", + ) + .unwrap(), + next: solana_sdk::pubkey::Pubkey::from_str( + "3VQwtcntVQN1mj1MybQw8qK7Li3KNrrgNskSQwZAPGNr", + ) + .unwrap(), + prev_slot: 172761778, + prev_price: 22691000, + prev_conf: 398674, + prev_timestamp: 1667333702, + agg: PriceInfo { + price: 736382, + conf: 85623946, + status: pyth_sdk_solana::state::PriceStatus::Trading, + corp_act: pyth_sdk_solana::state::CorpAction::NoCorpAct, + pub_slot: 7262746, + }, + comp: [PriceComp::default(); 32], + extended: (), + } + .into(); - // Send an update from the global store to the adapter - let price = 52162; - let conf = 1646; - let valid_slot = 75684; - let pub_slot = 32565; - let _ = test_adapter - .adapter - .global_store_update( - Identifier::new(account.to_bytes()), - price, - conf, - PriceStatus::Trading, - valid_slot, - pub_slot, + let _ = state + .state + .update_global_price( + Network::Primary, + &Update::PriceAccountUpdate { + account_key: account, + account: test_price.into(), + }, ) .await .unwrap(); - // Check that the adapter sends a notify price message with the corresponding subscription id + // Check that the application sends a notify price message with the corresponding subscription id // to the expected channel. assert_eq!( notify_price_rx.recv().await.unwrap(), NotifyPrice { subscription: subscription_id, result: PriceUpdate { - price, - conf, - status: "trading".to_string(), - valid_slot, - pub_slot + price: test_price.agg.price, + conf: test_price.agg.conf, + status: "trading".to_string(), + valid_slot: test_price.valid_slot, + pub_slot: test_price.agg.pub_slot, }, } ); - let _ = test_adapter.shutdown_tx.send(()); - test_adapter.jh.abort(); + let _ = state.shutdown_tx.send(()); + state.jh.abort(); } } diff --git a/src/agent/state/api.rs b/src/agent/state/api.rs index a117b7ab..dfa6e5ae 100644 --- a/src/agent/state/api.rs +++ b/src/agent/state/api.rs @@ -1,39 +1,39 @@ use { super::{ + super::{ + pyth::{ + Conf, + NotifyPrice, + NotifyPriceSched, + Price, + PriceAccount, + PriceAccountMetadata, + PriceUpdate, + ProductAccount, + ProductAccountMetadata, + PublisherAccount, + SubscriptionID, + }, + solana::{ + self, + network::Network, + oracle::PriceEntry, + }, + }, global::{ AllAccountsData, - AllAccountsMetadata, GlobalStore, + Update, }, local::{ self, LocalStore, }, + Config, NotifyPriceSchedSubscription, NotifyPriceSubscription, State, }, - crate::agent::{ - pythd::api::{ - Conf, - NotifyPrice, - NotifyPriceSched, - Price, - PriceAccount, - PriceAccountMetadata, - PriceUpdate, - ProductAccount, - ProductAccountMetadata, - PublisherAccount, - SubscriptionID, - }, - solana::{ - self, - network::Network, - oracle::PriceEntry, - }, - store::PriceIdentifier, - }, anyhow::{ anyhow, Result, @@ -44,10 +44,17 @@ use { PriceComp, PriceStatus, }, - std::sync::Arc, + std::{ + collections::HashMap, + sync::{ + atomic::AtomicI64, + Arc, + }, + time::Duration, + }, tokio::sync::{ - broadcast, mpsc, + RwLock, }, }; @@ -129,12 +136,32 @@ fn solana_price_account_to_pythd_api_price_account( } } +type PriceSubscriptions = HashMap>; +type PriceSchedSubscribtions = HashMap>; + +#[derive(Default)] +pub struct PricesState { + subscription_id_seq: AtomicI64, + notify_price_sched_interval_duration: Duration, + notify_price_subscriptions: RwLock, + notify_price_sched_subscriptions: RwLock, +} + +impl PricesState { + pub fn new(config: Config) -> Self { + Self { + subscription_id_seq: 1.into(), + notify_price_sched_interval_duration: config.notify_price_sched_interval_duration, + notify_price_subscriptions: Default::default(), + notify_price_sched_subscriptions: Default::default(), + } + } +} + #[async_trait::async_trait] -pub trait StateApi { +pub trait Prices { async fn get_product_list(&self) -> Result>; - async fn lookup_all_accounts_metadata(&self) -> Result; async fn get_all_products(&self) -> Result>; - async fn lookup_all_accounts_data(&self) -> Result; async fn get_product( &self, product_account_key: &solana_sdk::pubkey::Pubkey, @@ -152,50 +179,35 @@ pub trait StateApi { ) -> SubscriptionID; async fn send_notify_price_sched(&self) -> Result<()>; async fn drop_closed_subscriptions(&self); - async fn update_price( + async fn update_local_price( &self, account: &solana_sdk::pubkey::Pubkey, price: Price, conf: Conf, status: String, ) -> Result<()>; + async fn update_global_price(&self, network: Network, update: &Update) -> Result<()>; // TODO: implement FromStr method on PriceStatus fn map_status(status: &str) -> Result; - async fn global_store_update( - &self, - price_identifier: PriceIdentifier, - price: i64, - conf: u64, - status: PriceStatus, - valid_slot: u64, - pub_slot: u64, - ) -> Result<()>; } -pub async fn notifier(adapter: Arc, mut shutdown_rx: broadcast::Receiver<()>) { - let mut interval = tokio::time::interval(adapter.notify_price_sched_interval_duration); - loop { - adapter.drop_closed_subscriptions().await; - tokio::select! { - _ = shutdown_rx.recv() => { - info!(adapter.logger, "shutdown signal received"); - return; - } - _ = interval.tick() => { - if let Err(err) = adapter.send_notify_price_sched().await { - error!(adapter.logger, "{}", err); - debug!(adapter.logger, "error context"; "context" => format!("{:?}", err)); - } - } - } +// Allow downcasting State into Keypairs for functions that depend on the `Keypairs` service. +impl<'a> From<&'a State> for &'a PricesState { + fn from(state: &'a State) -> &'a PricesState { + &state.prices } } #[async_trait::async_trait] -impl StateApi for State { +impl Prices for T +where + for<'a> &'a T: Into<&'a PricesState>, + T: GlobalStore, + T: LocalStore, + T: Sync, +{ async fn get_product_list(&self) -> Result> { - let all_accounts_metadata = self.lookup_all_accounts_metadata().await?; - + let all_accounts_metadata = GlobalStore::accounts_metadata(self).await?; let mut result = Vec::new(); for (product_account_key, product_account) in all_accounts_metadata.product_accounts_metadata @@ -229,14 +241,8 @@ impl StateApi for State { Ok(result) } - // Fetches the Solana-specific Oracle data from the global store - async fn lookup_all_accounts_metadata(&self) -> Result { - GlobalStore::accounts_metadata(self).await - } - async fn get_all_products(&self) -> Result> { - let solana_data = self.lookup_all_accounts_data().await?; - + let solana_data = GlobalStore::accounts_data(self, Network::Primary).await?; let mut result = Vec::new(); for (product_account_key, product_account) in &solana_data.product_accounts { let product_account_api = solana_product_account_to_pythd_api_product_account( @@ -251,15 +257,11 @@ impl StateApi for State { Ok(result) } - async fn lookup_all_accounts_data(&self) -> Result { - GlobalStore::accounts_data(self, Network::Primary).await - } - async fn get_product( &self, product_account_key: &solana_sdk::pubkey::Pubkey, ) -> Result { - let all_accounts_data = self.lookup_all_accounts_data().await?; + let all_accounts_data = GlobalStore::accounts_data(self, Network::Primary).await?; // Look up the product account let product_account = all_accounts_data @@ -280,7 +282,8 @@ impl StateApi for State { notify_price_sched_tx: mpsc::Sender, ) -> SubscriptionID { let subscription_id = self.next_subscription_id(); - self.notify_price_sched_subscriptions + self.into() + .notify_price_sched_subscriptions .write() .await .entry(Identifier::new(account_pubkey.to_bytes())) @@ -293,7 +296,8 @@ impl StateApi for State { } fn next_subscription_id(&self) -> SubscriptionID { - self.subscription_id_seq + self.into() + .subscription_id_seq .fetch_add(1, std::sync::atomic::Ordering::SeqCst) } @@ -303,7 +307,8 @@ impl StateApi for State { notify_price_tx: mpsc::Sender, ) -> SubscriptionID { let subscription_id = self.next_subscription_id(); - self.notify_price_subscriptions + self.into() + .notify_price_subscriptions .write() .await .entry(Identifier::new(account.to_bytes())) @@ -317,6 +322,7 @@ impl StateApi for State { async fn send_notify_price_sched(&self) -> Result<()> { for subscription in self + .into() .notify_price_sched_subscriptions .read() .await @@ -325,7 +331,7 @@ impl StateApi for State { { // Send the notify price sched update without awaiting. This results in raising errors // if the channel is full which normally should not happen. This is because we do not - // want to block the adapter if the channel is full. + // want to block the API if the channel is full. subscription .notify_price_sched_tx .try_send(NotifyPriceSched { @@ -337,11 +343,18 @@ impl StateApi for State { } async fn drop_closed_subscriptions(&self) { - for subscriptions in self.notify_price_subscriptions.write().await.values_mut() { + for subscriptions in self + .into() + .notify_price_subscriptions + .write() + .await + .values_mut() + { subscriptions.retain(|subscription| !subscription.notify_price_tx.is_closed()) } for subscriptions in self + .into() .notify_price_sched_subscriptions .write() .await @@ -351,7 +364,7 @@ impl StateApi for State { } } - async fn update_price( + async fn update_local_price( &self, account: &solana_sdk::pubkey::Pubkey, price: Price, @@ -372,6 +385,54 @@ impl StateApi for State { .map_err(|_| anyhow!("failed to send update to local store")) } + async fn update_global_price(&self, network: Network, update: &Update) -> Result<()> { + GlobalStore::update(self, network, update) + .await + .map_err(|_| anyhow!("failed to send update to global store"))?; + + // Additionally, if the update is for a PriceAccount, we can notify our + // subscribers that the account has changed. We only notify when this is + // an update by the primary network as the account data might differ on + // the secondary network. + match (network, update) { + ( + Network::Primary, + Update::PriceAccountUpdate { + account_key, + account, + }, + ) => { + // Look up any subcriptions associated with the price identifier + let empty = Vec::new(); + let subscriptions = self.into().notify_price_subscriptions.read().await; + let subscriptions = subscriptions + .get(&Identifier::new(account_key.to_bytes())) + .unwrap_or(&empty); + + // Send the Notify Price update to each subscription + for subscription in subscriptions { + // Send the notify price update without awaiting. This results in raising errors if the + // channel is full which normally should not happen. This is because we do not want to + // block the APIO if the channel is full. + subscription.notify_price_tx.try_send(NotifyPrice { + subscription: subscription.subscription_id, + result: PriceUpdate { + price: account.agg.price, + conf: account.agg.conf, + status: price_status_to_str(account.agg.status), + valid_slot: account.valid_slot, + pub_slot: account.agg.pub_slot, + }, + })?; + } + + Ok(()) + } + + _ => Ok(()), + } + } + // TODO: implement FromStr method on PriceStatus fn map_status(status: &str) -> Result { match status { @@ -383,38 +444,28 @@ impl StateApi for State { _ => Err(anyhow!("invalid price status: {:#?}", status)), } } +} - async fn global_store_update( - &self, - price_identifier: PriceIdentifier, - price: i64, - conf: u64, - status: PriceStatus, - valid_slot: u64, - pub_slot: u64, - ) -> Result<()> { - // Look up any subcriptions associated with the price identifier - let empty = Vec::new(); - let subscriptions = self.notify_price_subscriptions.read().await; - let subscriptions = subscriptions.get(&price_identifier).unwrap_or(&empty); - - // Send the Notify Price update to each subscription - for subscription in subscriptions { - // Send the notify price update without awaiting. This results in raising errors if the - // channel is full which normally should not happen. This is because we do not want to - // block the adapter if the channel is full. - subscription.notify_price_tx.try_send(NotifyPrice { - subscription: subscription.subscription_id, - result: PriceUpdate { - price, - conf, - status: price_status_to_str(status), - valid_slot, - pub_slot, - }, - })?; +pub async fn notifier(state: Arc) +where + for<'a> &'a S: Into<&'a PricesState>, + S: Prices, +{ + let prices: &PricesState = (&*state).into(); + let mut interval = tokio::time::interval(prices.notify_price_sched_interval_duration); + let mut exit = crate::agent::EXIT.subscribe(); + loop { + Prices::drop_closed_subscriptions(&*state).await; + tokio::select! { + _ = exit.changed() => { + tracing::info!("Shutdown signal received."); + return; + } + _ = interval.tick() => { + if let Err(err) = state.send_notify_price_sched().await { + tracing::error!(err = ?err, "Notifier: failed to send notify price sched."); + } + } } - - Ok(()) } } diff --git a/src/agent/state/global.rs b/src/agent/state/global.rs index 97c444d8..36033a76 100644 --- a/src/agent/state/global.rs +++ b/src/agent/state/global.rs @@ -16,15 +16,12 @@ use { ProductEntry, }, }, - state::StateApi, }, anyhow::{ anyhow, Result, }, prometheus_client::registry::Registry, - pyth_sdk::Identifier, - slog::Logger, solana_sdk::pubkey::Pubkey, std::collections::{ BTreeMap, @@ -118,20 +115,16 @@ pub struct Store { /// Prometheus metrics for prices price_metrics: PriceGlobalMetrics, - - /// Shared logger configuration. - logger: Logger, } impl Store { - pub fn new(logger: Logger, registry: &mut Registry) -> Self { + pub fn new(registry: &mut Registry) -> Self { Store { - account_data_primary: Default::default(), + account_data_primary: Default::default(), account_data_secondary: Default::default(), - account_metadata: Default::default(), - product_metrics: ProductGlobalMetrics::new(registry), - price_metrics: PriceGlobalMetrics::new(registry), - logger, + account_metadata: Default::default(), + product_metrics: ProductGlobalMetrics::new(registry), + price_metrics: PriceGlobalMetrics::new(registry), } } } @@ -164,10 +157,10 @@ pub trait GlobalStore { ) -> Result>; } -// Allow downcasting Adapter into GlobalStore for functions that depend on the `GlobalStore` service. +// Allow downcasting State into GlobalStore for functions that depend on the `GlobalStore` service. impl<'a> From<&'a State> for &'a Store { - fn from(adapter: &'a State) -> &'a Store { - &adapter.global_store + fn from(state: &'a State) -> &'a Store { + &state.global_store } } @@ -175,7 +168,6 @@ impl<'a> From<&'a State> for &'a Store { impl GlobalStore for T where for<'a> &'a T: Into<&'a Store>, - T: StateApi, T: Sync, { async fn update(&self, network: Network, update: &Update) -> Result<()> { @@ -223,7 +215,6 @@ where async fn update_data(state: &S, network: Network, update: &Update) -> Result<()> where - S: StateApi, for<'a> &'a S: Into<&'a Store>, { let store: &Store = state.into(); @@ -261,10 +252,11 @@ where // This message is not an error. It is common // for primary and secondary network to have // slight difference in their timestamps. - debug!(store.logger, "Global store: ignoring stale update of an existing newer price"; - "price_key" => account_key.to_string(), - "existing_timestamp" => existing_price.timestamp, - "new_timestamp" => account.timestamp, + tracing::debug!( + price_key = account_key.to_string(), + existing_timestamp = existing_price.timestamp, + new_timestamp = account.timestamp, + "Global store: ignoring stale update of an existing newer price" ); return Ok(()); } @@ -279,23 +271,6 @@ where .await .price_accounts .insert(*account_key, *account); - - // Notify the Pythd API adapter that this account has changed. - // As the account data might differ between the two networks - // we only notify the adapter of the primary network updates. - if let Network::Primary = network { - StateApi::global_store_update( - state, - Identifier::new(account_key.to_bytes()), - account.agg.price, - account.agg.conf, - account.agg.status, - account.valid_slot, - account.agg.pub_slot, - ) - .await - .map_err(|_| anyhow!("failed to notify pythd adapter of account update"))?; - } } } @@ -304,7 +279,6 @@ where async fn update_metadata(state: &S, update: &Update) -> Result<()> where - S: StateApi, for<'a> &'a S: Into<&'a Store>, { let store: &Store = state.into(); diff --git a/src/agent/state/keypairs.rs b/src/agent/state/keypairs.rs new file mode 100644 index 00000000..a93eca5c --- /dev/null +++ b/src/agent/state/keypairs.rs @@ -0,0 +1,282 @@ +//! Keypair Management API +//! +//! The Keypair Manager allows hotloading keypairs via a HTTP request. + +use { + super::State, + crate::agent::solana::network::Network, + anyhow::{ + Context, + Result, + }, + serde::Deserialize, + solana_client::nonblocking::rpc_client::RpcClient, + solana_sdk::{ + commitment_config::CommitmentConfig, + signature::Keypair, + signer::Signer, + }, + std::{ + net::SocketAddr, + sync::Arc, + }, + tokio::{ + sync::RwLock, + task::JoinHandle, + }, + warp::{ + hyper::StatusCode, + reply::{ + self, + WithStatus, + }, + Filter, + Rejection, + }, +}; + +pub fn default_min_keypair_balance_sol() -> u64 { + 1 +} + +pub fn default_bind_address() -> SocketAddr { + "127.0.0.1:9001" + .parse() + .expect("INTERNAL: Could not build default remote keypair loader bind address") +} + +#[derive(Clone, Debug, Deserialize)] +#[serde(default)] +pub struct Config { + primary_min_keypair_balance_sol: u64, + secondary_min_keypair_balance_sol: u64, + bind_address: SocketAddr, +} + +impl Default for Config { + fn default() -> Self { + Self { + primary_min_keypair_balance_sol: default_min_keypair_balance_sol(), + secondary_min_keypair_balance_sol: default_min_keypair_balance_sol(), + bind_address: default_bind_address(), + } + } +} + +#[derive(Default)] +pub struct KeypairState { + primary_current_keypair: RwLock>, + secondary_current_keypair: RwLock>, +} + +#[async_trait::async_trait] +pub trait Keypairs { + async fn request_keypair(&self, network: Network) -> Result; + async fn update_keypair(&self, network: Network, new_keypair: Keypair); +} + +// Allow downcasting State into Keypairs for functions that depend on the `Keypairs` service. +impl<'a> From<&'a State> for &'a KeypairState { + fn from(state: &'a State) -> &'a KeypairState { + &state.keypairs + } +} + +#[async_trait::async_trait] +impl Keypairs for T +where + for<'a> &'a T: Into<&'a KeypairState>, + T: Sync, +{ + async fn request_keypair(&self, network: Network) -> Result { + let keypair = match network { + Network::Primary => &self.into().primary_current_keypair, + Network::Secondary => &self.into().secondary_current_keypair, + } + .read() + .await; + + Ok(Keypair::from_bytes( + &keypair + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Keypair not available"))? + .to_bytes(), + )?) + } + + async fn update_keypair(&self, network: Network, new_keypair: Keypair) { + *match network { + Network::Primary => self.into().primary_current_keypair.write().await, + Network::Secondary => self.into().secondary_current_keypair.write().await, + } = Some(new_keypair); + } +} + +pub async fn spawn( + primary_rpc_url: String, + secondary_rpc_url: Option, + config: Config, + state: Arc, +) -> Vec> +where + S: Keypairs, + S: Send + Sync + 'static, + for<'a> &'a S: Into<&'a KeypairState>, +{ + let ip = config.bind_address.ip(); + + if !ip.is_loopback() { + tracing::warn!( + bind_address = ?config.bind_address, + "Remote key loader: bind address is not localhost. Make sure the access on the selected address is secure.", + ); + } + + let primary_upload_route = { + let state = state.clone(); + let rpc_url = primary_rpc_url.clone(); + let min_balance = config.primary_min_keypair_balance_sol; + warp::path!("primary" / "load_keypair") + .and(warp::post()) + .and(warp::body::content_length_limit(1024)) + .and(warp::body::json()) + .and(warp::path::end()) + .and_then(move |kp: Vec| { + let state = state.clone(); + let rpc_url = rpc_url.clone(); + async move { + let response = handle_new_keypair( + state, + Network::Primary, + kp, + min_balance, + rpc_url, + "primary", + ) + .await; + Result::, Rejection>::Ok(response) + } + }) + }; + + let secondary_upload_route = warp::path!("secondary" / "load_keypair") + .and(warp::post()) + .and(warp::body::content_length_limit(1024)) + .and(warp::body::json()) + .and(warp::path::end()) + .and_then(move |kp: Vec| { + let state = state.clone(); + let rpc_url = secondary_rpc_url.clone(); + async move { + if let Some(rpc_url) = rpc_url { + let min_balance = config.secondary_min_keypair_balance_sol; + let response = handle_new_keypair( + state, + Network::Secondary, + kp, + min_balance, + rpc_url, + "secondary", + ) + .await; + Result::, Rejection>::Ok(response) + } else { + Result::, Rejection>::Ok(reply::with_status( + "Secondary network is not active", + StatusCode::SERVICE_UNAVAILABLE, + )) + } + } + }); + + let http_api_jh = { + let (_, serve) = warp::serve(primary_upload_route.or(secondary_upload_route)) + .bind_with_graceful_shutdown(config.bind_address, async { + let _ = crate::agent::EXIT.subscribe().changed().await; + }); + tokio::spawn(serve) + }; + + // WARNING: All jobs spawned here must report their join handles in this vec + vec![http_api_jh] +} + +/// Validate and apply a keypair to the specified mut reference, +/// hiding errors in logs. +/// +/// Returns the appropriate HTTP response depending on checks success. +/// +/// NOTE(2023-03-22): Lifetime bounds are currently necessary +/// because of https://github.com/rust-lang/rust/issues/63033 +async fn handle_new_keypair<'a, 'b: 'a, S>( + state: Arc, + network: Network, + new_keypair_bytes: Vec, + min_keypair_balance_sol: u64, + rpc_url: String, + network_name: &'b str, +) -> WithStatus<&'static str> +where + S: Keypairs, +{ + let mut upload_ok = true; + match Keypair::from_bytes(&new_keypair_bytes) { + Ok(kp) => match validate_keypair(&kp, min_keypair_balance_sol, rpc_url.clone()).await { + Ok(()) => { + Keypairs::update_keypair(&*state, network, kp).await; + } + Err(e) => { + tracing::warn!( + network = network_name, + error = e.to_string(), + "Remote keypair loader: Keypair failed validation", + ); + upload_ok = false; + } + }, + Err(e) => { + tracing::warn!( + network = network_name, + error = e.to_string(), + "Remote keypair loader: Keypair failed validation", + ); + upload_ok = false; + } + } + + if upload_ok { + reply::with_status("keypair upload OK", StatusCode::OK) + } else { + reply::with_status( + "Could not upload keypair. See logs for details.", + StatusCode::BAD_REQUEST, + ) + } +} + +/// Validate keypair balance before using it in transactions. +pub async fn validate_keypair( + kp: &Keypair, + min_keypair_balance_sol: u64, + rpc_url: String, +) -> Result<()> { + let c = RpcClient::new_with_commitment(rpc_url, CommitmentConfig::confirmed()); + + let balance_lamports = c + .get_balance(&kp.pubkey()) + .await + .context("Could not check keypair's balance")?; + + let lamports_in_sol = 1_000_000_000; + + if balance_lamports > min_keypair_balance_sol * lamports_in_sol { + Ok(()) + } else { + Err(anyhow::anyhow!(format!( + "Keypair {} balance of {} SOL below threshold of {} SOL", + kp.pubkey(), + balance_lamports as f64 / lamports_in_sol as f64, + min_keypair_balance_sol + ))) + } +} diff --git a/src/agent/state/local.rs b/src/agent/state/local.rs index 8d05654e..5e228298 100644 --- a/src/agent/state/local.rs +++ b/src/agent/state/local.rs @@ -2,11 +2,7 @@ // is contributing to the network. The Exporters will then take this data and publish // it to the networks. use { - super::{ - PriceIdentifier, - State, - StateApi, - }, + super::State, crate::agent::metrics::PriceLocalMetrics, anyhow::{ anyhow, @@ -15,7 +11,6 @@ use { chrono::NaiveDateTime, prometheus_client::registry::Registry, pyth_sdk_solana::state::PriceStatus, - slog::Logger, solana_sdk::bs58, std::collections::HashMap, tokio::sync::RwLock, @@ -46,31 +41,33 @@ impl PriceInfo { } pub struct Store { - prices: RwLock>, + prices: RwLock>, metrics: PriceLocalMetrics, - logger: Logger, } impl Store { - pub fn new(logger: Logger, registry: &mut Registry) -> Self { + pub fn new(registry: &mut Registry) -> Self { Store { - prices: RwLock::new(HashMap::new()), + prices: RwLock::new(HashMap::new()), metrics: PriceLocalMetrics::new(registry), - logger, } } } #[async_trait::async_trait] pub trait LocalStore { - async fn update(&self, price_identifier: PriceIdentifier, price_info: PriceInfo) -> Result<()>; - async fn get_all_price_infos(&self) -> HashMap; + async fn update( + &self, + price_identifier: pyth_sdk::Identifier, + price_info: PriceInfo, + ) -> Result<()>; + async fn get_all_price_infos(&self) -> HashMap; } -// Allow downcasting Adapter into GlobalStore for functions that depend on the `GlobalStore` service. +// Allow downcasting State into GlobalStore for functions that depend on the `GlobalStore` service. impl<'a> From<&'a State> for &'a Store { - fn from(adapter: &'a State) -> &'a Store { - &adapter.local_store + fn from(state: &'a State) -> &'a Store { + &state.local_store } } @@ -78,11 +75,17 @@ impl<'a> From<&'a State> for &'a Store { impl LocalStore for T where for<'a> &'a T: Into<&'a Store>, - T: StateApi, T: Sync, { - async fn update(&self, price_identifier: PriceIdentifier, price_info: PriceInfo) -> Result<()> { - debug!(self.into().logger, "local store received price update"; "identifier" => bs58::encode(price_identifier.to_bytes()).into_string()); + async fn update( + &self, + price_identifier: pyth_sdk::Identifier, + price_info: PriceInfo, + ) -> Result<()> { + tracing::debug!( + identifier = bs58::encode(price_identifier.to_bytes()).into_string(), + "Local store received price update." + ); // Drop the update if it is older than the current one stored for the price if let Some(current_price_info) = self.into().prices.read().await.get(&price_identifier) { @@ -104,7 +107,7 @@ where Ok(()) } - async fn get_all_price_infos(&self) -> HashMap { + async fn get_all_price_infos(&self) -> HashMap { self.into().prices.read().await.clone() } } diff --git a/src/agent/store.rs b/src/agent/store.rs deleted file mode 100644 index 6f0bb24d..00000000 --- a/src/agent/store.rs +++ /dev/null @@ -1 +0,0 @@ -pub type PriceIdentifier = pyth_sdk::Identifier; diff --git a/src/bin/agent.rs b/src/bin/agent.rs index 98e5194a..3a515f7c 100644 --- a/src/bin/agent.rs +++ b/src/bin/agent.rs @@ -4,27 +4,13 @@ use { Context, Result, }, - clap::{ - Parser, - ValueEnum, - }, + clap::Parser, pyth_agent::agent::{ config::Config, Agent, }, - slog::{ - debug, - error, - o, - Drain, - Logger, - PushFnValue, - Record, - }, - slog_async::Async, - slog_envlogger::LogBuilder, std::{ - env, + io::IsTerminal, path::PathBuf, }, }; @@ -35,26 +21,30 @@ use { struct Arguments { #[clap(short, long, default_value = "config/config.toml")] /// Path to configuration file - config: PathBuf, - #[clap(short, long, default_value = "plain", value_enum)] - /// Log flavor to use - log_flavor: LogFlavor, + config: PathBuf, #[clap(short = 'L', long)] /// Whether to print file:line info for each log statement log_locations: bool, } -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] -enum LogFlavor { - /// Standard human-readable output - Plain, - /// Structured JSON output - Json, -} - #[tokio::main] async fn main() -> Result<()> { + // Initialize a Tracing Subscriber + let fmt_builder = tracing_subscriber::fmt() + .with_file(false) + .with_line_number(true) + .with_thread_ids(true) + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_ansi(std::io::stderr().is_terminal()); + + // Use the compact formatter if we're in a terminal, otherwise use the JSON formatter. + if std::io::stderr().is_terminal() { + tracing::subscriber::set_global_default(fmt_builder.compact().finish())?; + } else { + tracing::subscriber::set_global_default(fmt_builder.json().finish())?; + } + let args = Arguments::parse(); if !args.config.as_path().exists() { @@ -66,66 +56,18 @@ async fn main() -> Result<()> { // Parse config early for logging channel capacity let config = Config::new(args.config).context("Could not parse config")?; - let log_level = env::var("RUST_LOG").unwrap_or("info".to_string()); - - // Build an async drain with a different inner drain depending on - // log flavor choice in CLI - let async_drain = match args.log_flavor { - LogFlavor::Json => { - // JSON output using slog-bunyan - let inner_drain = LogBuilder::new( - slog_bunyan::with_name(env!("CARGO_PKG_NAME"), std::io::stdout()) - .build() - .fuse(), - ) - .parse(&log_level) - .build(); - - Async::new(inner_drain) - .chan_size(config.channel_capacities.logger_buffer) - .build() - .fuse() - } - LogFlavor::Plain => { - // Plain, colored output usind slog-term - let inner_drain = LogBuilder::new( - slog_term::FullFormat::new(slog_term::TermDecorator::new().stdout().build()) - .build() - .fuse(), - ) - .parse(&log_level) - .build(); - - Async::new(inner_drain) - .chan_size(config.channel_capacities.logger_buffer) - .build() - .fuse() - } - }; - - let mut logger = slog::Logger::root(async_drain, o!()); - - // Add location information to each log statement if enabled - if args.log_locations { - logger = logger.new(o!( - "loc" => PushFnValue( - move |r: &Record, ser| { - ser.emit(format!("{}:{}", r.file(), r.line())) - } - ), - )); - } - - if let Err(err) = start(config, logger.clone()).await { - error!(logger, "{}", err); - debug!(logger, "error context"; "context" => format!("{:?}", err)); + // Launch the application. If it fails, print the full backtrace and exit. RUST_BACKTRACE + // should be set to 1 for this otherwise it will only print the top-level error. + if let Err(err) = start(config).await { + eprintln!("{}", err.backtrace()); + err.chain().for_each(|cause| eprintln!("{cause}")); return Err(err); } Ok(()) } -async fn start(config: Config, logger: Logger) -> Result<()> { - Agent::new(config).start(logger).await; +async fn start(config: Config) -> Result<()> { + Agent::new(config).start().await; Ok(()) } diff --git a/src/lib.rs b/src/lib.rs index 8792531a..f17bc55d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,8 +1 @@ -// The typed-html crate does pretty deep macro calls. Bump if -// recursion limit compilation errors return for html!() calls. -#![recursion_limit = "256"] -#[macro_use] -extern crate slog; -extern crate slog_term; - pub mod agent;