diff --git a/comtypes/client/_generate.py b/comtypes/client/_generate.py index 3924b997..f2e65941 100644 --- a/comtypes/client/_generate.py +++ b/comtypes/client/_generate.py @@ -122,6 +122,9 @@ def GetModule(tlib: _UnionT[Any, typeinfo.ITypeLib]) -> types.ModuleType: pathname = None tlib = _load_tlib(tlib) logger.debug("GetModule(%s)", tlib.GetLibAttr()) + mod = _get_existing_module(tlib) + if mod is not None: + return mod return ModuleGenerator(tlib, pathname).generate() @@ -161,6 +164,36 @@ def _load_tlib(obj: Any) -> typeinfo.ITypeLib: raise TypeError("'%r' is not supported type for loading typelib" % obj) +def _get_existing_module(tlib: typeinfo.ITypeLib) -> Optional[types.ModuleType]: + def _get_friendly(name: str) -> Optional[types.ModuleType]: + try: + mod = _my_import(name) + except Exception as details: + logger.info("Could not import %s: %s", friendly_name, details) + else: + return mod + + def _get_wrapper(name: str) -> Optional[types.ModuleType]: + if name in sys.modules: + return sys.modules[name] + try: + return _my_import(name) + except Exception as details: + logger.info("Could not import %s: %s", name, details) + + wrapper_name = codegenerator.name_wrapper_module(tlib) + friendly_name = codegenerator.name_friendly_module(tlib) + wrapper_module = _get_wrapper(wrapper_name) + if wrapper_module is not None: + if friendly_name is None: + return wrapper_module + else: + friendly_module = _get_friendly(friendly_name) + if friendly_module is not None: + return friendly_module + return None + + def _create_module(modulename: str, code: str) -> types.ModuleType: """Creates the module, then imports it.""" # `modulename` is 'comtypes.gen.xxx' @@ -186,8 +219,6 @@ def _create_module(modulename: str, code: str) -> types.ModuleType: class ModuleGenerator(object): def __init__(self, tlib: typeinfo.ITypeLib, pathname: Optional[str]) -> None: - known_symbols, known_interfaces = _get_known_namespaces() - self.codegen = codegenerator.CodeGenerator(known_symbols, known_interfaces) self.wrapper_name = codegenerator.name_wrapper_module(tlib) self.friendly_name = codegenerator.name_friendly_module(tlib) if pathname is None: @@ -197,54 +228,21 @@ def __init__(self, tlib: typeinfo.ITypeLib, pathname: Optional[str]) -> None: self.tlib = tlib def generate(self) -> types.ModuleType: - # create and import the real typelib wrapper module - mod = self._get_existing_wrapper_module() - if mod is None: - mod = self._create_wrapper_module() - if self.friendly_name is None: - return mod - mod = self._get_existing_friendly_module() - if mod is not None: - return mod - return self._create_friendly_module() - - def _get_existing_friendly_module(self) -> Optional[types.ModuleType]: - if self.friendly_name is None: - return - try: - mod = _my_import(self.friendly_name) - except Exception as details: - logger.info("Could not import %s: %s", self.friendly_name, details) - else: - return mod - - def _create_friendly_module(self) -> types.ModuleType: - """helper which creates and imports the friendly-named module.""" - if self.friendly_name is None: - raise TypeError - # the module is always regenerated if the import fails - logger.info("# Generating %s", self.friendly_name) - # determine the Python module name - code = self.codegen.generate_friendly_code(self.wrapper_name) - return _create_module(self.friendly_name, code) - - def _get_existing_wrapper_module(self) -> Optional[types.ModuleType]: - if self.wrapper_name in sys.modules: - return sys.modules[self.wrapper_name] - try: - return _my_import(self.wrapper_name) - except Exception as details: - logger.info("Could not import %s: %s", self.wrapper_name, details) - - def _create_wrapper_module(self) -> types.ModuleType: - """helper which creates and imports the real typelib wrapper module.""" - # generate the module since it doesn't exist or is out of date + """Generates wrapper and friendly modules.""" + known_symbols, known_interfaces = _get_known_namespaces() + codegen = codegenerator.CodeGenerator(known_symbols, known_interfaces) + codebases: List[Tuple[str, str]] = [] logger.info("# Generating %s", self.wrapper_name) items = list(tlbparser.TypeLibParser(self.tlib).parse().values()) - code = self.codegen.generate_wrapper_code(items, filename=self.pathname) - for ext_tlib in self.codegen.externals: # generates dependency COM-lib modules + wrp_code = codegen.generate_wrapper_code(items, filename=self.pathname) + codebases.append((self.wrapper_name, wrp_code)) + if self.friendly_name is not None: + logger.info("# Generating %s", self.friendly_name) + frd_code = codegen.generate_friendly_code(self.wrapper_name) + codebases.append((self.friendly_name, frd_code)) + for ext_tlib in codegen.externals: # generates dependency COM-lib modules GetModule(ext_tlib) - return _create_module(self.wrapper_name, code) + return [_create_module(name, code) for (name, code) in codebases][-1] _SymbolName = str diff --git a/comtypes/test/test_client_regenerate_modules.py b/comtypes/test/test_client_regenerate_modules.py new file mode 100644 index 00000000..dd2933eb --- /dev/null +++ b/comtypes/test/test_client_regenerate_modules.py @@ -0,0 +1,158 @@ +import contextlib +import importlib +from pathlib import Path +import shutil +import sys +import tempfile +import types +from typing import Iterator +import unittest as ut +from unittest import mock + +import comtypes +import comtypes.client +import comtypes.gen + +comtypes.client.GetModule("scrrun.dll") +from comtypes.gen import Scripting # noqa +from comtypes.gen import stdole # noqa + + +SCRRUN_FRIENDLY = Path(Scripting.__file__) +SCRRUN_WRAPPER = Path(Scripting.__wrapper_module__.__file__) +STDOLE_FRIENDLY = Path(stdole.__file__) +STDOLE_WRAPPER = Path(stdole.__wrapper_module__.__file__) + + +@contextlib.contextmanager +def _mkdtmp_gen_dir() -> Iterator[Path]: + with tempfile.TemporaryDirectory() as t: + tmp_dir = Path(t) + tmp_comtypes_dir = tmp_dir / "comtypes" + tmp_comtypes_dir.mkdir() + (tmp_comtypes_dir / "__init__.py").touch() + tmp_comtypes_gen_dir = tmp_comtypes_dir / "gen" + tmp_comtypes_gen_dir.mkdir() + (tmp_comtypes_gen_dir / "__init__.py").touch() + yield tmp_comtypes_gen_dir + + +@contextlib.contextmanager +def _patch_gen_pkg(new_path: Path) -> Iterator[types.ModuleType]: + new_comtypes_init = (new_path / "comtypes" / "__init__.py").resolve() + assert new_comtypes_init.exists() + new_comtypes_gen_init = (new_path / "comtypes" / "gen" / "__init__.py").resolve() + assert new_comtypes_gen_init.exists() + orig_comtypes = sys.modules["comtypes"] + orig_gen_names = list(filter(lambda k: k.startswith("comtypes.gen"), sys.modules)) + tmp_sys_path = list(sys.path) # copy + with mock.patch.object(sys, "path", tmp_sys_path): + sys.path.insert(0, str(new_path)) + with mock.patch.dict(sys.modules): + # The reason for removing the parent module (in this case, `comtypes`) + # from `sys.modules` is because the child module (in this case, + # `comtypes.gen`) refers to the namespace of the parent module. + # If the parent module exists in `sys.modules`, Python uses that cache + # to import the child module. Therefore, in order to import a new version + # of the child module, it is necessary to temporarily remove the parent + # module from `sys.modules`. + del sys.modules["comtypes"] + for k in orig_gen_names: + del sys.modules[k] + # The module that is imported here is not the one cached in `sys.modules` + # before the patch, but the module that is newly loaded from + # `new_path / 'comtypes' / 'gen' / '__init__.py'`. + new_comtypes_gen = importlib.import_module("comtypes.gen") + assert new_comtypes_gen.__file__ is not None + assert Path(new_comtypes_gen.__file__).resolve() == new_comtypes_gen_init + # The `comtypes` module cached in `sys.modules` as a side effect of + # executing the above line is empty because it is the one loaded from + # `new_path / 'comtypes' / '__init__.py'`. + # If we call the test target as it is, an error will occur due to + # referencing an empty module, so we restore the original `comtypes` + # to `sys.modules`. + sys.modules["comtypes"] = orig_comtypes + assert sys.modules["comtypes.gen"] is new_comtypes_gen + # By making the empty `comtypes.gen` package we created earlier to be + # referenced as the `gen` attribute of `comtypes`, the original + # `comtypes.gen` will not be referenced within the context. + with mock.patch.object(orig_comtypes, "gen", new_comtypes_gen): + yield new_comtypes_gen + + +@contextlib.contextmanager +def patch_gen_dir() -> Iterator[Path]: + with _mkdtmp_gen_dir() as tmp_gen_dir: + with mock.patch.object(comtypes.client, "gen_dir", str(tmp_gen_dir)): + try: + with _patch_gen_pkg(tmp_gen_dir.parent.parent): + yield tmp_gen_dir + finally: + importlib.invalidate_caches() + importlib.reload(comtypes.gen) + importlib.reload(stdole) + importlib.reload(Scripting) + + +class Test(ut.TestCase): + def test_all_modules_are_missing(self): + with patch_gen_dir() as gen_dir: + # ensure `gen_dir` and `sys.modules` are patched. + with self.assertRaises(ImportError): + from comtypes.gen import Scripting as _ # noqa + self.assertFalse((gen_dir / SCRRUN_FRIENDLY.name).exists()) + self.assertFalse((gen_dir / SCRRUN_WRAPPER.name).exists()) + self.assertFalse((gen_dir / STDOLE_FRIENDLY.name).exists()) + self.assertFalse((gen_dir / STDOLE_WRAPPER.name).exists()) + # generate new files and modules. + comtypes.client.GetModule("scrrun.dll") + self.assertTrue((gen_dir / SCRRUN_FRIENDLY.name).exists()) + self.assertTrue((gen_dir / SCRRUN_WRAPPER.name).exists()) + self.assertTrue((gen_dir / STDOLE_FRIENDLY.name).exists()) + self.assertTrue((gen_dir / STDOLE_WRAPPER.name).exists()) + + def test_friendly_module_is_missing(self): + with patch_gen_dir() as gen_dir: + shutil.copy2(SCRRUN_WRAPPER, gen_dir / SCRRUN_WRAPPER.name) + wrp_mtime = (gen_dir / SCRRUN_WRAPPER.name).stat().st_mtime_ns + shutil.copy2(STDOLE_FRIENDLY, gen_dir / STDOLE_FRIENDLY.name) + shutil.copy2(STDOLE_WRAPPER, gen_dir / STDOLE_WRAPPER.name) + comtypes.client.GetModule("scrrun.dll") + self.assertTrue((gen_dir / SCRRUN_FRIENDLY.name).exists()) + # Check the most recent content modification time to confirm whether + # the module file has been regenerated. + self.assertGreater( + (gen_dir / SCRRUN_WRAPPER.name).stat().st_mtime_ns, wrp_mtime + ) + + def test_wrapper_module_is_missing(self): + with patch_gen_dir() as gen_dir: + shutil.copy2(SCRRUN_WRAPPER, gen_dir / SCRRUN_FRIENDLY.name) + frd_mtime = (gen_dir / SCRRUN_FRIENDLY.name).stat().st_mtime_ns + shutil.copy2(STDOLE_FRIENDLY, gen_dir / STDOLE_FRIENDLY.name) + shutil.copy2(STDOLE_WRAPPER, gen_dir / STDOLE_WRAPPER.name) + comtypes.client.GetModule("scrrun.dll") + self.assertTrue((gen_dir / SCRRUN_WRAPPER.name).exists()) + self.assertGreater( + (gen_dir / SCRRUN_FRIENDLY.name).stat().st_mtime_ns, frd_mtime + ) + + def test_dependency_modules_are_missing(self): + with patch_gen_dir() as gen_dir: + shutil.copy2(SCRRUN_WRAPPER, gen_dir / SCRRUN_FRIENDLY.name) + frd_mtime = (gen_dir / SCRRUN_FRIENDLY.name).stat().st_mtime_ns + shutil.copy2(SCRRUN_WRAPPER, gen_dir / SCRRUN_WRAPPER.name) + wrp_mtime = (gen_dir / SCRRUN_WRAPPER.name).stat().st_mtime_ns + comtypes.client.GetModule("scrrun.dll") + self.assertTrue((gen_dir / STDOLE_FRIENDLY.name).exists()) + self.assertTrue((gen_dir / STDOLE_WRAPPER.name).exists()) + self.assertGreater( + (gen_dir / SCRRUN_FRIENDLY.name).stat().st_mtime_ns, frd_mtime + ) + self.assertGreater( + (gen_dir / SCRRUN_WRAPPER.name).stat().st_mtime_ns, wrp_mtime + ) + + +if __name__ == "__main__": + ut.main()