Skip to content

Commit

Permalink
refactor(agent): refactor all references to adapter to state
Browse files Browse the repository at this point in the history
  • Loading branch information
Reisen committed Jun 5, 2024
1 parent 0c1c066 commit 87a38be
Show file tree
Hide file tree
Showing 17 changed files with 122 additions and 133 deletions.
20 changes: 10 additions & 10 deletions src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,15 @@ impl Agent {
// job handles
let mut jhs = vec![];

// Create the Pythd Adapter.
let adapter =
Arc::new(state::State::new(self.config.pythd_adapter.clone(), logger.clone()).await);
// Create the Application State.
let state = Arc::new(state::State::new(self.config.state.clone(), logger.clone()).await);

// Spawn the primary network
jhs.extend(network::spawn_network(
self.config.primary_network.clone(),
network::Network::Primary,
logger.new(o!("primary" => true)),
adapter.clone(),
state.clone(),
)?);

// Spawn the secondary network, if needed
Expand All @@ -141,25 +140,25 @@ impl Agent {
config.clone(),
network::Network::Secondary,
logger.new(o!("primary" => false)),
adapter.clone(),
state.clone(),
)?);
}

// Create the Notifier task for the Pythd RPC.
jhs.push(tokio::spawn(notifier(logger.clone(), adapter.clone())));
jhs.push(tokio::spawn(notifier(logger.clone(), state.clone())));

// Spawn the Pythd API Server
jhs.push(tokio::spawn(rpc::run(
self.config.pythd_api_server.clone(),
logger.clone(),
adapter.clone(),
state.clone(),
)));

// Spawn the metrics server
jhs.push(tokio::spawn(metrics::MetricsServer::spawn(
self.config.metrics_server.bind_address,
logger.clone(),
adapter.clone(),
state.clone(),
)));

// Spawn the remote keypair loader endpoint for both networks
Expand All @@ -172,7 +171,7 @@ impl Agent {
.map(|c| c.rpc_url.clone()),
self.config.remote_keypair_loader.clone(),
logger,
adapter,
state,
)
.await,
);
Expand Down Expand Up @@ -210,7 +209,8 @@ pub mod config {
pub primary_network: network::Config,
pub secondary_network: Option<network::Config>,
#[serde(default)]
pub pythd_adapter: state::Config,
#[serde(rename = "pythd_adapter")]
pub state: state::Config,
#[serde(default)]
pub pythd_api_server: pythd::api::rpc::Config,
#[serde(default)]
Expand Down
6 changes: 3 additions & 3 deletions src/agent/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,16 @@ lazy_static! {
pub struct MetricsServer {
pub start_time: Instant,
pub logger: Logger,
pub adapter: Arc<State>,
pub state: Arc<State>,
}

impl MetricsServer {
/// Instantiate a metrics API.
pub async fn spawn(addr: impl Into<SocketAddr> + 'static, logger: Logger, adapter: Arc<State>) {
pub async fn spawn(addr: impl Into<SocketAddr> + 'static, logger: Logger, state: Arc<State>) {
let server = MetricsServer {
start_time: Instant::now(),
logger,
adapter,
state,
};

let shared_state = Arc::new(Mutex::new(server));
Expand Down
42 changes: 21 additions & 21 deletions src/agent/pythd/api/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ enum ConnectionError {

async fn handle_connection<S>(
ws_conn: WebSocket,
adapter: Arc<S>,
state: Arc<S>,
notify_price_tx_buffer: usize,
notify_price_sched_tx_buffer: usize,
logger: Logger,
Expand All @@ -131,7 +131,7 @@ async fn handle_connection<S>(
loop {
if let Err(err) = handle_next(
&logger,
&*adapter,
&*state,
&mut ws_tx,
&mut ws_rx,
&mut notify_price_tx,
Expand All @@ -156,7 +156,7 @@ async fn handle_connection<S>(

async fn handle_next<S>(
logger: &Logger,
adapter: &S,
state: &S,
ws_tx: &mut SplitSink<WebSocket, Message>,
ws_rx: &mut SplitStream<WebSocket>,
notify_price_tx: &mut mpsc::Sender<NotifyPrice>,
Expand All @@ -175,7 +175,7 @@ where
handle(
logger,
ws_tx,
adapter,
state,
notify_price_tx,
notify_price_sched_tx,
msg,
Expand All @@ -201,7 +201,7 @@ where
async fn handle<S>(
logger: &Logger,
ws_tx: &mut SplitSink<WebSocket, Message>,
adapter: &S,
state: &S,
notify_price_tx: &mpsc::Sender<NotifyPrice>,
notify_price_sched_tx: &mpsc::Sender<NotifyPriceSched>,
msg: Message,
Expand All @@ -224,7 +224,7 @@ where
for request in requests {
let response = dispatch_and_catch_error(
logger,
adapter,
state,
notify_price_tx,
notify_price_sched_tx,
&request,
Expand Down Expand Up @@ -287,7 +287,7 @@ async fn parse(msg: Message) -> Result<(Vec<Request<Method, Value>>, bool)> {

async fn dispatch_and_catch_error<S>(
logger: &Logger,
adapter: &S,
state: &S,
notify_price_tx: &mpsc::Sender<NotifyPrice>,
notify_price_sched_tx: &mpsc::Sender<NotifyPriceSched>,
request: &Request<Method, Value>,
Expand All @@ -302,13 +302,13 @@ where
);

let result = match request.method {
Method::GetProductList => get_product_list(adapter).await,
Method::GetProduct => get_product(adapter, request).await,
Method::GetAllProducts => get_all_products(adapter).await,
Method::UpdatePrice => update_price(adapter, request).await,
Method::SubscribePrice => subscribe_price(adapter, notify_price_tx, request).await,
Method::GetProductList => get_product_list(state).await,
Method::GetProduct => get_product(state, request).await,
Method::GetAllProducts => get_all_products(state).await,
Method::UpdatePrice => update_price(state, request).await,
Method::SubscribePrice => subscribe_price(state, notify_price_tx, request).await,
Method::SubscribePriceSched => {
subscribe_price_sched(adapter, notify_price_sched_tx, request).await
subscribe_price_sched(state, notify_price_sched_tx, request).await
}
Method::NotifyPrice | Method::NotifyPriceSched => {
Err(anyhow!("unsupported method: {:?}", request.method))
Expand Down Expand Up @@ -410,10 +410,10 @@ pub struct Config {
/// The address which the websocket API server will listen on.
pub listen_address: String,
/// Size of the buffer of each Server's channel on which `notify_price` events are
/// received from the Adapter.
/// received from the Price state.
pub notify_price_tx_buffer: usize,
/// Size of the buffer of each Server's channel on which `notify_price_sched` events are
/// received from the Adapter.
/// received from the Price state.
pub notify_price_sched_tx_buffer: usize,
}

Expand All @@ -427,20 +427,20 @@ impl Default for Config {
}
}

pub async fn run<S>(config: Config, logger: Logger, adapter: Arc<S>)
pub async fn run<S>(config: Config, logger: Logger, state: Arc<S>)
where
S: state::Prices,
S: Send,
S: Sync,
S: 'static,
{
if let Err(err) = serve(config, &logger, adapter).await {
if let Err(err) = serve(config, &logger, state).await {
error!(logger, "{}", err);
debug!(logger, "error context"; "context" => format!("{:?}", err));
}
}

async fn serve<S>(config: Config, logger: &Logger, adapter: Arc<S>) -> Result<()>
async fn serve<S>(config: Config, logger: &Logger, state: Arc<S>) -> Result<()>
where
S: state::Prices,
S: Send,
Expand All @@ -456,16 +456,16 @@ where
let config = config.clone();
warp::path::end()
.and(warp::ws())
.and(warp::any().map(move || adapter.clone()))
.and(warp::any().map(move || state.clone()))
.and(warp::any().map(move || with_logger.clone()))
.and(warp::any().map(move || config.clone()))
.map(
|ws: Ws, adapter: Arc<S>, with_logger: WithLogger, config: Config| {
|ws: Ws, state: Arc<S>, with_logger: WithLogger, config: Config| {
ws.on_upgrade(move |conn| async move {
info!(with_logger.logger, "websocket user connected");
handle_connection(
conn,
adapter,
state,
config.notify_price_tx_buffer,
config.notify_price_sched_tx_buffer,
with_logger.logger,
Expand Down
4 changes: 2 additions & 2 deletions src/agent/pythd/api/rpc/get_all_products.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ use {
anyhow::Result,
};

pub async fn get_all_products<S>(adapter: &S) -> Result<serde_json::Value>
pub async fn get_all_products<S>(state: &S) -> Result<serde_json::Value>
where
S: state::Prices,
{
let products = adapter.get_all_products().await?;
let products = state.get_all_products().await?;
Ok(serde_json::to_value(products)?)
}
4 changes: 2 additions & 2 deletions src/agent/pythd/api/rpc/get_product.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use {
};

pub async fn get_product<S>(
adapter: &S,
state: &S,
request: &Request<Method, Value>,
) -> Result<serde_json::Value>
where
Expand All @@ -27,6 +27,6 @@ where
}?;

let account = params.account.parse::<solana_sdk::pubkey::Pubkey>()?;
let product = adapter.get_product(&account).await?;
let product = state.get_product(&account).await?;
Ok(serde_json::to_value(product)?)
}
4 changes: 2 additions & 2 deletions src/agent/pythd/api/rpc/get_product_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ use {
anyhow::Result,
};

pub async fn get_product_list<S>(adapter: &S) -> Result<serde_json::Value>
pub async fn get_product_list<S>(state: &S) -> Result<serde_json::Value>
where
S: state::Prices,
{
let product_list = adapter.get_product_list().await?;
let product_list = state.get_product_list().await?;
Ok(serde_json::to_value(product_list)?)
}
4 changes: 2 additions & 2 deletions src/agent/pythd/api/rpc/subscribe_price.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use {
};

pub async fn subscribe_price<S>(
adapter: &S,
state: &S,
notify_price_tx: &mpsc::Sender<NotifyPrice>,
request: &Request<Method, Value>,
) -> Result<serde_json::Value>
Expand All @@ -33,7 +33,7 @@ where
)?;

let account = params.account.parse::<solana_sdk::pubkey::Pubkey>()?;
let subscription = adapter
let subscription = state
.subscribe_price(&account, notify_price_tx.clone())
.await;

Expand Down
4 changes: 2 additions & 2 deletions src/agent/pythd/api/rpc/subscribe_price_sched.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use {
};

pub async fn subscribe_price_sched<S>(
adapter: &S,
state: &S,
notify_price_sched_tx: &mpsc::Sender<NotifyPriceSched>,
request: &Request<Method, Value>,
) -> Result<serde_json::Value>
Expand All @@ -33,7 +33,7 @@ where
)?;

let account = params.account.parse::<solana_sdk::pubkey::Pubkey>()?;
let subscription = adapter
let subscription = state
.subscribe_price_sched(&account, notify_price_sched_tx.clone())
.await;

Expand Down
4 changes: 2 additions & 2 deletions src/agent/pythd/api/rpc/update_price.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use {
};

pub async fn update_price<S>(
adapter: &S,
state: &S,
request: &Request<Method, Value>,
) -> Result<serde_json::Value>
where
Expand All @@ -28,7 +28,7 @@ where
.ok_or_else(|| anyhow!("Missing request parameters"))?,
)?;

adapter
state
.update_local_price(
&params.account.parse::<solana_sdk::pubkey::Pubkey>()?,
params.price,
Expand Down
6 changes: 3 additions & 3 deletions src/agent/solana.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ pub mod network {
config: Config,
network: Network,
logger: Logger,
adapter: Arc<State>,
state: Arc<State>,
) -> Result<Vec<JoinHandle<()>>> {
// Publisher permissions updates between oracle and exporter
let (publisher_permissions_tx, publisher_permissions_rx) = watch::channel(<_>::default());
Expand All @@ -90,7 +90,7 @@ pub mod network {
publisher_permissions_tx,
KeyStore::new(config.key_store.clone(), &logger)?,
logger.clone(),
adapter.clone(),
state.clone(),
);

// Spawn the Exporter
Expand All @@ -102,7 +102,7 @@ pub mod network {
publisher_permissions_rx,
KeyStore::new(config.key_store.clone(), &logger)?,
logger,
adapter,
state,
)?;

jhs.extend(exporter_jhs);
Expand Down
Loading

0 comments on commit 87a38be

Please sign in to comment.