diff --git a/pymongo/asynchronous/monitor.py b/pymongo/asynchronous/monitor.py index ad1bc70aba..5c18e034c3 100644 --- a/pymongo/asynchronous/monitor.py +++ b/pymongo/asynchronous/monitor.py @@ -21,11 +21,11 @@ import logging import time import weakref -from typing import TYPE_CHECKING, Any, Mapping, Optional, cast +from typing import TYPE_CHECKING, Any, Optional from pymongo import common, periodic_executor from pymongo._csot import MovingMinimum -from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled +from pymongo.errors import NetworkTimeout, _OperationCancelled from pymongo.hello import Hello from pymongo.lock import _async_create_lock from pymongo.logger import _SDAM_LOGGER, _debug_log, _SDAMStatusMessage @@ -250,13 +250,7 @@ async def _check_server(self) -> ServerDescription: self._conn_id = None start = time.monotonic() try: - try: - return await self._check_once() - except (OperationFailure, NotPrimaryError) as exc: - # Update max cluster time even when hello fails. - details = cast(Mapping[str, Any], exc.details) - await self._topology.receive_cluster_time(details.get("$clusterTime")) - raise + return await self._check_once() except asyncio.CancelledError: raise except ReferenceError: @@ -355,7 +349,6 @@ async def _check_with_socket(self, conn: AsyncConnection) -> tuple[Hello, float] Can raise ConnectionFailure or OperationFailure. """ - cluster_time = self._topology.max_cluster_time() start = time.monotonic() if conn.more_to_come: # Read the next streaming hello (MongoDB 4.4+). @@ -365,13 +358,13 @@ async def _check_with_socket(self, conn: AsyncConnection) -> tuple[Hello, float] ): # Initiate streaming hello (MongoDB 4.4+). response = await conn._hello( - cluster_time, + None, self._server_description.topology_version, self._settings.heartbeat_frequency, ) else: # New connection handshake or polling hello (MongoDB <4.4). - response = await conn._hello(cluster_time, None, None) + response = await conn._hello(None, None, None) duration = _monotonic_duration(start) return response, duration diff --git a/pymongo/asynchronous/network.py b/pymongo/asynchronous/network.py index d17aead120..c7a5580eca 100644 --- a/pymongo/asynchronous/network.py +++ b/pymongo/asynchronous/network.py @@ -207,6 +207,10 @@ async def command( ) response_doc = unpacked_docs[0] + if not conn.ready: + cluster_time = response_doc.get("$clusterTime") + if cluster_time: + conn._cluster_time = cluster_time if client: await client._process_response(response_doc, session) if check: diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index bf2f2b4946..46276ba4dc 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -310,6 +310,8 @@ def __init__( self.connect_rtt = 0.0 self._client_id = pool._client_id self.creation_time = time.monotonic() + # For gossiping $clusterTime from the connection handshake to the client. + self._cluster_time = None def set_conn_timeout(self, timeout: Optional[float]) -> None: """Cache last timeout to avoid duplicate calls to conn.settimeout.""" @@ -1314,6 +1316,9 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A conn.close_conn(ConnectionClosedReason.ERROR) raise + if handler: + await handler.client._topology.receive_cluster_time(conn._cluster_time) + return conn @contextlib.asynccontextmanager diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index 6d67710a7e..b5bab5b8dd 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -493,7 +493,8 @@ async def _process_change( self._description = new_td await self._update_servers() - self._receive_cluster_time_no_lock(server_description.cluster_time) + # TODO: Verify that app errors update the $clusterTime. + # self._receive_cluster_time_no_lock(server_description.cluster_time) if self._publish_tp and not suppress_event: assert self._events is not None diff --git a/pymongo/synchronous/monitor.py b/pymongo/synchronous/monitor.py index df4130d4ab..f6e5c5d1e4 100644 --- a/pymongo/synchronous/monitor.py +++ b/pymongo/synchronous/monitor.py @@ -21,11 +21,11 @@ import logging import time import weakref -from typing import TYPE_CHECKING, Any, Mapping, Optional, cast +from typing import TYPE_CHECKING, Any, Optional from pymongo import common, periodic_executor from pymongo._csot import MovingMinimum -from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled +from pymongo.errors import NetworkTimeout, _OperationCancelled from pymongo.hello import Hello from pymongo.lock import _create_lock from pymongo.logger import _SDAM_LOGGER, _debug_log, _SDAMStatusMessage @@ -250,13 +250,7 @@ def _check_server(self) -> ServerDescription: self._conn_id = None start = time.monotonic() try: - try: - return self._check_once() - except (OperationFailure, NotPrimaryError) as exc: - # Update max cluster time even when hello fails. - details = cast(Mapping[str, Any], exc.details) - self._topology.receive_cluster_time(details.get("$clusterTime")) - raise + return self._check_once() except asyncio.CancelledError: raise except ReferenceError: @@ -355,7 +349,6 @@ def _check_with_socket(self, conn: Connection) -> tuple[Hello, float]: Can raise ConnectionFailure or OperationFailure. """ - cluster_time = self._topology.max_cluster_time() start = time.monotonic() if conn.more_to_come: # Read the next streaming hello (MongoDB 4.4+). @@ -365,13 +358,13 @@ def _check_with_socket(self, conn: Connection) -> tuple[Hello, float]: ): # Initiate streaming hello (MongoDB 4.4+). response = conn._hello( - cluster_time, + None, self._server_description.topology_version, self._settings.heartbeat_frequency, ) else: # New connection handshake or polling hello (MongoDB <4.4). - response = conn._hello(cluster_time, None, None) + response = conn._hello(None, None, None) duration = _monotonic_duration(start) return response, duration diff --git a/pymongo/synchronous/network.py b/pymongo/synchronous/network.py index 7206dca735..543b069bfc 100644 --- a/pymongo/synchronous/network.py +++ b/pymongo/synchronous/network.py @@ -207,6 +207,10 @@ def command( ) response_doc = unpacked_docs[0] + if not conn.ready: + cluster_time = response_doc.get("$clusterTime") + if cluster_time: + conn._cluster_time = cluster_time if client: client._process_response(response_doc, session) if check: diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 05f930d480..5154efecc0 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -310,6 +310,8 @@ def __init__( self.connect_rtt = 0.0 self._client_id = pool._client_id self.creation_time = time.monotonic() + # For gossiping $clusterTime from the connection handshake to the client. + self._cluster_time = None def set_conn_timeout(self, timeout: Optional[float]) -> None: """Cache last timeout to avoid duplicate calls to conn.settimeout.""" @@ -1308,6 +1310,9 @@ def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connect conn.close_conn(ConnectionClosedReason.ERROR) raise + if handler: + handler.client._topology.receive_cluster_time(conn._cluster_time) + return conn @contextlib.contextmanager diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index b03269ae43..2f31ea2e8f 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -493,7 +493,8 @@ def _process_change( self._description = new_td self._update_servers() - self._receive_cluster_time_no_lock(server_description.cluster_time) + # TODO: Verify that app errors update the $clusterTime. + # self._receive_cluster_time_no_lock(server_description.cluster_time) if self._publish_tp and not suppress_event: assert self._events is not None diff --git a/test/asynchronous/test_session.py b/test/asynchronous/test_session.py index 42bc253b56..8c2fe1ff07 100644 --- a/test/asynchronous/test_session.py +++ b/test/asynchronous/test_session.py @@ -33,9 +33,11 @@ async_client_context, unittest, ) +from test.asynchronous.helpers import client_knobs from test.utils import ( EventListener, ExceptionCatchingThread, + HeartbeatEventListener, OvertCommandListener, async_wait_until, ) @@ -1133,12 +1135,10 @@ async def asyncSetUp(self): if "$clusterTime" not in (await async_client_context.hello): raise SkipTest("$clusterTime not supported") + # Sessions prose test: 3) $clusterTime in commands async def test_cluster_time(self): listener = SessionTestListener() - # Prevent heartbeats from updating $clusterTime between operations. - client = await self.async_rs_or_single_client( - event_listeners=[listener], heartbeatFrequencyMS=999999 - ) + client = await self.async_rs_or_single_client(event_listeners=[listener]) collection = client.pymongo_test.collection # Prepare for tests of find() and aggregate(). await collection.insert_many([{} for _ in range(10)]) @@ -1217,6 +1217,40 @@ async def aggregate(): f"{f.__name__} sent wrong $clusterTime with {event.command_name}", ) + # Sessions prose test: 20) Drivers do not gossip `$clusterTime` on SDAM commands + async def test_cluster_time_not_used_by_sdam(self): + heartbeat_listener = HeartbeatEventListener() + cmd_listener = OvertCommandListener() + with client_knobs(min_heartbeat_interval=0.01): + c1 = await self.async_single_client( + event_listeners=[heartbeat_listener, cmd_listener], heartbeatFrequencyMS=10 + ) + cluster_time = (await c1.admin.command({"ping": 1}))["$clusterTime"] + self.assertEqual(c1._topology.max_cluster_time(), cluster_time) + + # Advance the server's $clusterTime by performing an insert via another client. + await self.db.test.insert_one({"advance": "$clusterTime"}) + # Wait until the client C1 processes the next pair of SDAM heartbeat started + succeeded events. + heartbeat_listener.reset() + + async def next_heartbeat(): + events = heartbeat_listener.events + for i in range(len(events) - 1): + if isinstance(events[i], monitoring.ServerHeartbeatStartedEvent): + if isinstance(events[i + 1], monitoring.ServerHeartbeatSucceededEvent): + return True + return False + + await async_wait_until( + next_heartbeat, "never found pair of heartbeat started + succeeded events" + ) + # Assert that C1's max $clusterTime is still the same and has not been updated by SDAM. + cmd_listener.reset() + await c1.admin.command({"ping": 1}) + started = cmd_listener.started_events[0] + self.assertEqual(started.command_name, "ping") + self.assertEqual(started.command["$clusterTime"], cluster_time) + if __name__ == "__main__": unittest.main() diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index ce7a52f1a0..70dcfc5b48 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -244,7 +244,7 @@ class TestClusterTimeComparison(unittest.TestCase): def test_cluster_time_comparison(self): t = create_mock_topology("mongodb://host") - def send_cluster_time(time, inc, should_update): + def send_cluster_time(time, inc): old = t.max_cluster_time() new = {"clusterTime": Timestamp(time, inc)} got_hello( @@ -259,16 +259,14 @@ def send_cluster_time(time, inc, should_update): ) actual = t.max_cluster_time() - if should_update: - self.assertEqual(actual, new) - else: - self.assertEqual(actual, old) - - send_cluster_time(0, 1, True) - send_cluster_time(2, 2, True) - send_cluster_time(2, 1, False) - send_cluster_time(1, 3, False) - send_cluster_time(2, 3, True) + # We never update $clusterTime from monitoring connections. + self.assertEqual(actual, old) + + send_cluster_time(0, 1) + send_cluster_time(2, 2) + send_cluster_time(2, 1) + send_cluster_time(1, 3) + send_cluster_time(2, 3) class TestIgnoreStaleErrors(IntegrationTest): diff --git a/test/test_session.py b/test/test_session.py index 634efa11c0..c978818835 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -33,9 +33,11 @@ client_context, unittest, ) +from test.helpers import client_knobs from test.utils import ( EventListener, ExceptionCatchingThread, + HeartbeatEventListener, OvertCommandListener, wait_until, ) @@ -1119,10 +1121,10 @@ def setUp(self): if "$clusterTime" not in (client_context.hello): raise SkipTest("$clusterTime not supported") + # Sessions prose test: 3) $clusterTime in commands def test_cluster_time(self): listener = SessionTestListener() - # Prevent heartbeats from updating $clusterTime between operations. - client = self.rs_or_single_client(event_listeners=[listener], heartbeatFrequencyMS=999999) + client = self.rs_or_single_client(event_listeners=[listener]) collection = client.pymongo_test.collection # Prepare for tests of find() and aggregate(). collection.insert_many([{} for _ in range(10)]) @@ -1201,6 +1203,38 @@ def aggregate(): f"{f.__name__} sent wrong $clusterTime with {event.command_name}", ) + # Sessions prose test: 20) Drivers do not gossip `$clusterTime` on SDAM commands + def test_cluster_time_not_used_by_sdam(self): + heartbeat_listener = HeartbeatEventListener() + cmd_listener = OvertCommandListener() + with client_knobs(min_heartbeat_interval=0.01): + c1 = self.single_client( + event_listeners=[heartbeat_listener, cmd_listener], heartbeatFrequencyMS=10 + ) + cluster_time = (c1.admin.command({"ping": 1}))["$clusterTime"] + self.assertEqual(c1._topology.max_cluster_time(), cluster_time) + + # Advance the server's $clusterTime by performing an insert via another client. + self.db.test.insert_one({"advance": "$clusterTime"}) + # Wait until the client C1 processes the next pair of SDAM heartbeat started + succeeded events. + heartbeat_listener.reset() + + def next_heartbeat(): + events = heartbeat_listener.events + for i in range(len(events) - 1): + if isinstance(events[i], monitoring.ServerHeartbeatStartedEvent): + if isinstance(events[i + 1], monitoring.ServerHeartbeatSucceededEvent): + return True + return False + + wait_until(next_heartbeat, "never found pair of heartbeat started + succeeded events") + # Assert that C1's max $clusterTime is still the same and has not been updated by SDAM. + cmd_listener.reset() + c1.admin.command({"ping": 1}) + started = cmd_listener.started_events[0] + self.assertEqual(started.command_name, "ping") + self.assertEqual(started.command["$clusterTime"], cluster_time) + if __name__ == "__main__": unittest.main()