Skip to content

Commit

Permalink
Prevent the occurrence of SyntaxError in friendly modules. (#533)
Browse files Browse the repository at this point in the history
* fix `ModuleGenerator` and unify `_create_module_...`

* add `test_client_regenerate_modules.py`

* Remove the responsibility of loading existing modules from `ModuleGenerator`.
Instead, `GetModule` attempts to load the existing modules.

* small fix

* move the lines defining `_get_existing_module` upward.

* small fix
  • Loading branch information
junkmd authored May 5, 2024
1 parent efd900d commit 6b81d07
Show file tree
Hide file tree
Showing 2 changed files with 203 additions and 47 deletions.
92 changes: 45 additions & 47 deletions comtypes/client/_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ def GetModule(tlib: _UnionT[Any, typeinfo.ITypeLib]) -> types.ModuleType:
pathname = None
tlib = _load_tlib(tlib)
logger.debug("GetModule(%s)", tlib.GetLibAttr())
mod = _get_existing_module(tlib)
if mod is not None:
return mod
return ModuleGenerator(tlib, pathname).generate()


Expand Down Expand Up @@ -161,6 +164,36 @@ def _load_tlib(obj: Any) -> typeinfo.ITypeLib:
raise TypeError("'%r' is not supported type for loading typelib" % obj)


def _get_existing_module(tlib: typeinfo.ITypeLib) -> Optional[types.ModuleType]:
def _get_friendly(name: str) -> Optional[types.ModuleType]:
try:
mod = _my_import(name)
except Exception as details:
logger.info("Could not import %s: %s", friendly_name, details)
else:
return mod

def _get_wrapper(name: str) -> Optional[types.ModuleType]:
if name in sys.modules:
return sys.modules[name]
try:
return _my_import(name)
except Exception as details:
logger.info("Could not import %s: %s", name, details)

wrapper_name = codegenerator.name_wrapper_module(tlib)
friendly_name = codegenerator.name_friendly_module(tlib)
wrapper_module = _get_wrapper(wrapper_name)
if wrapper_module is not None:
if friendly_name is None:
return wrapper_module
else:
friendly_module = _get_friendly(friendly_name)
if friendly_module is not None:
return friendly_module
return None


def _create_module(modulename: str, code: str) -> types.ModuleType:
"""Creates the module, then imports it."""
# `modulename` is 'comtypes.gen.xxx'
Expand All @@ -186,8 +219,6 @@ def _create_module(modulename: str, code: str) -> types.ModuleType:

class ModuleGenerator(object):
def __init__(self, tlib: typeinfo.ITypeLib, pathname: Optional[str]) -> None:
known_symbols, known_interfaces = _get_known_namespaces()
self.codegen = codegenerator.CodeGenerator(known_symbols, known_interfaces)
self.wrapper_name = codegenerator.name_wrapper_module(tlib)
self.friendly_name = codegenerator.name_friendly_module(tlib)
if pathname is None:
Expand All @@ -197,54 +228,21 @@ def __init__(self, tlib: typeinfo.ITypeLib, pathname: Optional[str]) -> None:
self.tlib = tlib

def generate(self) -> types.ModuleType:
# create and import the real typelib wrapper module
mod = self._get_existing_wrapper_module()
if mod is None:
mod = self._create_wrapper_module()
if self.friendly_name is None:
return mod
mod = self._get_existing_friendly_module()
if mod is not None:
return mod
return self._create_friendly_module()

def _get_existing_friendly_module(self) -> Optional[types.ModuleType]:
if self.friendly_name is None:
return
try:
mod = _my_import(self.friendly_name)
except Exception as details:
logger.info("Could not import %s: %s", self.friendly_name, details)
else:
return mod

def _create_friendly_module(self) -> types.ModuleType:
"""helper which creates and imports the friendly-named module."""
if self.friendly_name is None:
raise TypeError
# the module is always regenerated if the import fails
logger.info("# Generating %s", self.friendly_name)
# determine the Python module name
code = self.codegen.generate_friendly_code(self.wrapper_name)
return _create_module(self.friendly_name, code)

def _get_existing_wrapper_module(self) -> Optional[types.ModuleType]:
if self.wrapper_name in sys.modules:
return sys.modules[self.wrapper_name]
try:
return _my_import(self.wrapper_name)
except Exception as details:
logger.info("Could not import %s: %s", self.wrapper_name, details)

def _create_wrapper_module(self) -> types.ModuleType:
"""helper which creates and imports the real typelib wrapper module."""
# generate the module since it doesn't exist or is out of date
"""Generates wrapper and friendly modules."""
known_symbols, known_interfaces = _get_known_namespaces()
codegen = codegenerator.CodeGenerator(known_symbols, known_interfaces)
codebases: List[Tuple[str, str]] = []
logger.info("# Generating %s", self.wrapper_name)
items = list(tlbparser.TypeLibParser(self.tlib).parse().values())
code = self.codegen.generate_wrapper_code(items, filename=self.pathname)
for ext_tlib in self.codegen.externals: # generates dependency COM-lib modules
wrp_code = codegen.generate_wrapper_code(items, filename=self.pathname)
codebases.append((self.wrapper_name, wrp_code))
if self.friendly_name is not None:
logger.info("# Generating %s", self.friendly_name)
frd_code = codegen.generate_friendly_code(self.wrapper_name)
codebases.append((self.friendly_name, frd_code))
for ext_tlib in codegen.externals: # generates dependency COM-lib modules
GetModule(ext_tlib)
return _create_module(self.wrapper_name, code)
return [_create_module(name, code) for (name, code) in codebases][-1]


_SymbolName = str
Expand Down
158 changes: 158 additions & 0 deletions comtypes/test/test_client_regenerate_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import contextlib
import importlib
from pathlib import Path
import shutil
import sys
import tempfile
import types
from typing import Iterator
import unittest as ut
from unittest import mock

import comtypes
import comtypes.client
import comtypes.gen

comtypes.client.GetModule("scrrun.dll")
from comtypes.gen import Scripting # noqa
from comtypes.gen import stdole # noqa


SCRRUN_FRIENDLY = Path(Scripting.__file__)
SCRRUN_WRAPPER = Path(Scripting.__wrapper_module__.__file__)
STDOLE_FRIENDLY = Path(stdole.__file__)
STDOLE_WRAPPER = Path(stdole.__wrapper_module__.__file__)


@contextlib.contextmanager
def _mkdtmp_gen_dir() -> Iterator[Path]:
with tempfile.TemporaryDirectory() as t:
tmp_dir = Path(t)
tmp_comtypes_dir = tmp_dir / "comtypes"
tmp_comtypes_dir.mkdir()
(tmp_comtypes_dir / "__init__.py").touch()
tmp_comtypes_gen_dir = tmp_comtypes_dir / "gen"
tmp_comtypes_gen_dir.mkdir()
(tmp_comtypes_gen_dir / "__init__.py").touch()
yield tmp_comtypes_gen_dir


@contextlib.contextmanager
def _patch_gen_pkg(new_path: Path) -> Iterator[types.ModuleType]:
new_comtypes_init = (new_path / "comtypes" / "__init__.py").resolve()
assert new_comtypes_init.exists()
new_comtypes_gen_init = (new_path / "comtypes" / "gen" / "__init__.py").resolve()
assert new_comtypes_gen_init.exists()
orig_comtypes = sys.modules["comtypes"]
orig_gen_names = list(filter(lambda k: k.startswith("comtypes.gen"), sys.modules))
tmp_sys_path = list(sys.path) # copy
with mock.patch.object(sys, "path", tmp_sys_path):
sys.path.insert(0, str(new_path))
with mock.patch.dict(sys.modules):
# The reason for removing the parent module (in this case, `comtypes`)
# from `sys.modules` is because the child module (in this case,
# `comtypes.gen`) refers to the namespace of the parent module.
# If the parent module exists in `sys.modules`, Python uses that cache
# to import the child module. Therefore, in order to import a new version
# of the child module, it is necessary to temporarily remove the parent
# module from `sys.modules`.
del sys.modules["comtypes"]
for k in orig_gen_names:
del sys.modules[k]
# The module that is imported here is not the one cached in `sys.modules`
# before the patch, but the module that is newly loaded from
# `new_path / 'comtypes' / 'gen' / '__init__.py'`.
new_comtypes_gen = importlib.import_module("comtypes.gen")
assert new_comtypes_gen.__file__ is not None
assert Path(new_comtypes_gen.__file__).resolve() == new_comtypes_gen_init
# The `comtypes` module cached in `sys.modules` as a side effect of
# executing the above line is empty because it is the one loaded from
# `new_path / 'comtypes' / '__init__.py'`.
# If we call the test target as it is, an error will occur due to
# referencing an empty module, so we restore the original `comtypes`
# to `sys.modules`.
sys.modules["comtypes"] = orig_comtypes
assert sys.modules["comtypes.gen"] is new_comtypes_gen
# By making the empty `comtypes.gen` package we created earlier to be
# referenced as the `gen` attribute of `comtypes`, the original
# `comtypes.gen` will not be referenced within the context.
with mock.patch.object(orig_comtypes, "gen", new_comtypes_gen):
yield new_comtypes_gen


@contextlib.contextmanager
def patch_gen_dir() -> Iterator[Path]:
with _mkdtmp_gen_dir() as tmp_gen_dir:
with mock.patch.object(comtypes.client, "gen_dir", str(tmp_gen_dir)):
try:
with _patch_gen_pkg(tmp_gen_dir.parent.parent):
yield tmp_gen_dir
finally:
importlib.invalidate_caches()
importlib.reload(comtypes.gen)
importlib.reload(stdole)
importlib.reload(Scripting)


class Test(ut.TestCase):
def test_all_modules_are_missing(self):
with patch_gen_dir() as gen_dir:
# ensure `gen_dir` and `sys.modules` are patched.
with self.assertRaises(ImportError):
from comtypes.gen import Scripting as _ # noqa
self.assertFalse((gen_dir / SCRRUN_FRIENDLY.name).exists())
self.assertFalse((gen_dir / SCRRUN_WRAPPER.name).exists())
self.assertFalse((gen_dir / STDOLE_FRIENDLY.name).exists())
self.assertFalse((gen_dir / STDOLE_WRAPPER.name).exists())
# generate new files and modules.
comtypes.client.GetModule("scrrun.dll")
self.assertTrue((gen_dir / SCRRUN_FRIENDLY.name).exists())
self.assertTrue((gen_dir / SCRRUN_WRAPPER.name).exists())
self.assertTrue((gen_dir / STDOLE_FRIENDLY.name).exists())
self.assertTrue((gen_dir / STDOLE_WRAPPER.name).exists())

def test_friendly_module_is_missing(self):
with patch_gen_dir() as gen_dir:
shutil.copy2(SCRRUN_WRAPPER, gen_dir / SCRRUN_WRAPPER.name)
wrp_mtime = (gen_dir / SCRRUN_WRAPPER.name).stat().st_mtime_ns
shutil.copy2(STDOLE_FRIENDLY, gen_dir / STDOLE_FRIENDLY.name)
shutil.copy2(STDOLE_WRAPPER, gen_dir / STDOLE_WRAPPER.name)
comtypes.client.GetModule("scrrun.dll")
self.assertTrue((gen_dir / SCRRUN_FRIENDLY.name).exists())
# Check the most recent content modification time to confirm whether
# the module file has been regenerated.
self.assertGreater(
(gen_dir / SCRRUN_WRAPPER.name).stat().st_mtime_ns, wrp_mtime
)

def test_wrapper_module_is_missing(self):
with patch_gen_dir() as gen_dir:
shutil.copy2(SCRRUN_WRAPPER, gen_dir / SCRRUN_FRIENDLY.name)
frd_mtime = (gen_dir / SCRRUN_FRIENDLY.name).stat().st_mtime_ns
shutil.copy2(STDOLE_FRIENDLY, gen_dir / STDOLE_FRIENDLY.name)
shutil.copy2(STDOLE_WRAPPER, gen_dir / STDOLE_WRAPPER.name)
comtypes.client.GetModule("scrrun.dll")
self.assertTrue((gen_dir / SCRRUN_WRAPPER.name).exists())
self.assertGreater(
(gen_dir / SCRRUN_FRIENDLY.name).stat().st_mtime_ns, frd_mtime
)

def test_dependency_modules_are_missing(self):
with patch_gen_dir() as gen_dir:
shutil.copy2(SCRRUN_WRAPPER, gen_dir / SCRRUN_FRIENDLY.name)
frd_mtime = (gen_dir / SCRRUN_FRIENDLY.name).stat().st_mtime_ns
shutil.copy2(SCRRUN_WRAPPER, gen_dir / SCRRUN_WRAPPER.name)
wrp_mtime = (gen_dir / SCRRUN_WRAPPER.name).stat().st_mtime_ns
comtypes.client.GetModule("scrrun.dll")
self.assertTrue((gen_dir / STDOLE_FRIENDLY.name).exists())
self.assertTrue((gen_dir / STDOLE_WRAPPER.name).exists())
self.assertGreater(
(gen_dir / SCRRUN_FRIENDLY.name).stat().st_mtime_ns, frd_mtime
)
self.assertGreater(
(gen_dir / SCRRUN_WRAPPER.name).stat().st_mtime_ns, wrp_mtime
)


if __name__ == "__main__":
ut.main()

0 comments on commit 6b81d07

Please sign in to comment.