Skip to content

Commit

Permalink
Support finding suite class that is not defined in the main module. (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
mhaoli authored Dec 13, 2024
1 parent 19e9df4 commit 130f6d9
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 12 deletions.
56 changes: 49 additions & 7 deletions mobly/suite_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,21 +137,63 @@ def _parse_cli_args(argv):
return parser.parse_known_args(argv)[0]


def _find_suite_class():
"""Finds the test suite class in the current module.
def _find_suite_classes_in_module(module):
"""Finds all test suite classes in the given module.
Walk through module members and find the subclass of BaseSuite. Only
one subclass is allowed in a module.
Walk through module members and find all classes that is a subclass of
BaseSuite.
Args:
module: types.ModuleType, the module object to find test suite classes.
Returns:
The test suite class in the test module.
A list of test suite classes.
"""
test_suites = []
main_module_members = sys.modules['__main__']
for _, module_member in main_module_members.__dict__.items():
for _, module_member in module.__dict__.items():
if inspect.isclass(module_member):
if issubclass(module_member, base_suite.BaseSuite):
test_suites.append(module_member)
return test_suites


def _find_suite_class():
"""Finds the test suite class.
First search for test suite classes in the __main__ module. If no test suite
class is found, search in the module that is calling
`suite_runner.run_suite_class`.
Walk through module members and find the subclass of BaseSuite. Only
one subclass is allowed.
Returns:
The test suite class in the test module.
"""
# Try to find test suites in __main__ module first.
test_suites = _find_suite_classes_in_module(sys.modules['__main__'])

# Try to find test suites in the module of the caller of `run_suite_class`.
if len(test_suites) == 0:
logging.debug(
'No suite class found in the __main__ module, trying to find it in the '
'module of the caller of suite_runner.run_suite_class method.'
)
stacks = inspect.stack()
if len(stacks) < 2:
logging.debug(
'Failed to get the caller stack of run_suite_class. Got stacks: %s',
stacks,
)
else:
run_suite_class_caller_frame_info = inspect.stack()[2]
caller_frame = run_suite_class_caller_frame_info.frame
module = inspect.getmodule(caller_frame)
if module is None:
logging.debug('Failed to find module for frame %s', caller_frame)
else:
test_suites = _find_suite_classes_in_module(module)

if len(test_suites) != 1:
logging.error(
'Expected 1 test class per file, found %s.',
Expand Down
31 changes: 31 additions & 0 deletions tests/lib/integration_test_suite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2024 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from mobly import base_suite
from mobly import suite_runner
from tests.lib import integration_test


class IntegrationTestSuite(base_suite.BaseSuite):

def setup_suite(self, config):
self.add_test_class(integration_test.IntegrationTest)


def main():
suite_runner.run_suite_class()


if __name__ == "__main__":
main()
30 changes: 25 additions & 5 deletions tests/mobly/suite_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from mobly import test_runner
from tests.lib import integration2_test
from tests.lib import integration_test
from tests.lib import integration_test_suite


class FakeTest1(base_test.BaseTestClass):
Expand Down Expand Up @@ -140,8 +141,7 @@ def test_run_suite_with_failures(self, mock_exit):
mock_exit.assert_called_once_with(1)

@mock.patch('sys.exit')
@mock.patch.object(suite_runner, '_find_suite_class', autospec=True)
def test_run_suite_class(self, mock_find_suite_class, mock_exit):
def test_run_suite_class(self, mock_exit):
tmp_file_path = self._gen_tmp_config_file()
mock_cli_args = ['test_binary', f'--config={tmp_file_path}']
mock_called = mock.MagicMock()
Expand All @@ -161,12 +161,14 @@ def teardown_suite(self):
mock_called.teardown_suite()
super().teardown_suite()

mock_find_suite_class.return_value = FakeTestSuite
sys.modules['__main__'].__dict__[FakeTestSuite.__name__] = FakeTestSuite

with mock.patch.object(sys, 'argv', new=mock_cli_args):
suite_runner.run_suite_class()
try:
suite_runner.run_suite_class()
finally:
del sys.modules['__main__'].__dict__[FakeTestSuite.__name__]

mock_find_suite_class.assert_called_once()
mock_called.setup_suite.assert_called_once_with()
mock_called.teardown_suite.assert_called_once_with()
mock_exit.assert_not_called()
Expand Down Expand Up @@ -240,6 +242,24 @@ def setup_suite(self, config):
{'FakeTest1': ['test_a']},
)

@mock.patch('sys.exit')
@mock.patch.object(test_runner, 'TestRunner')
@mock.patch.object(
integration_test_suite.IntegrationTestSuite, 'setup_suite', autospec=True
)
def test_run_suite_class_finds_suite_class_when_not_in_main_module(
self, mock_setup_suite, mock_test_runner_class, mock_exit
):
mock_test_runner = mock_test_runner_class.return_value
mock_test_runner.results.is_all_pass = True
tmp_file_path = self._gen_tmp_config_file()
mock_cli_args = ['test_binary', f'--config={tmp_file_path}']

with mock.patch.object(sys, 'argv', new=mock_cli_args):
integration_test_suite.main()

mock_setup_suite.assert_called_once()

def test_print_test_names(self):
mock_test_class = mock.MagicMock()
mock_cls_instance = mock.MagicMock()
Expand Down

0 comments on commit 130f6d9

Please sign in to comment.