Skip to content

Commit

Permalink
refactor(hermes): make State hidden to force APIs (#1537)
Browse files Browse the repository at this point in the history
  • Loading branch information
Reisen authored May 14, 2024
1 parent b6f5bf1 commit f0adce0
Show file tree
Hide file tree
Showing 12 changed files with 137 additions and 81 deletions.
39 changes: 29 additions & 10 deletions apps/hermes/src/api.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
use {
crate::{
config::RunOptions,
state::State,
state::{
Aggregates,
Benchmarks,
Cache,
Metrics,
},
},
anyhow::Result,
axum::{
Expand All @@ -24,7 +29,7 @@ mod rest;
pub mod types;
mod ws;

pub struct ApiState<S = State> {
pub struct ApiState<S> {
pub state: Arc<S>,
pub ws: Arc<ws::WsState>,
pub metrics: Arc<metrics_middleware::ApiMetrics>,
Expand All @@ -42,12 +47,12 @@ impl<S> Clone for ApiState<S> {
}
}

impl ApiState<State> {
pub fn new(
state: Arc<State>,
ws_whitelist: Vec<IpNet>,
requester_ip_header_name: String,
) -> Self {
impl<S> ApiState<S> {
pub fn new(state: Arc<S>, ws_whitelist: Vec<IpNet>, requester_ip_header_name: String) -> Self
where
S: Metrics,
S: Send + Sync + 'static,
{
Self {
metrics: Arc::new(metrics_middleware::ApiMetrics::new(state.clone())),
ws: Arc::new(ws::WsState::new(
Expand All @@ -61,7 +66,14 @@ impl ApiState<State> {
}

#[tracing::instrument(skip(opts, state))]
pub async fn spawn(opts: RunOptions, state: Arc<State>) -> Result<()> {
pub async fn spawn<S>(opts: RunOptions, state: Arc<S>) -> Result<()>
where
S: Aggregates,
S: Benchmarks,
S: Cache,
S: Metrics,
S: Send + Sync + 'static,
{
let state = {
let opts = opts.clone();
ApiState::new(
Expand All @@ -79,7 +91,14 @@ pub async fn spawn(opts: RunOptions, state: Arc<State>) -> Result<()> {
/// Currently this is based on Axum due to the simplicity and strong ecosystem support for the
/// packages they are based on (tokio & hyper).
#[tracing::instrument(skip(opts, state))]
pub async fn run(opts: RunOptions, state: ApiState) -> Result<()> {
pub async fn run<S>(opts: RunOptions, state: ApiState<S>) -> Result<()>
where
S: Aggregates,
S: Benchmarks,
S: Cache,
S: Metrics,
S: Send + Sync + 'static,
{
tracing::info!(endpoint = %opts.rpc.listen_addr, "Starting RPC Server.");

#[derive(OpenApi)]
Expand Down
8 changes: 3 additions & 5 deletions apps/hermes/src/api/metrics_middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ impl ApiMetrics {
pub fn new<S>(state: Arc<S>) -> Self
where
S: Metrics,
S: Send,
S: Sync,
S: 'static,
S: Send + Sync + 'static,
{
let new = Self {
requests: Family::default(),
Expand Down Expand Up @@ -81,8 +79,8 @@ pub struct Labels {
pub status: u16,
}

pub async fn track_metrics<B>(
State(api_state): State<ApiState>,
pub async fn track_metrics<B, S>(
State(api_state): State<ApiState<S>>,
req: Request<B>,
next: Next<B>,
) -> impl IntoResponse {
Expand Down
2 changes: 1 addition & 1 deletion apps/hermes/src/api/rest/v2/price_feeds_metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use {
},
ApiState,
},
price_feeds_metadata::PriceFeedMeta,
state::price_feeds_metadata::PriceFeedMeta,
},
anyhow::Result,
axum::{
Expand Down
4 changes: 1 addition & 3 deletions apps/hermes/src/api/rest/v2/sse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,7 @@ pub async fn price_stream_sse_handler<S>(
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, RestError>
where
S: Aggregates,
S: Sync,
S: Send,
S: 'static,
S: Send + Sync + 'static,
{
let price_ids: Vec<PriceIdentifier> = params.ids.into_iter().map(Into::into).collect();

Expand Down
28 changes: 20 additions & 8 deletions apps/hermes/src/api/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ use {
RequestTime,
},
metrics::Metrics,
State,
Benchmarks,
Cache,
PriceFeedMeta,
},
anyhow::{
anyhow,
Expand Down Expand Up @@ -124,9 +126,7 @@ impl WsMetrics {
pub fn new<S>(state: Arc<S>) -> Self
where
S: Metrics,
S: Send,
S: Sync,
S: 'static,
S: Send + Sync + 'static,
{
let new = Self {
interactions: Family::default(),
Expand Down Expand Up @@ -161,7 +161,11 @@ pub struct WsState {
}

impl WsState {
pub fn new(whitelist: Vec<IpNet>, requester_ip_header_name: String, state: Arc<State>) -> Self {
pub fn new<S>(whitelist: Vec<IpNet>, requester_ip_header_name: String, state: Arc<S>) -> Self
where
S: Metrics,
S: Send + Sync + 'static,
{
Self {
subscriber_counter: AtomicUsize::new(0),
rate_limiter: RateLimiter::dashmap(Quota::per_second(nonzero!(
Expand Down Expand Up @@ -211,11 +215,18 @@ enum ServerResponseMessage {
Err { error: String },
}

pub async fn ws_route_handler(
pub async fn ws_route_handler<S>(
ws: WebSocketUpgrade,
AxumState(state): AxumState<super::ApiState>,
AxumState(state): AxumState<ApiState<S>>,
headers: HeaderMap,
) -> impl IntoResponse {
) -> impl IntoResponse
where
S: Aggregates,
S: Benchmarks,
S: Cache,
S: PriceFeedMeta,
S: Send + Sync + 'static,
{
let requester_ip = headers
.get(state.ws.requester_ip_header_name.as_str())
.and_then(|value| value.to_str().ok())
Expand All @@ -230,6 +241,7 @@ pub async fn ws_route_handler(
async fn websocket_handler<S>(stream: WebSocket, state: ApiState<S>, subscriber_ip: Option<IpAddr>)
where
S: Aggregates,
S: Send,
{
let ws_state = state.ws.clone();

Expand Down
4 changes: 1 addition & 3 deletions apps/hermes/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use {
},
futures::future::join_all,
lazy_static::lazy_static,
state::State,
std::io::IsTerminal,
tokio::{
spawn,
Expand All @@ -21,7 +20,6 @@ mod api;
mod config;
mod metrics_server;
mod network;
mod price_feeds_metadata;
mod serde;
mod state;

Expand Down Expand Up @@ -53,7 +51,7 @@ async fn init() -> Result<()> {
let (update_tx, _) = tokio::sync::broadcast::channel(1000);

// Initialize a cache store with a 1000 element circular buffer.
let state = State::new(update_tx.clone(), 1000, opts.benchmarks.endpoint.clone());
let state = state::new(update_tx.clone(), 1000, opts.benchmarks.endpoint.clone());

// Listen for Ctrl+C so we can set the exit flag and wait for a graceful shutdown.
spawn(async move {
Expand Down
16 changes: 10 additions & 6 deletions apps/hermes/src/metrics_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
use {
crate::{
config::RunOptions,
state::{
metrics::Metrics,
State as AppState,
},
state::metrics::Metrics,
},
anyhow::Result,
axum::{
Expand All @@ -23,7 +20,11 @@ use {


#[tracing::instrument(skip(opts, state))]
pub async fn run(opts: RunOptions, state: Arc<AppState>) -> Result<()> {
pub async fn run<S>(opts: RunOptions, state: Arc<S>) -> Result<()>
where
S: Metrics,
S: Send + Sync + 'static,
{
tracing::info!(endpoint = %opts.metrics.server_listen_addr, "Starting Metrics Server.");

let app = Router::new();
Expand All @@ -44,7 +45,10 @@ pub async fn run(opts: RunOptions, state: Arc<AppState>) -> Result<()> {
Ok(())
}

pub async fn metrics(State(state): State<Arc<AppState>>) -> impl IntoResponse {
pub async fn metrics<S>(State(state): State<Arc<S>>) -> impl IntoResponse
where
S: Metrics,
{
let buffer = Metrics::encode(&*state).await;
(
[(
Expand Down
23 changes: 16 additions & 7 deletions apps/hermes/src/network/pythnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,17 @@ use {
GuardianSet,
GuardianSetData,
},
price_feeds_metadata::{
PriceFeedMeta,
DEFAULT_PRICE_FEEDS_CACHE_UPDATE_INTERVAL,
},
state::{
aggregate::{
AccumulatorMessages,
Aggregates,
Update,
},
price_feeds_metadata::{
PriceFeedMeta,
DEFAULT_PRICE_FEEDS_CACHE_UPDATE_INTERVAL,
},
wormhole::Wormhole,
State,
},
},
anyhow::{
Expand Down Expand Up @@ -139,7 +138,12 @@ async fn fetch_bridge_data(
}
}

pub async fn run(store: Arc<State>, pythnet_ws_endpoint: String) -> Result<!> {
pub async fn run<S>(store: Arc<S>, pythnet_ws_endpoint: String) -> Result<!>
where
S: Aggregates,
S: Wormhole,
S: Send + Sync + 'static,
{
let client = PubsubClient::new(pythnet_ws_endpoint.as_ref()).await?;

let config = RpcProgramAccountsConfig {
Expand Down Expand Up @@ -222,6 +226,7 @@ async fn fetch_existing_guardian_sets<S>(
) -> Result<()>
where
S: Wormhole,
S: Send + Sync + 'static,
{
let client = RpcClient::new(pythnet_http_endpoint.to_string());
let bridge = fetch_bridge_data(&client, &wormhole_contract_addr).await?;
Expand Down Expand Up @@ -261,7 +266,11 @@ where
}

#[tracing::instrument(skip(opts, state))]
pub async fn spawn(opts: RunOptions, state: Arc<State>) -> Result<()> {
pub async fn spawn<S>(opts: RunOptions, state: Arc<S>) -> Result<()>
where
S: Wormhole,
S: Send + Sync + 'static,
{
tracing::info!(endpoint = opts.pythnet.ws_addr, "Started Pythnet Listener.");

// Create RpcClient instance here
Expand Down
15 changes: 7 additions & 8 deletions apps/hermes/src/network/wormhole.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
use {
crate::{
config::RunOptions,
state::{
wormhole::Wormhole,
State,
},
state::wormhole::Wormhole,
},
anyhow::{
anyhow,
Expand Down Expand Up @@ -118,7 +115,11 @@ mod proto {

// Launches the Wormhole gRPC service.
#[tracing::instrument(skip(opts, state))]
pub async fn spawn(opts: RunOptions, state: Arc<State>) -> Result<()> {
pub async fn spawn<S>(opts: RunOptions, state: Arc<S>) -> Result<()>
where
S: Wormhole,
S: Send + Sync + 'static,
{
let mut exit = crate::EXIT.subscribe();
loop {
let current_time = Instant::now();
Expand All @@ -142,9 +143,7 @@ pub async fn spawn(opts: RunOptions, state: Arc<State>) -> Result<()> {
async fn run<S>(opts: RunOptions, state: Arc<S>) -> Result<!>
where
S: Wormhole,
S: Sync,
S: Send,
S: 'static,
S: Send + Sync + 'static,
{
let mut client = SpyRpcServiceClient::connect(opts.wormhole.spy_rpc_addr).await?;
let mut stream = client
Expand Down
Loading

0 comments on commit f0adce0

Please sign in to comment.