diff --git a/mobly/base_suite.py b/mobly/base_suite.py index 188a9e0d..32442082 100644 --- a/mobly/base_suite.py +++ b/mobly/base_suite.py @@ -14,6 +14,8 @@ import abc +import logging + class BaseSuite(abc.ABC): """Class used to define a Mobly suite. @@ -34,11 +36,20 @@ class BaseSuite(abc.ABC): def __init__(self, runner, config): self._runner = runner self._config = config.copy() + self._test_selector = None @property def user_params(self): return self._config.user_params + def set_test_selector(self, test_selector): + """Sets test selector. + + Don't override or call this method. This should only be used by the Mobly + framework. + """ + self._test_selector = test_selector + def get_suite_info(self): """User defined extra suite information to be recorded in test summary. @@ -61,10 +72,20 @@ def add_test_class(self, clazz, config=None, tests=None, name_suffix=None): of test cases; all matched test cases will be executed; an error is raised if no match is found. If not specified, all tests in the class are executed. + CLI argument `tests` takes precedence over this argument. name_suffix: string, suffix to append to the class name for reporting. This is used for differentiating the same class executed with different parameters in a suite. """ + if self._test_selector: + cls_name = clazz.__name__ + if cls_name not in self._test_selector: + logging.info( + 'Skipping test class %s due to CLI argument `tests`.', cls_name + ) + return + tests = self._test_selector[cls_name] + if not config: config = self._config self._runner.add_test_class(config, clazz, tests, name_suffix) diff --git a/mobly/suite_runner.py b/mobly/suite_runner.py index 9a1f0e2c..8ff90bf1 100644 --- a/mobly/suite_runner.py +++ b/mobly/suite_runner.py @@ -193,21 +193,63 @@ def _parse_cli_args(argv): return parser.parse_known_args(argv)[0] -def _find_suite_class(): - """Finds the test suite class in the current module. +def _find_suite_classes_in_module(module): + """Finds all test suite classes in the given module. - Walk through module members and find the subclass of BaseSuite. Only - one subclass is allowed in a module. + Walk through module members and find all classes that is a subclass of + BaseSuite. + + Args: + module: types.ModuleType, the module object to find test suite classes. Returns: - The test suite class in the test module. + A list of test suite classes. """ test_suites = [] - main_module_members = sys.modules['__main__'] - for _, module_member in main_module_members.__dict__.items(): + for _, module_member in module.__dict__.items(): if inspect.isclass(module_member): if issubclass(module_member, base_suite.BaseSuite): test_suites.append(module_member) + return test_suites + + +def _find_suite_class(): + """Finds the test suite class. + + First search for test suite classes in the __main__ module. If no test suite + class is found, search in the module that is calling + `suite_runner.run_suite_class`. + + Walk through module members and find the subclass of BaseSuite. Only + one subclass is allowed. + + Returns: + The test suite class in the test module. + """ + # Try to find test suites in __main__ module first. + test_suites = _find_suite_classes_in_module(sys.modules['__main__']) + + # Try to find test suites in the module of the caller of `run_suite_class`. + if len(test_suites) == 0: + logging.debug( + 'No suite class found in the __main__ module, trying to find it in the ' + 'module of the caller of suite_runner.run_suite_class method.' + ) + stacks = inspect.stack() + if len(stacks) < 2: + logging.debug( + 'Failed to get the caller stack of run_suite_class. Got stacks: %s', + stacks, + ) + else: + run_suite_class_caller_frame_info = inspect.stack()[2] + caller_frame = run_suite_class_caller_frame_info.frame + module = inspect.getmodule(caller_frame) + if module is None: + logging.debug('Failed to find module for frame %s', caller_frame) + else: + test_suites = _find_suite_classes_in_module(module) + if len(test_suites) != 1: logging.error( 'Expected 1 test class per file, found %s.', @@ -273,7 +315,8 @@ def run_suite_class(argv=None): log_dir=config.log_path, testbed_name=config.testbed_name ) suite = suite_class(runner, config) - + test_selector = _parse_raw_test_selector(cli_args.tests) + suite.set_test_selector(test_selector) suite_record = SuiteInfoRecord(test_suite_class=suite_class.__name__) console_level = logging.DEBUG if cli_args.verbose else logging.INFO @@ -357,8 +400,8 @@ def compute_selected_tests(test_classes, selected_tests): that class are selected. Args: - test_classes: list of strings, names of all the classes that are part - of a suite. + test_classes: list of `type[base_test.BaseTestClass]`, all the test classes + that are part of a suite. selected_tests: list of strings, list of tests to execute. If empty, all classes `test_classes` are selected. E.g. @@ -396,6 +439,50 @@ def compute_selected_tests(test_classes, selected_tests): # The user is selecting some tests to run. Parse the selectors. # Dict from test_name class name to list of tests to execute (or None for all # tests). + test_class_name_to_tests = _parse_raw_test_selector(selected_tests) + + # Now compute the tests to run for each test class. + # Dict from test class name to class instance. + class_name_to_class = {cls.__name__: cls for cls in test_classes} + for test_class_name, tests in test_class_name_to_tests.items(): + test_class = class_name_to_class.get(test_class_name) + if not test_class: + raise Error('Unknown test_class name %s' % test_class_name) + class_to_tests[test_class] = tests + + return class_to_tests + + +def _parse_raw_test_selector(selected_tests): + """Parses test selector from CLI arguments. + + This function transforms a list of selector strings (such as FooTest or + FooTest.test_method_a) to a dict where keys are test_name classes, and + values are lists of selected tests in those classes. None means all tests in + that class are selected. + + Args: + selected_tests: list of strings, list of tests to execute. E.g. + + .. code-block:: python + + ['FooTest', 'BarTest', 'BazTest.test_method_a', 'BazTest.test_method_b'] + + Returns: + A dict. Keys are test class names, values are lists of test names within + class. E.g. the example in `selected_tests` would translate to: + + .. code-block:: python + { + 'FooTest': None, + 'BarTest': None, + 'BazTest': ['test_method_a', 'test_method_b'], + } + + This returns None if `selected_tests` is None. + """ + if selected_tests is None: + return None test_class_name_to_tests = collections.OrderedDict() for test_name in selected_tests: if '.' in test_name: # Has a test method @@ -412,13 +499,4 @@ def compute_selected_tests(test_classes, selected_tests): else: # No test method; run all tests in this class. test_class_name_to_tests[test_name] = None - # Now transform class names to class objects. - # Dict from test_name class name to instance. - class_name_to_class = {cls.__name__: cls for cls in test_classes} - for test_class_name, tests in test_class_name_to_tests.items(): - test_class = class_name_to_class.get(test_class_name) - if not test_class: - raise Error('Unknown test_name class %s' % test_class_name) - class_to_tests[test_class] = tests - - return class_to_tests + return test_class_name_to_tests diff --git a/tests/lib/integration_test_suite.py b/tests/lib/integration_test_suite.py new file mode 100644 index 00000000..dd95ab04 --- /dev/null +++ b/tests/lib/integration_test_suite.py @@ -0,0 +1,31 @@ +# Copyright 2024 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mobly import base_suite +from mobly import suite_runner +from tests.lib import integration_test + + +class IntegrationTestSuite(base_suite.BaseSuite): + + def setup_suite(self, config): + self.add_test_class(integration_test.IntegrationTest) + + +def main(): + suite_runner.run_suite_class() + + +if __name__ == "__main__": + main() diff --git a/tests/mobly/base_suite_test.py b/tests/mobly/base_suite_test.py new file mode 100644 index 00000000..43f92133 --- /dev/null +++ b/tests/mobly/base_suite_test.py @@ -0,0 +1,139 @@ +# Copyright 2024 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import os +import shutil +import sys +import tempfile +import unittest +from unittest import mock + +from mobly import base_suite +from mobly import base_test +from mobly import suite_runner +from mobly import test_runner +from mobly import config_parser +from tests.lib import integration2_test +from tests.lib import integration_test + + +class FakeTest1(base_test.BaseTestClass): + + def test_a(self): + pass + + def test_b(self): + pass + + def test_c(self): + pass + + +class FakeTest2(base_test.BaseTestClass): + + def test_2(self): + pass + + +class FakeTestSuite(base_suite.BaseSuite): + + def setup_suite(self, config): + self.add_test_class(FakeTest1, config) + self.add_test_class(FakeTest2, config) + + +class FakeTestSuiteWithFilteredTests(base_suite.BaseSuite): + + def setup_suite(self, config): + self.add_test_class(FakeTest1, config, ['test_a', 'test_b']) + self.add_test_class(FakeTest2, config, ['test_2']) + + +class BaseSuiteTest(unittest.TestCase): + + def setUp(self): + super().setUp() + self.mock_config = mock.Mock(autospec=config_parser.TestRunConfig) + self.mock_test_runner = mock.Mock(autospec=test_runner.TestRunner) + + def test_setup_suite(self): + suite = FakeTestSuite(self.mock_test_runner, self.mock_config) + suite.set_test_selector(None) + + suite.setup_suite(self.mock_config) + + self.mock_test_runner.add_test_class.assert_has_calls( + [ + mock.call(self.mock_config, FakeTest1, mock.ANY, mock.ANY), + mock.call(self.mock_config, FakeTest2, mock.ANY, mock.ANY), + ], + ) + + def test_setup_suite_with_test_selector(self): + suite = FakeTestSuite(self.mock_test_runner, self.mock_config) + test_selector = { + 'FakeTest1': ['test_a', 'test_b'], + 'FakeTest2': None, + } + + suite.set_test_selector(test_selector) + suite.setup_suite(self.mock_config) + + self.mock_test_runner.add_test_class.assert_has_calls( + [ + mock.call( + self.mock_config, FakeTest1, ['test_a', 'test_b'], mock.ANY + ), + mock.call(self.mock_config, FakeTest2, None, mock.ANY), + ], + ) + + def test_setup_suite_test_selector_takes_precedence(self): + suite = FakeTestSuiteWithFilteredTests( + self.mock_test_runner, self.mock_config + ) + test_selector = { + 'FakeTest1': ['test_a', 'test_c'], + 'FakeTest2': None, + } + + suite.set_test_selector(test_selector) + suite.setup_suite(self.mock_config) + + self.mock_test_runner.add_test_class.assert_has_calls( + [ + mock.call( + self.mock_config, FakeTest1, ['test_a', 'test_c'], mock.ANY + ), + mock.call(self.mock_config, FakeTest2, None, mock.ANY), + ], + ) + + def test_setup_suite_with_skip_test_class(self): + suite = FakeTestSuite(self.mock_test_runner, self.mock_config) + test_selector = {'FakeTest1': None} + + suite.set_test_selector(test_selector) + suite.setup_suite(self.mock_config) + + self.mock_test_runner.add_test_class.assert_has_calls( + [ + mock.call(self.mock_config, FakeTest1, None, mock.ANY), + ], + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/mobly/suite_runner_test.py b/tests/mobly/suite_runner_test.py index 5035d640..deb0b20d 100755 --- a/tests/mobly/suite_runner_test.py +++ b/tests/mobly/suite_runner_test.py @@ -26,15 +26,20 @@ from mobly import base_test from mobly import records from mobly import suite_runner +from mobly import test_runner from mobly import utils from tests.lib import integration2_test from tests.lib import integration_test +from tests.lib import integration_test_suite import yaml class FakeTest1(base_test.BaseTestClass): pass + def test_a(self): + pass + class SuiteRunnerTest(unittest.TestCase): @@ -148,6 +153,10 @@ def test_run_suite_class(self, mock_exit): class FakeTestSuite(base_suite.BaseSuite): + def set_test_selector(self, test_selector): + mock_called.set_test_selector(test_selector) + super().set_test_selector(test_selector) + def setup_suite(self, config): mock_called.setup_suite() super().setup_suite(config) @@ -168,12 +177,103 @@ def teardown_suite(self): mock_called.setup_suite.assert_called_once_with() mock_called.teardown_suite.assert_called_once_with() mock_exit.assert_not_called() + mock_called.set_test_selector.assert_called_once_with(None) + + @mock.patch('sys.exit') + @mock.patch.object(records, 'TestSummaryWriter', autospec=True) + @mock.patch.object(suite_runner, '_find_suite_class', autospec=True) + @mock.patch.object(test_runner, 'TestRunner') + def test_run_suite_class_with_test_selection_by_class( + self, mock_test_runner_class, mock_find_suite_class, *_ + ): + mock_test_runner = mock_test_runner_class.return_value + mock_test_runner.results.is_all_pass = True + tmp_file_path = self._gen_tmp_config_file() + mock_cli_args = [ + 'test_binary', + f'--config={tmp_file_path}', + '--tests=FakeTest1', + ] + mock_called = mock.MagicMock() + + class FakeTestSuite(base_suite.BaseSuite): + + def set_test_selector(self, test_selector): + mock_called.set_test_selector(test_selector) + super().set_test_selector(test_selector) + + def setup_suite(self, config): + self.add_test_class(FakeTest1) + + mock_find_suite_class.return_value = FakeTestSuite + + with mock.patch.object(sys, 'argv', new=mock_cli_args): + suite_runner.run_suite_class() + + mock_called.set_test_selector.assert_called_once_with( + {'FakeTest1': None}, + ) + + @mock.patch('sys.exit') + @mock.patch.object(records, 'TestSummaryWriter', autospec=True) + @mock.patch.object(suite_runner, '_find_suite_class', autospec=True) + @mock.patch.object(test_runner, 'TestRunner') + def test_run_suite_class_with_test_selection_by_method( + self, mock_test_runner_class, mock_find_suite_class, *_ + ): + mock_test_runner = mock_test_runner_class.return_value + mock_test_runner.results.is_all_pass = True + tmp_file_path = self._gen_tmp_config_file() + mock_cli_args = [ + 'test_binary', + f'--config={tmp_file_path}', + '--tests=FakeTest1.test_a', + ] + mock_called = mock.MagicMock() + + class FakeTestSuite(base_suite.BaseSuite): + + def set_test_selector(self, test_selector): + mock_called.set_test_selector(test_selector) + super().set_test_selector(test_selector) + + def setup_suite(self, config): + self.add_test_class(FakeTest1) + + mock_find_suite_class.return_value = FakeTestSuite + + with mock.patch.object(sys, 'argv', new=mock_cli_args): + suite_runner.run_suite_class() + + mock_called.set_test_selector.assert_called_once_with( + {'FakeTest1': ['test_a']}, + ) + + @mock.patch('sys.exit') + @mock.patch.object(records, 'TestSummaryWriter', autospec=True) + @mock.patch.object(test_runner, 'TestRunner') + @mock.patch.object( + integration_test_suite.IntegrationTestSuite, 'setup_suite', autospec=True + ) + def test_run_suite_class_finds_suite_class_when_not_in_main_module( + self, mock_setup_suite, mock_test_runner_class, *_ + ): + mock_test_runner = mock_test_runner_class.return_value + mock_test_runner.results.is_all_pass = True + mock_test_runner + tmp_file_path = self._gen_tmp_config_file() + mock_cli_args = ['test_binary', f'--config={tmp_file_path}'] + + with mock.patch.object(sys, 'argv', new=mock_cli_args): + integration_test_suite.main() + + mock_setup_suite.assert_called_once() @mock.patch('sys.exit') @mock.patch.object( utils, 'get_current_epoch_time', return_value=1733143236278 ) - def test_run_suite_class_records_suite_class_name(self, mock_time, _): + def test_run_suite_class_records_suite_info(self, mock_time, _): tmp_file_path = self._gen_tmp_config_file() mock_cli_args = ['test_binary', f'--config={tmp_file_path}'] expected_record = suite_runner.SuiteInfoRecord( @@ -230,9 +330,8 @@ def test_print_test_names_with_exception(self): def test_convert_suite_info_record_to_dict(self): suite_class_name = 'FakeTestSuite' suite_version = '1.2.3' - record = suite_runner.SuiteInfoRecord( - test_suite_class=suite_class_name, extras={'version': suite_version} - ) + record = suite_runner.SuiteInfoRecord(test_suite_class=suite_class_name) + record.set_extras({'version': suite_version}) record.suite_begin() record.suite_end()