diff --git a/changelog.md b/changelog.md index 909ed96ef..07d82d016 100644 --- a/changelog.md +++ b/changelog.md @@ -13,6 +13,7 @@ * Add `ttl` parameter to `GdsSessions.get_or_create` to control if and when an unused session will be automatically deleted. * Add concurrency control for remote write-back procedures using the `concurrency` parameter. * Add progress logging for remote write-back when using GDS Sessions. +* Added a flag to GraphDataScience and AuraGraphDataScience classes to disable displaying progress bars when running procedures. ## Bug fixes diff --git a/graphdatascience/graph_data_science.py b/graphdatascience/graph_data_science.py index 6fb4edb16..aa0aeae4d 100644 --- a/graphdatascience/graph_data_science.py +++ b/graphdatascience/graph_data_science.py @@ -34,6 +34,7 @@ def __init__( arrow_disable_server_verification: bool = True, arrow_tls_root_certs: Optional[bytes] = None, bookmarks: Optional[Any] = None, + show_progress: bool = True, ): """ Construct a new GraphDataScience object. @@ -63,6 +64,8 @@ def __init__( GDS Arrow Flight server. bookmarks : Optional[Any], default None The Neo4j bookmarks to require a certain state before the next query gets executed. + show_progress : bool, default True + A flag to indicate whether to show progress bars for running procedures. """ if aura_ds: GraphDataScience._validate_endpoint(endpoint) @@ -70,7 +73,7 @@ def __init__( if isinstance(endpoint, QueryRunner): self._query_runner = endpoint else: - self._query_runner = Neo4jQueryRunner.create(endpoint, auth, aura_ds, database, bookmarks) + self._query_runner = Neo4jQueryRunner.create(endpoint, auth, aura_ds, database, bookmarks, show_progress) self._server_version = self._query_runner.server_version() @@ -86,6 +89,7 @@ def __init__( None if arrow is True else arrow, ) + self._query_runner.set_show_progress(show_progress) super().__init__(self._query_runner, namespace="gds", server_version=self._server_version) @property @@ -129,6 +133,17 @@ def set_bookmarks(self, bookmarks: Any) -> None: """ self._query_runner.set_bookmarks(bookmarks) + def set_show_progress(self, show_progress: bool) -> None: + """ + Set whether to show progress for running procedures. + + Parameters + ---------- + show_progress: bool + Whether to show progress for procedures. + """ + self._query_runner.set_show_progress(show_progress) + def database(self) -> Optional[str]: """ Get the database which queries are run against. diff --git a/graphdatascience/query_runner/arrow_query_runner.py b/graphdatascience/query_runner/arrow_query_runner.py index 45e412f98..7c3edb7a6 100644 --- a/graphdatascience/query_runner/arrow_query_runner.py +++ b/graphdatascience/query_runner/arrow_query_runner.py @@ -236,6 +236,9 @@ def close(self) -> None: def fallback_query_runner(self) -> QueryRunner: return self._fallback_query_runner + def set_show_progress(self, show_progress: bool) -> None: + self._fallback_query_runner.set_show_progress(show_progress) + def create_graph_constructor( self, graph_name: str, concurrency: int, undirected_relationship_types: Optional[List[str]] ) -> GraphConstructor: diff --git a/graphdatascience/query_runner/neo4j_query_runner.py b/graphdatascience/query_runner/neo4j_query_runner.py index 57894a945..2a35debdc 100644 --- a/graphdatascience/query_runner/neo4j_query_runner.py +++ b/graphdatascience/query_runner/neo4j_query_runner.py @@ -33,6 +33,7 @@ def create( aura_ds: bool = False, database: Optional[str] = None, bookmarks: Optional[Any] = None, + show_progress: bool = True, ) -> Neo4jQueryRunner: if isinstance(endpoint, str): config: Dict[str, Any] = {"user_agent": f"neo4j-graphdatascience-v{__version__}"} @@ -51,7 +52,9 @@ def create( ) elif isinstance(endpoint, neo4j.Driver): - query_runner = Neo4jQueryRunner(endpoint, auto_close=False, bookmarks=bookmarks, database=database) + query_runner = Neo4jQueryRunner( + endpoint, auto_close=False, bookmarks=bookmarks, database=database, show_progress=show_progress + ) else: raise ValueError(f"Invalid endpoint type: {type(endpoint)}") @@ -80,6 +83,7 @@ def __init__( database: Optional[str] = neo4j.DEFAULT_DATABASE, auto_close: bool = False, bookmarks: Optional[Any] = None, + show_progress: bool = True, ): self._driver = driver self._config = config @@ -89,6 +93,7 @@ def __init__( self._bookmarks = bookmarks self._last_bookmarks: Optional[Any] = None self._server_version = None + self._show_progress = show_progress self._progress_logger = QueryProgressLogger( self.__run_cypher_simplified_for_query_progress_logger, self.server_version ) @@ -175,12 +180,15 @@ def call_procedure( def run_cypher_query() -> DataFrame: return self.run_cypher(query, params, database, custom_error) - if logging: + if self._resolve_show_progress(logging): job_id = self._progress_logger.extract_or_create_job_id(params) return self._progress_logger.run_with_progress_logging(run_cypher_query, job_id, database) else: return run_cypher_query() + def _resolve_show_progress(self, show_progress: bool) -> bool: + return self._show_progress and show_progress + def server_version(self) -> ServerVersion: if self._server_version: return self._server_version @@ -256,6 +264,9 @@ def create_graph_constructor( self, graph_name, concurrency, undirected_relationship_types, self.server_version() ) + def set_show_progress(self, show_progress: bool) -> None: + self._show_progress = show_progress + @staticmethod def handle_driver_exception(session: neo4j.Session, e: Exception) -> None: reg_gds_hit = re.search( diff --git a/graphdatascience/query_runner/query_runner.py b/graphdatascience/query_runner/query_runner.py index c3b6c105b..2ce25a6e2 100644 --- a/graphdatascience/query_runner/query_runner.py +++ b/graphdatascience/query_runner/query_runner.py @@ -76,5 +76,9 @@ def bookmarks(self) -> Optional[Any]: def last_bookmarks(self) -> Optional[Any]: pass + @abstractmethod + def set_show_progress(self, show_progress: bool) -> None: + pass + def set_server_version(self, _: ServerVersion) -> None: pass diff --git a/graphdatascience/query_runner/session_query_runner.py b/graphdatascience/query_runner/session_query_runner.py index b9bf31c23..2a7766583 100644 --- a/graphdatascience/query_runner/session_query_runner.py +++ b/graphdatascience/query_runner/session_query_runner.py @@ -24,20 +24,22 @@ class SessionQueryRunner(QueryRunner): @staticmethod def create( - gds_query_runner: QueryRunner, db_query_runner: QueryRunner, arrow_client: GdsArrowClient + gds_query_runner: QueryRunner, db_query_runner: QueryRunner, arrow_client: GdsArrowClient, show_progress: bool ) -> SessionQueryRunner: - return SessionQueryRunner(gds_query_runner, db_query_runner, arrow_client) + return SessionQueryRunner(gds_query_runner, db_query_runner, arrow_client, show_progress) def __init__( self, gds_query_runner: QueryRunner, db_query_runner: QueryRunner, arrow_client: GdsArrowClient, + show_progress: bool, ): self._gds_query_runner = gds_query_runner self._db_query_runner = db_query_runner self._gds_arrow_client = arrow_client self._resolved_protocol_version = ProtocolVersionResolver(db_query_runner).resolve() + self._show_progress = show_progress self._progress_logger = QueryProgressLogger( lambda query, database: self._gds_query_runner.run_cypher(query=query, database=database), self._gds_query_runner.server_version, @@ -112,6 +114,10 @@ def create_graph_constructor( ) -> GraphConstructor: return self._gds_query_runner.create_graph_constructor(graph_name, concurrency, undirected_relationship_types) + def set_show_progress(self, show_progress: bool) -> None: + self._show_progress = show_progress + self._gds_query_runner.set_show_progress(show_progress) + def close(self) -> None: self._gds_arrow_client.close() self._gds_query_runner.close() @@ -184,7 +190,7 @@ def _remote_write_back( def run_write_back(): return write_protocol.run_write_back(self._db_query_runner, write_back_params, yields) - if logging: + if self._resolve_show_progress(logging): database_write_result = self._progress_logger.run_with_progress_logging(run_write_back, job_id, database) else: database_write_result = run_write_back() @@ -203,6 +209,9 @@ def run_write_back(): return gds_write_result + def _resolve_show_progress(self, show_progress: bool) -> bool: + return self._show_progress and show_progress + def _inject_arrow_config(self, params: Dict[str, Any]) -> None: host, port = self._gds_arrow_client.connection_info() token = self._gds_arrow_client.request_token() diff --git a/graphdatascience/session/aura_graph_data_science.py b/graphdatascience/session/aura_graph_data_science.py index 26e8b674b..b75ecaf16 100644 --- a/graphdatascience/session/aura_graph_data_science.py +++ b/graphdatascience/session/aura_graph_data_science.py @@ -35,6 +35,7 @@ def create( arrow_disable_server_verification: bool = False, arrow_tls_root_certs: Optional[bytes] = None, bookmarks: Optional[Any] = None, + show_progress: bool = True, ): # we need to explicitly set this as the default value is None # database in the session is always neo4j @@ -43,6 +44,7 @@ def create( auth=gds_session_connection_info.auth(), aura_ds=True, database="neo4j", + show_progress=show_progress, ) arrow_info = ArrowInfo.create(session_bolt_query_runner) @@ -66,14 +68,12 @@ def create( ) db_bolt_query_runner = Neo4jQueryRunner.create( - db_connection_info.uri, - db_connection_info.auth(), - aura_ds=True, + db_connection_info.uri, db_connection_info.auth(), aura_ds=True, show_progress=False ) db_bolt_query_runner.set_bookmarks(bookmarks) session_query_runner = SessionQueryRunner.create( - session_arrow_query_runner, db_bolt_query_runner, session_arrow_client + session_arrow_query_runner, db_bolt_query_runner, session_arrow_client, show_progress ) gds_version = session_bolt_query_runner.server_version() @@ -159,6 +159,17 @@ def set_bookmarks(self, bookmarks: Any) -> None: """ self._query_runner.set_bookmarks(bookmarks) + def set_show_progress(self, show_progress: bool) -> None: + """ + Set whether to show progress for running procedures. + + Parameters + ---------- + show_progress: bool + Whether to show progress for procedures. + """ + self._query_runner.set_show_progress(show_progress) + def database(self) -> Optional[str]: """ Get the database which cypher queries are run against. diff --git a/graphdatascience/tests/integration/test_progress_logging.py b/graphdatascience/tests/integration/test_progress_logging.py new file mode 100644 index 000000000..4276da4a5 --- /dev/null +++ b/graphdatascience/tests/integration/test_progress_logging.py @@ -0,0 +1,48 @@ +from neo4j import Driver +from pandas import DataFrame + +from graphdatascience import ServerVersion +from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner +from graphdatascience.query_runner.session_query_runner import SessionQueryRunner +from graphdatascience.tests.unit.conftest import CollectingQueryRunner +from graphdatascience.tests.unit.test_session_query_runner import FakeArrowClient + + +def test_disabled_progress_logging(neo4j_driver: Driver): + query_runner = Neo4jQueryRunner.create(neo4j_driver, show_progress=False) + assert query_runner._resolve_show_progress(True) is False + assert query_runner._resolve_show_progress(False) is False + + +def test_enabled_progress_logging(neo4j_driver: Driver): + query_runner = Neo4jQueryRunner.create(neo4j_driver, show_progress=True) + assert query_runner._resolve_show_progress(True) is True + assert query_runner._resolve_show_progress(False) is False + + +def test_disabled_progress_logging_session(neo4j_driver: Driver): + version = ServerVersion(2, 7, 0) + db_query_runner = CollectingQueryRunner(version, result_mock=DataFrame([{"version": "v1"}])) + gds_query_runner = CollectingQueryRunner(version) + query_runner = SessionQueryRunner.create( + gds_query_runner, + db_query_runner, + FakeArrowClient(), # type: ignore + show_progress=False, + ) + assert query_runner._resolve_show_progress(True) is False + assert query_runner._resolve_show_progress(False) is False + + +def test_enabled_progress_logging_session(neo4j_driver: Driver): + version = ServerVersion(2, 7, 0) + db_query_runner = CollectingQueryRunner(version, result_mock=DataFrame([{"version": "v1"}])) + gds_query_runner = CollectingQueryRunner(version) + query_runner = SessionQueryRunner.create( + gds_query_runner, + db_query_runner, + FakeArrowClient(), # type: ignore + show_progress=True, + ) + assert query_runner._resolve_show_progress(True) is True + assert query_runner._resolve_show_progress(False) is False diff --git a/graphdatascience/tests/integration/test_remote_graph_ops.py b/graphdatascience/tests/integration/test_remote_graph_ops.py index c133121ec..d6984a30b 100644 --- a/graphdatascience/tests/integration/test_remote_graph_ops.py +++ b/graphdatascience/tests/integration/test_remote_graph_ops.py @@ -73,7 +73,7 @@ def test_remote_write_back_node_similarity(gds_with_cloud_setup: AuraGraphDataSc G, writeRelationshipType="SIMILAR", writeProperty="score", similarityCutoff=0 ) - assert result["relationshipsWritten"] == 4 + assert result["relationshipsWritten"] == 2 @pytest.mark.cloud_architecture diff --git a/graphdatascience/tests/unit/conftest.py b/graphdatascience/tests/unit/conftest.py index 784368e25..6009f5206 100644 --- a/graphdatascience/tests/unit/conftest.py +++ b/graphdatascience/tests/unit/conftest.py @@ -115,6 +115,9 @@ def bookmarks(self) -> Optional[Any]: def last_bookmarks(self) -> Optional[Any]: return None + def set_show_progress(self, show_progress: bool) -> None: + pass + def create_graph_constructor( self, graph_name: str, concurrency: int, undirected_relationship_types: Optional[List[str]] ) -> GraphConstructor: diff --git a/graphdatascience/tests/unit/test_session_query_runner.py b/graphdatascience/tests/unit/test_session_query_runner.py index 8e4148f0d..e3a542cd6 100644 --- a/graphdatascience/tests/unit/test_session_query_runner.py +++ b/graphdatascience/tests/unit/test_session_query_runner.py @@ -21,7 +21,7 @@ def test_extracts_parameters_projection_v1() -> None: db_query_runner = CollectingQueryRunner(version, result_mock=DataFrame([{"version": "v1"}])) gds_query_runner = CollectingQueryRunner(version) gds_query_runner.set__mock_result(DataFrame([{"databaseLocation": "remote"}])) - qr = SessionQueryRunner.create(gds_query_runner, db_query_runner, FakeArrowClient()) # type: ignore + qr = SessionQueryRunner.create(gds_query_runner, db_query_runner, FakeArrowClient(), True) # type: ignore qr.call_procedure( endpoint="gds.arrow.project", @@ -68,6 +68,7 @@ def test_extracts_parameters_projection_v2() -> None: gds_query_runner, db_query_runner, FakeArrowClient(), # type: ignore + True, ) qr.call_procedure( @@ -112,7 +113,7 @@ def test_extracts_parameters_algo_write_v1() -> None: db_query_runner = CollectingQueryRunner(version, result_mock=DataFrame([{"version": "v1"}])) gds_query_runner = CollectingQueryRunner(version) gds_query_runner.set__mock_result(DataFrame([{"databaseLocation": "remote"}])) - qr = SessionQueryRunner.create(gds_query_runner, db_query_runner, FakeArrowClient()) # type: ignore + qr = SessionQueryRunner.create(gds_query_runner, db_query_runner, FakeArrowClient(), True) # type: ignore qr.call_procedure(endpoint="gds.degree.write", params=CallParameters(graph_name="g", config={"jobId": "my-job"})) @@ -141,6 +142,7 @@ def test_extracts_parameters_algo_write_v2() -> None: gds_query_runner, db_query_runner, FakeArrowClient(), # type: ignore + True, ) qr.call_procedure( @@ -169,7 +171,7 @@ def test_arrow_and_write_configuration() -> None: db_query_runner = CollectingQueryRunner(version, result_mock=DataFrame([{"version": "v1"}])) gds_query_runner = CollectingQueryRunner(version) gds_query_runner.set__mock_result(DataFrame([{"databaseLocation": "remote"}])) - qr = SessionQueryRunner.create(gds_query_runner, db_query_runner, FakeArrowClient()) # type: ignore + qr = SessionQueryRunner.create(gds_query_runner, db_query_runner, FakeArrowClient(), True) # type: ignore qr.call_procedure( endpoint="gds.degree.write", @@ -206,7 +208,7 @@ def test_arrow_and_write_configuration_graph_write() -> None: db_query_runner = CollectingQueryRunner(version, result_mock=DataFrame([{"version": "v1"}])) gds_query_runner = CollectingQueryRunner(version) gds_query_runner.set__mock_result(DataFrame([{"databaseLocation": "remote"}])) - qr = SessionQueryRunner.create(gds_query_runner, db_query_runner, FakeArrowClient()) # type: ignore + qr = SessionQueryRunner.create(gds_query_runner, db_query_runner, FakeArrowClient(), True) # type: ignore qr.call_procedure( endpoint="gds.graph.nodeProperties.write",