Skip to content

Commit

Permalink
fix GIL deadlock when setting up a custom embedding function in python (
Browse files Browse the repository at this point in the history
#1929)

* fix GIL deadlock when setting up a custom embedding function in python

* turn tracing back on for vector operations

* fix docstrings

* add comment with some clarification

* revert change on pometry-sotre-private commit
  • Loading branch information
ricopinazo authored Jan 22, 2025
1 parent cd862c9 commit 1b27e27
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 6 deletions.
6 changes: 5 additions & 1 deletion raphtory-graphql/src/python/server/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,11 @@ impl PyGraphServer {
let mut server = PyRunningGraphServer::new(join_handle, sender, port)?;
if let Some(_server_handler) = &server.server_handler {
let url = format!("http://localhost:{port}");
match PyRunningGraphServer::wait_for_server_online(&url, timeout_ms) {
// we need to release the GIL, otherwise the server will deadlock when trying to use python function as the embedding function
// and wait_for_server_online will never return
let result =
py.allow_threads(|| PyRunningGraphServer::wait_for_server_online(&url, timeout_ms));
match result {
Ok(_) => return Ok(server),
Err(e) => {
PyRunningGraphServer::stop_server(&mut server, py)?;
Expand Down
10 changes: 6 additions & 4 deletions raphtory-graphql/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,14 +176,11 @@ impl GraphServer {

/// Start the server on the port `port` and return a handle to it.
pub async fn start_with_port(self, port: u16) -> IoResult<RunningGraphServer> {
self.data.vectorise_all_graphs_that_are_not().await?;

let work_dir = self.data.work_dir.clone();
// set up opentelemetry first of all
let config = self.config.clone();
let filter = config.logging.get_log_env();
let tracer_name = config.tracing.otlp_tracing_service_name.clone();
let tp = config.tracing.tracer_provider();

// Create the base registry
let registry = Registry::default().with(filter).with(
fmt::layer().pretty().with_span_events(FmtSpan::NONE), //(FULL, NEW, ENTER, EXIT, CLOSE)
Expand All @@ -202,6 +199,10 @@ impl GraphServer {
registry.try_init().ok();
}
};

self.data.vectorise_all_graphs_that_are_not().await?;
let work_dir = self.data.work_dir.clone();

// it is important that this runs after algorithms have been pushed to PLUGIN_ALGOS static variable
let app: CorsEndpoint<CookieJarManagerEndpoint<Route>> = self
.generate_endpoint(tp.clone().map(|tp| tp.tracer(tracer_name)))
Expand Down Expand Up @@ -266,6 +267,7 @@ impl GraphServer {
}

/// A Raphtory server handler
#[derive(Debug)]
pub struct RunningGraphServer {
signal_sender: Sender<()>,
server_result: JoinHandle<IoResult<()>>,
Expand Down
12 changes: 12 additions & 0 deletions raphtory/src/python/packages/vectors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,18 @@ impl PyVectorisedGraph {
self.0.empty_selection()
}

/// Return all the graph level documents
///
/// Returns:
/// list[Document]: list of graph level documents
pub fn get_graph_documents(&self, py: Python) -> PyResult<Vec<PyDocument>> {
self.0
.get_graph_documents()
.into_iter()
.map(|doc| into_py_document(doc, &self.0, py))
.collect()
}

/// Search the top scoring documents according to `query` with no more than `limit` documents
///
/// Args:
Expand Down
1 change: 1 addition & 0 deletions raphtory/src/vectors/embedding_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use std::{

pub type CacheStore = HashMap<u64, Embedding>;

#[derive(Debug)]
pub struct EmbeddingCache {
cache: RwLock<CacheStore>, // TODO: double check that we really need a RwLock !!
path: PathBuf,
Expand Down
2 changes: 1 addition & 1 deletion raphtory/src/vectors/template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ impl From<Prop> for Value {
}
}

#[derive(Clone, Serialize, Deserialize, Default)]
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct DocumentTemplate {
pub graph_template: Option<String>,
pub node_template: Option<String>,
Expand Down
10 changes: 10 additions & 0 deletions raphtory/src/vectors/vectorised_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use super::{
similarity_search_utils::score_document_groups_by_highest,
vector_selection::VectorSelection,
vectorisable::{vectorise_edge, vectorise_graph, vectorise_node},
Document,
};

pub struct VectorisedGraph<G: StaticGraphViewOps> {
Expand Down Expand Up @@ -151,6 +152,15 @@ impl<G: StaticGraphViewOps> VectorisedGraph<G> {
VectorSelection::new(self.clone())
}

/// Return all the graph level documents
pub fn get_graph_documents(&self) -> Vec<Document> {
self.graph_documents
.read()
.iter()
.map(|doc| doc.regenerate(&self.source_graph, &self.template))
.collect_vec()
}

/// Search the top scoring documents according to `query` with no more than `limit` documents
///
/// # Arguments
Expand Down

0 comments on commit 1b27e27

Please sign in to comment.