Skip to content

Commit

Permalink
Merge pull request #172 from firstbatchxyz/erhant/serper-jina-checks
Browse files Browse the repository at this point in the history
feat: network specific Jina & Serper checks
  • Loading branch information
erhant authored Jan 22, 2025
2 parents 160c802 + cedb534 commit 4d9f2ab
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 2 deletions.
16 changes: 16 additions & 0 deletions compute/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ impl DriaComputeNodeConfig {

/// Asserts that the configured listen address is free.
/// Throws an error if the address is already in use.
#[inline]
pub fn assert_address_not_in_use(&self) -> Result<()> {
if address_in_use(&self.p2p_listen_addr) {
return Err(eyre!(
Expand All @@ -137,6 +138,21 @@ impl DriaComputeNodeConfig {

Ok(())
}

/// Checks the network specific configurations.
pub fn check_network_specific(&self) -> Result<()> {
// if network is `pro`, we require Jina and Serper to be present.
if self.network_type == DriaNetworkType::Pro {
if !self.workflows.jina.has_api_key() {
return Err(eyre!("Jina is required for the Pro network."));
}
if !self.workflows.serper.has_api_key() {
return Err(eyre!("Serper is required for the Pro network."));
}
}

Ok(())
}
}

#[cfg(test)]
Expand Down
4 changes: 3 additions & 1 deletion compute/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ async fn main() -> Result<()> {
tokio::time::sleep(tokio::time::Duration::from_secs(duration_secs)).await;

log::warn!("Exiting due to DKN_EXIT_TIMEOUT.");

cancellation_token.cancel();
} else if let Err(err) = wait_for_termination(cancellation_token.clone()).await {
// if there is no timeout, we wait for termination signals here
Expand Down Expand Up @@ -86,6 +85,9 @@ async fn main() -> Result<()> {
}?;
log::warn!("Using models: {:#?}", config.workflows.models);

// check network-specific configurations
config.check_network_specific()?;

// create the node
let batch_size = config.batch_size;
let (mut node, p2p, worker_batch, worker_single) = DriaComputeNode::new(config).await?;
Expand Down
2 changes: 1 addition & 1 deletion p2p/src/network.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use libp2p::{Multiaddr, PeerId};

/// Network type.
#[derive(Default, Debug, Clone, Copy)]
#[derive(Default, Debug, Clone, Copy, PartialEq)]
pub enum DriaNetworkType {
#[default]
Community,
Expand Down
2 changes: 2 additions & 0 deletions utils/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub fn split_csv_line(input: &str) -> Vec<String> {

/// Reads an environment variable and trims whitespace and `"` from both ends.
/// If the trimmed value is empty, returns `None`.
#[inline]
pub fn safe_read_env(var: Result<String, std::env::VarError>) -> Option<String> {
var.map(|s| s.trim_matches('"').trim().to_string())
.ok()
Expand All @@ -43,6 +44,7 @@ where
/// Returns the current time in nanoseconds since the Unix epoch.
///
/// If a `SystemTimeError` occurs, will return 0 just to keep things running.
#[inline]
pub fn get_current_time_nanos() -> u128 {
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
Expand Down
6 changes: 6 additions & 0 deletions workflows/src/apis/jina.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ impl JinaConfig {
}
}

/// Checks if the API key is present.
#[inline]
pub fn has_api_key(&self) -> bool {
self.api_key.is_some()
}

/// Sets the API key for Jina.
pub fn with_api_key(mut self, api_key: String) -> Self {
self.api_key = Some(api_key);
Expand Down
6 changes: 6 additions & 0 deletions workflows/src/apis/serper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ impl SerperConfig {
}
}

/// Checks if the API key is present.
#[inline]
pub fn has_api_key(&self) -> bool {
self.api_key.is_some()
}

/// Sets the API key for Serper.
pub fn with_api_key(mut self, api_key: String) -> Self {
self.api_key = Some(api_key);
Expand Down

0 comments on commit 4d9f2ab

Please sign in to comment.