Skip to content

Commit

Permalink
PYTHON-5101 Convert test.test_server_selection_in_window to async (#2119
Browse files Browse the repository at this point in the history
)

Co-authored-by: Noah Stapp <[email protected]>
  • Loading branch information
sleepyStick and NoahStapp authored Feb 11, 2025
1 parent 1a7239c commit 13fa361
Show file tree
Hide file tree
Showing 8 changed files with 515 additions and 92 deletions.
2 changes: 1 addition & 1 deletion test/asynchronous/test_encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ def allowable_errors(self, op):
return errors


async def create_test(scenario_def, test, name):
def create_test(scenario_def, test, name):
@async_client_context.require_test_commands
async def run_scenario(self):
await self.run_scenario(scenario_def, test)
Expand Down
179 changes: 179 additions & 0 deletions test/asynchronous/test_server_selection_in_window.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Copyright 2020-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test the topology module's Server Selection Spec implementation."""
from __future__ import annotations

import asyncio
import os
import threading
from pathlib import Path
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
from test.asynchronous.helpers import ConcurrentRunner
from test.asynchronous.utils_selection_tests import create_topology
from test.asynchronous.utils_spec_runner import AsyncSpecTestCreator
from test.utils import (
CMAPListener,
OvertCommandListener,
async_get_pool,
async_wait_until,
)

from pymongo.common import clean_node
from pymongo.monitoring import ConnectionReadyEvent
from pymongo.operations import _Op
from pymongo.read_preferences import ReadPreference

_IS_SYNC = False
# Location of JSON test specifications.
if _IS_SYNC:
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection", "in_window")
else:
TEST_PATH = os.path.join(
Path(__file__).resolve().parent.parent, "server_selection", "in_window"
)


class TestAllScenarios(unittest.IsolatedAsyncioTestCase):
async def run_scenario(self, scenario_def):
topology = await create_topology(scenario_def)

# Update mock operation_count state:
for mock in scenario_def["mocked_topology_state"]:
address = clean_node(mock["address"])
server = topology.get_server_by_address(address)
server.pool.operation_count = mock["operation_count"]

pref = ReadPreference.NEAREST
counts = {address: 0 for address in topology._description.server_descriptions()}

# Number of times to repeat server selection
iterations = scenario_def["iterations"]
for _ in range(iterations):
server = await topology.select_server(pref, _Op.TEST, server_selection_timeout=0)
counts[server.description.address] += 1

# Verify expected_frequencies
outcome = scenario_def["outcome"]
tolerance = outcome["tolerance"]
expected_frequencies = outcome["expected_frequencies"]
for host_str, freq in expected_frequencies.items():
address = clean_node(host_str)
actual_freq = float(counts[address]) / iterations
if freq == 0:
# Should be exactly 0.
self.assertEqual(actual_freq, 0)
else:
# Should be within 'tolerance'.
self.assertAlmostEqual(actual_freq, freq, delta=tolerance)


def create_test(scenario_def, test, name):
async def run_scenario(self):
await self.run_scenario(scenario_def)

return run_scenario


class CustomSpecTestCreator(AsyncSpecTestCreator):
def tests(self, scenario_def):
"""Extract the tests from a spec file.
Server selection in_window tests do not have a 'tests' field.
The whole file represents a single test case.
"""
return [scenario_def]


CustomSpecTestCreator(create_test, TestAllScenarios, TEST_PATH).create_tests()


class FinderTask(ConcurrentRunner):
def __init__(self, collection, iterations):
super().__init__()
self.daemon = True
self.collection = collection
self.iterations = iterations
self.passed = False

async def run(self):
for _ in range(self.iterations):
await self.collection.find_one({})
self.passed = True


class TestProse(AsyncIntegrationTest):
async def frequencies(self, client, listener, n_finds=10):
coll = client.test.test
N_TASKS = 10
tasks = [FinderTask(coll, n_finds) for _ in range(N_TASKS)]
for task in tasks:
await task.start()
for task in tasks:
await task.join()
for task in tasks:
self.assertTrue(task.passed)

events = listener.started_events
self.assertEqual(len(events), n_finds * N_TASKS)
nodes = client.nodes
self.assertEqual(len(nodes), 2)
freqs = {address: 0.0 for address in nodes}
for event in events:
freqs[event.connection_id] += 1
for address in freqs:
freqs[address] = freqs[address] / float(len(events))
return freqs

@async_client_context.require_failCommand_appName
@async_client_context.require_multiple_mongoses
async def test_load_balancing(self):
listener = OvertCommandListener()
cmap_listener = CMAPListener()
# PYTHON-2584: Use a large localThresholdMS to avoid the impact of
# varying RTTs.
client = await self.async_rs_client(
async_client_context.mongos_seeds(),
appName="loadBalancingTest",
event_listeners=[listener, cmap_listener],
localThresholdMS=30000,
minPoolSize=10,
)
await async_wait_until(lambda: len(client.nodes) == 2, "discover both nodes")
# Wait for both pools to be populated.
await cmap_listener.async_wait_for_event(ConnectionReadyEvent, 20)
# Delay find commands on only one mongos.
delay_finds = {
"configureFailPoint": "failCommand",
"mode": {"times": 10000},
"data": {
"failCommands": ["find"],
"blockConnection": True,
"blockTimeMS": 500,
"appName": "loadBalancingTest",
},
}
async with self.fail_point(delay_finds):
nodes = async_client_context.client.nodes
self.assertEqual(len(nodes), 1)
delayed_server = next(iter(nodes))
freqs = await self.frequencies(client, listener)
self.assertLessEqual(freqs[delayed_server], 0.25)
listener.reset()
freqs = await self.frequencies(client, listener, n_finds=150)
self.assertAlmostEqual(freqs[delayed_server], 0.50, delta=0.15)


if __name__ == "__main__":
unittest.main()
203 changes: 203 additions & 0 deletions test/asynchronous/utils_selection_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# Copyright 2015-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for testing Server Selection and Max Staleness."""
from __future__ import annotations

import datetime
import os
import sys
from test.asynchronous import AsyncPyMongoTestCase

sys.path[0:0] = [""]

from test import unittest
from test.pymongo_mocks import DummyMonitor
from test.utils import AsyncMockPool, parse_read_preference
from test.utils_selection_tests_shared import (
get_addresses,
get_topology_type_name,
make_server_description,
)

from bson import json_util
from pymongo.asynchronous.settings import TopologySettings
from pymongo.asynchronous.topology import Topology
from pymongo.common import HEARTBEAT_FREQUENCY
from pymongo.errors import AutoReconnect, ConfigurationError
from pymongo.operations import _Op
from pymongo.server_selectors import writable_server_selector

_IS_SYNC = False


def get_topology_settings_dict(**kwargs):
settings = {
"monitor_class": DummyMonitor,
"heartbeat_frequency": HEARTBEAT_FREQUENCY,
"pool_class": AsyncMockPool,
}
settings.update(kwargs)
return settings


async def create_topology(scenario_def, **kwargs):
# Initialize topologies.
if "heartbeatFrequencyMS" in scenario_def:
frequency = int(scenario_def["heartbeatFrequencyMS"]) / 1000.0
else:
frequency = HEARTBEAT_FREQUENCY

seeds, hosts = get_addresses(scenario_def["topology_description"]["servers"])

topology_type = get_topology_type_name(scenario_def)
if topology_type == "LoadBalanced":
kwargs.setdefault("load_balanced", True)
# Force topology description to ReplicaSet
elif topology_type in ["ReplicaSetNoPrimary", "ReplicaSetWithPrimary"]:
kwargs.setdefault("replica_set_name", "rs")
settings = get_topology_settings_dict(heartbeat_frequency=frequency, seeds=seeds, **kwargs)

# "Eligible servers" is defined in the server selection spec as
# the set of servers matching both the ReadPreference's mode
# and tag sets.
topology = Topology(TopologySettings(**settings))
await topology.open()

# Update topologies with server descriptions.
for server in scenario_def["topology_description"]["servers"]:
server_description = make_server_description(server, hosts)
await topology.on_change(server_description)

# Assert that descriptions match
assert (
scenario_def["topology_description"]["type"] == topology.description.topology_type_name
), topology.description.topology_type_name

return topology


def create_test(scenario_def):
async def run_scenario(self):
_, hosts = get_addresses(scenario_def["topology_description"]["servers"])
# "Eligible servers" is defined in the server selection spec as
# the set of servers matching both the ReadPreference's mode
# and tag sets.
top_latency = await create_topology(scenario_def)

# "In latency window" is defined in the server selection
# spec as the subset of suitable_servers that falls within the
# allowable latency window.
top_suitable = await create_topology(scenario_def, local_threshold_ms=1000000)

# Create server selector.
if scenario_def.get("operation") == "write":
pref = writable_server_selector
else:
# Make first letter lowercase to match read_pref's modes.
pref_def = scenario_def["read_preference"]
if scenario_def.get("error"):
with self.assertRaises((ConfigurationError, ValueError)):
# Error can be raised when making Read Pref or selecting.
pref = parse_read_preference(pref_def)
await top_latency.select_server(pref, _Op.TEST)
return

pref = parse_read_preference(pref_def)

# Select servers.
if not scenario_def.get("suitable_servers"):
with self.assertRaises(AutoReconnect):
await top_suitable.select_server(pref, _Op.TEST, server_selection_timeout=0)

return

if not scenario_def["in_latency_window"]:
with self.assertRaises(AutoReconnect):
await top_latency.select_server(pref, _Op.TEST, server_selection_timeout=0)

return

actual_suitable_s = await top_suitable.select_servers(
pref, _Op.TEST, server_selection_timeout=0
)
actual_latency_s = await top_latency.select_servers(
pref, _Op.TEST, server_selection_timeout=0
)

expected_suitable_servers = {}
for server in scenario_def["suitable_servers"]:
server_description = make_server_description(server, hosts)
expected_suitable_servers[server["address"]] = server_description

actual_suitable_servers = {}
for s in actual_suitable_s:
actual_suitable_servers[
"%s:%d" % (s.description.address[0], s.description.address[1])
] = s.description

self.assertEqual(len(actual_suitable_servers), len(expected_suitable_servers))
for k, actual in actual_suitable_servers.items():
expected = expected_suitable_servers[k]
self.assertEqual(expected.address, actual.address)
self.assertEqual(expected.server_type, actual.server_type)
self.assertEqual(expected.round_trip_time, actual.round_trip_time)
self.assertEqual(expected.tags, actual.tags)
self.assertEqual(expected.all_hosts, actual.all_hosts)

expected_latency_servers = {}
for server in scenario_def["in_latency_window"]:
server_description = make_server_description(server, hosts)
expected_latency_servers[server["address"]] = server_description

actual_latency_servers = {}
for s in actual_latency_s:
actual_latency_servers[
"%s:%d" % (s.description.address[0], s.description.address[1])
] = s.description

self.assertEqual(len(actual_latency_servers), len(expected_latency_servers))
for k, actual in actual_latency_servers.items():
expected = expected_latency_servers[k]
self.assertEqual(expected.address, actual.address)
self.assertEqual(expected.server_type, actual.server_type)
self.assertEqual(expected.round_trip_time, actual.round_trip_time)
self.assertEqual(expected.tags, actual.tags)
self.assertEqual(expected.all_hosts, actual.all_hosts)

return run_scenario


def create_selection_tests(test_dir):
class TestAllScenarios(AsyncPyMongoTestCase):
pass

for dirpath, _, filenames in os.walk(test_dir):
dirname = os.path.split(dirpath)
dirname = os.path.split(dirname[-2])[-1] + "_" + dirname[-1]

for filename in filenames:
if os.path.splitext(filename)[1] != ".json":
continue
with open(os.path.join(dirpath, filename)) as scenario_stream:
scenario_def = json_util.loads(scenario_stream.read())

# Construct test from scenario.
new_test = create_test(scenario_def)
test_name = f"test_{dirname}_{os.path.splitext(filename)[0]}"

new_test.__name__ = test_name
setattr(TestAllScenarios, new_test.__name__, new_test)

return TestAllScenarios
Loading

0 comments on commit 13fa361

Please sign in to comment.