Skip to content

Commit

Permalink
Merge pull request #757 from soerenreichardt/configurable-logging
Browse files Browse the repository at this point in the history
Configure progress logging on GDS object
  • Loading branch information
soerenreichardt authored Sep 30, 2024
2 parents 3b5e45c + c944019 commit 1f8fdec
Show file tree
Hide file tree
Showing 11 changed files with 122 additions and 15 deletions.
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* Add `ttl` parameter to `GdsSessions.get_or_create` to control if and when an unused session will be automatically deleted.
* Add concurrency control for remote write-back procedures using the `concurrency` parameter.
* Add progress logging for remote write-back when using GDS Sessions.
* Added a flag to GraphDataScience and AuraGraphDataScience classes to disable displaying progress bars when running procedures.

## Bug fixes

Expand Down
17 changes: 16 additions & 1 deletion graphdatascience/graph_data_science.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
arrow_disable_server_verification: bool = True,
arrow_tls_root_certs: Optional[bytes] = None,
bookmarks: Optional[Any] = None,
show_progress: bool = True,
):
"""
Construct a new GraphDataScience object.
Expand Down Expand Up @@ -63,14 +64,16 @@ def __init__(
GDS Arrow Flight server.
bookmarks : Optional[Any], default None
The Neo4j bookmarks to require a certain state before the next query gets executed.
show_progress : bool, default True
A flag to indicate whether to show progress bars for running procedures.
"""
if aura_ds:
GraphDataScience._validate_endpoint(endpoint)

if isinstance(endpoint, QueryRunner):
self._query_runner = endpoint
else:
self._query_runner = Neo4jQueryRunner.create(endpoint, auth, aura_ds, database, bookmarks)
self._query_runner = Neo4jQueryRunner.create(endpoint, auth, aura_ds, database, bookmarks, show_progress)

self._server_version = self._query_runner.server_version()

Expand All @@ -86,6 +89,7 @@ def __init__(
None if arrow is True else arrow,
)

self._query_runner.set_show_progress(show_progress)
super().__init__(self._query_runner, namespace="gds", server_version=self._server_version)

@property
Expand Down Expand Up @@ -129,6 +133,17 @@ def set_bookmarks(self, bookmarks: Any) -> None:
"""
self._query_runner.set_bookmarks(bookmarks)

def set_show_progress(self, show_progress: bool) -> None:
"""
Set whether to show progress for running procedures.
Parameters
----------
show_progress: bool
Whether to show progress for procedures.
"""
self._query_runner.set_show_progress(show_progress)

def database(self) -> Optional[str]:
"""
Get the database which queries are run against.
Expand Down
3 changes: 3 additions & 0 deletions graphdatascience/query_runner/arrow_query_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,9 @@ def close(self) -> None:
def fallback_query_runner(self) -> QueryRunner:
return self._fallback_query_runner

def set_show_progress(self, show_progress: bool) -> None:
self._fallback_query_runner.set_show_progress(show_progress)

def create_graph_constructor(
self, graph_name: str, concurrency: int, undirected_relationship_types: Optional[List[str]]
) -> GraphConstructor:
Expand Down
15 changes: 13 additions & 2 deletions graphdatascience/query_runner/neo4j_query_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def create(
aura_ds: bool = False,
database: Optional[str] = None,
bookmarks: Optional[Any] = None,
show_progress: bool = True,
) -> Neo4jQueryRunner:
if isinstance(endpoint, str):
config: Dict[str, Any] = {"user_agent": f"neo4j-graphdatascience-v{__version__}"}
Expand All @@ -51,7 +52,9 @@ def create(
)

elif isinstance(endpoint, neo4j.Driver):
query_runner = Neo4jQueryRunner(endpoint, auto_close=False, bookmarks=bookmarks, database=database)
query_runner = Neo4jQueryRunner(
endpoint, auto_close=False, bookmarks=bookmarks, database=database, show_progress=show_progress
)

else:
raise ValueError(f"Invalid endpoint type: {type(endpoint)}")
Expand Down Expand Up @@ -80,6 +83,7 @@ def __init__(
database: Optional[str] = neo4j.DEFAULT_DATABASE,
auto_close: bool = False,
bookmarks: Optional[Any] = None,
show_progress: bool = True,
):
self._driver = driver
self._config = config
Expand All @@ -89,6 +93,7 @@ def __init__(
self._bookmarks = bookmarks
self._last_bookmarks: Optional[Any] = None
self._server_version = None
self._show_progress = show_progress
self._progress_logger = QueryProgressLogger(
self.__run_cypher_simplified_for_query_progress_logger, self.server_version
)
Expand Down Expand Up @@ -175,12 +180,15 @@ def call_procedure(
def run_cypher_query() -> DataFrame:
return self.run_cypher(query, params, database, custom_error)

if logging:
if self._resolve_show_progress(logging):
job_id = self._progress_logger.extract_or_create_job_id(params)
return self._progress_logger.run_with_progress_logging(run_cypher_query, job_id, database)
else:
return run_cypher_query()

def _resolve_show_progress(self, show_progress: bool) -> bool:
return self._show_progress and show_progress

def server_version(self) -> ServerVersion:
if self._server_version:
return self._server_version
Expand Down Expand Up @@ -256,6 +264,9 @@ def create_graph_constructor(
self, graph_name, concurrency, undirected_relationship_types, self.server_version()
)

def set_show_progress(self, show_progress: bool) -> None:
self._show_progress = show_progress

@staticmethod
def handle_driver_exception(session: neo4j.Session, e: Exception) -> None:
reg_gds_hit = re.search(
Expand Down
4 changes: 4 additions & 0 deletions graphdatascience/query_runner/query_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,9 @@ def bookmarks(self) -> Optional[Any]:
def last_bookmarks(self) -> Optional[Any]:
pass

@abstractmethod
def set_show_progress(self, show_progress: bool) -> None:
pass

def set_server_version(self, _: ServerVersion) -> None:
pass
15 changes: 12 additions & 3 deletions graphdatascience/query_runner/session_query_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,22 @@ class SessionQueryRunner(QueryRunner):

@staticmethod
def create(
gds_query_runner: QueryRunner, db_query_runner: QueryRunner, arrow_client: GdsArrowClient
gds_query_runner: QueryRunner, db_query_runner: QueryRunner, arrow_client: GdsArrowClient, show_progress: bool
) -> SessionQueryRunner:
return SessionQueryRunner(gds_query_runner, db_query_runner, arrow_client)
return SessionQueryRunner(gds_query_runner, db_query_runner, arrow_client, show_progress)

def __init__(
self,
gds_query_runner: QueryRunner,
db_query_runner: QueryRunner,
arrow_client: GdsArrowClient,
show_progress: bool,
):
self._gds_query_runner = gds_query_runner
self._db_query_runner = db_query_runner
self._gds_arrow_client = arrow_client
self._resolved_protocol_version = ProtocolVersionResolver(db_query_runner).resolve()
self._show_progress = show_progress
self._progress_logger = QueryProgressLogger(
lambda query, database: self._gds_query_runner.run_cypher(query=query, database=database),
self._gds_query_runner.server_version,
Expand Down Expand Up @@ -112,6 +114,10 @@ def create_graph_constructor(
) -> GraphConstructor:
return self._gds_query_runner.create_graph_constructor(graph_name, concurrency, undirected_relationship_types)

def set_show_progress(self, show_progress: bool) -> None:
self._show_progress = show_progress
self._gds_query_runner.set_show_progress(show_progress)

def close(self) -> None:
self._gds_arrow_client.close()
self._gds_query_runner.close()
Expand Down Expand Up @@ -184,7 +190,7 @@ def _remote_write_back(
def run_write_back():
return write_protocol.run_write_back(self._db_query_runner, write_back_params, yields)

if logging:
if self._resolve_show_progress(logging):
database_write_result = self._progress_logger.run_with_progress_logging(run_write_back, job_id, database)
else:
database_write_result = run_write_back()
Expand All @@ -203,6 +209,9 @@ def run_write_back():

return gds_write_result

def _resolve_show_progress(self, show_progress: bool) -> bool:
return self._show_progress and show_progress

def _inject_arrow_config(self, params: Dict[str, Any]) -> None:
host, port = self._gds_arrow_client.connection_info()
token = self._gds_arrow_client.request_token()
Expand Down
19 changes: 15 additions & 4 deletions graphdatascience/session/aura_graph_data_science.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def create(
arrow_disable_server_verification: bool = False,
arrow_tls_root_certs: Optional[bytes] = None,
bookmarks: Optional[Any] = None,
show_progress: bool = True,
):
# we need to explicitly set this as the default value is None
# database in the session is always neo4j
Expand All @@ -43,6 +44,7 @@ def create(
auth=gds_session_connection_info.auth(),
aura_ds=True,
database="neo4j",
show_progress=show_progress,
)

arrow_info = ArrowInfo.create(session_bolt_query_runner)
Expand All @@ -66,14 +68,12 @@ def create(
)

db_bolt_query_runner = Neo4jQueryRunner.create(
db_connection_info.uri,
db_connection_info.auth(),
aura_ds=True,
db_connection_info.uri, db_connection_info.auth(), aura_ds=True, show_progress=False
)
db_bolt_query_runner.set_bookmarks(bookmarks)

session_query_runner = SessionQueryRunner.create(
session_arrow_query_runner, db_bolt_query_runner, session_arrow_client
session_arrow_query_runner, db_bolt_query_runner, session_arrow_client, show_progress
)

gds_version = session_bolt_query_runner.server_version()
Expand Down Expand Up @@ -159,6 +159,17 @@ def set_bookmarks(self, bookmarks: Any) -> None:
"""
self._query_runner.set_bookmarks(bookmarks)

def set_show_progress(self, show_progress: bool) -> None:
"""
Set whether to show progress for running procedures.
Parameters
----------
show_progress: bool
Whether to show progress for procedures.
"""
self._query_runner.set_show_progress(show_progress)

def database(self) -> Optional[str]:
"""
Get the database which cypher queries are run against.
Expand Down
48 changes: 48 additions & 0 deletions graphdatascience/tests/integration/test_progress_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from neo4j import Driver
from pandas import DataFrame

from graphdatascience import ServerVersion
from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner
from graphdatascience.query_runner.session_query_runner import SessionQueryRunner
from graphdatascience.tests.unit.conftest import CollectingQueryRunner
from graphdatascience.tests.unit.test_session_query_runner import FakeArrowClient


def test_disabled_progress_logging(neo4j_driver: Driver):
query_runner = Neo4jQueryRunner.create(neo4j_driver, show_progress=False)
assert query_runner._resolve_show_progress(True) is False
assert query_runner._resolve_show_progress(False) is False


def test_enabled_progress_logging(neo4j_driver: Driver):
query_runner = Neo4jQueryRunner.create(neo4j_driver, show_progress=True)
assert query_runner._resolve_show_progress(True) is True
assert query_runner._resolve_show_progress(False) is False


def test_disabled_progress_logging_session(neo4j_driver: Driver):
version = ServerVersion(2, 7, 0)
db_query_runner = CollectingQueryRunner(version, result_mock=DataFrame([{"version": "v1"}]))
gds_query_runner = CollectingQueryRunner(version)
query_runner = SessionQueryRunner.create(
gds_query_runner,
db_query_runner,
FakeArrowClient(), # type: ignore
show_progress=False,
)
assert query_runner._resolve_show_progress(True) is False
assert query_runner._resolve_show_progress(False) is False


def test_enabled_progress_logging_session(neo4j_driver: Driver):
version = ServerVersion(2, 7, 0)
db_query_runner = CollectingQueryRunner(version, result_mock=DataFrame([{"version": "v1"}]))
gds_query_runner = CollectingQueryRunner(version)
query_runner = SessionQueryRunner.create(
gds_query_runner,
db_query_runner,
FakeArrowClient(), # type: ignore
show_progress=True,
)
assert query_runner._resolve_show_progress(True) is True
assert query_runner._resolve_show_progress(False) is False
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_remote_write_back_node_similarity(gds_with_cloud_setup: AuraGraphDataSc
G, writeRelationshipType="SIMILAR", writeProperty="score", similarityCutoff=0
)

assert result["relationshipsWritten"] == 4
assert result["relationshipsWritten"] == 2


@pytest.mark.cloud_architecture
Expand Down
3 changes: 3 additions & 0 deletions graphdatascience/tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ def bookmarks(self) -> Optional[Any]:
def last_bookmarks(self) -> Optional[Any]:
return None

def set_show_progress(self, show_progress: bool) -> None:
pass

def create_graph_constructor(
self, graph_name: str, concurrency: int, undirected_relationship_types: Optional[List[str]]
) -> GraphConstructor:
Expand Down
10 changes: 6 additions & 4 deletions graphdatascience/tests/unit/test_session_query_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_extracts_parameters_projection_v1() -> None:
db_query_runner = CollectingQueryRunner(version, result_mock=DataFrame([{"version": "v1"}]))
gds_query_runner = CollectingQueryRunner(version)
gds_query_runner.set__mock_result(DataFrame([{"databaseLocation": "remote"}]))
qr = SessionQueryRunner.create(gds_query_runner, db_query_runner, FakeArrowClient()) # type: ignore
qr = SessionQueryRunner.create(gds_query_runner, db_query_runner, FakeArrowClient(), True) # type: ignore

qr.call_procedure(
endpoint="gds.arrow.project",
Expand Down Expand Up @@ -68,6 +68,7 @@ def test_extracts_parameters_projection_v2() -> None:
gds_query_runner,
db_query_runner,
FakeArrowClient(), # type: ignore
True,
)

qr.call_procedure(
Expand Down Expand Up @@ -112,7 +113,7 @@ def test_extracts_parameters_algo_write_v1() -> None:
db_query_runner = CollectingQueryRunner(version, result_mock=DataFrame([{"version": "v1"}]))
gds_query_runner = CollectingQueryRunner(version)
gds_query_runner.set__mock_result(DataFrame([{"databaseLocation": "remote"}]))
qr = SessionQueryRunner.create(gds_query_runner, db_query_runner, FakeArrowClient()) # type: ignore
qr = SessionQueryRunner.create(gds_query_runner, db_query_runner, FakeArrowClient(), True) # type: ignore

qr.call_procedure(endpoint="gds.degree.write", params=CallParameters(graph_name="g", config={"jobId": "my-job"}))

Expand Down Expand Up @@ -141,6 +142,7 @@ def test_extracts_parameters_algo_write_v2() -> None:
gds_query_runner,
db_query_runner,
FakeArrowClient(), # type: ignore
True,
)

qr.call_procedure(
Expand Down Expand Up @@ -169,7 +171,7 @@ def test_arrow_and_write_configuration() -> None:
db_query_runner = CollectingQueryRunner(version, result_mock=DataFrame([{"version": "v1"}]))
gds_query_runner = CollectingQueryRunner(version)
gds_query_runner.set__mock_result(DataFrame([{"databaseLocation": "remote"}]))
qr = SessionQueryRunner.create(gds_query_runner, db_query_runner, FakeArrowClient()) # type: ignore
qr = SessionQueryRunner.create(gds_query_runner, db_query_runner, FakeArrowClient(), True) # type: ignore

qr.call_procedure(
endpoint="gds.degree.write",
Expand Down Expand Up @@ -206,7 +208,7 @@ def test_arrow_and_write_configuration_graph_write() -> None:
db_query_runner = CollectingQueryRunner(version, result_mock=DataFrame([{"version": "v1"}]))
gds_query_runner = CollectingQueryRunner(version)
gds_query_runner.set__mock_result(DataFrame([{"databaseLocation": "remote"}]))
qr = SessionQueryRunner.create(gds_query_runner, db_query_runner, FakeArrowClient()) # type: ignore
qr = SessionQueryRunner.create(gds_query_runner, db_query_runner, FakeArrowClient(), True) # type: ignore

qr.call_procedure(
endpoint="gds.graph.nodeProperties.write",
Expand Down

0 comments on commit 1f8fdec

Please sign in to comment.