diff --git a/comtypes/client/_generate.py b/comtypes/client/_generate.py index a1e85663..3924b997 100644 --- a/comtypes/client/_generate.py +++ b/comtypes/client/_generate.py @@ -1,11 +1,12 @@ from __future__ import print_function import ctypes import importlib +import inspect import logging import os import sys import types -from typing import Any, Tuple, List, Optional, Dict, Union as _UnionT +from typing import Any, Tuple, List, Mapping, Optional, Dict, Union as _UnionT import winreg from comtypes import GUID, typeinfo @@ -121,7 +122,7 @@ def GetModule(tlib: _UnionT[Any, typeinfo.ITypeLib]) -> types.ModuleType: pathname = None tlib = _load_tlib(tlib) logger.debug("GetModule(%s)", tlib.GetLibAttr()) - return ModuleGenerator().generate(tlib, pathname) + return ModuleGenerator(tlib, pathname).generate() def _load_tlib(obj: Any) -> typeinfo.ITypeLib: @@ -184,63 +185,91 @@ def _create_module(modulename: str, code: str) -> types.ModuleType: class ModuleGenerator(object): - def __init__(self) -> None: - self.codegen = codegenerator.CodeGenerator(_get_known_symbols()) + def __init__(self, tlib: typeinfo.ITypeLib, pathname: Optional[str]) -> None: + known_symbols, known_interfaces = _get_known_namespaces() + self.codegen = codegenerator.CodeGenerator(known_symbols, known_interfaces) + self.wrapper_name = codegenerator.name_wrapper_module(tlib) + self.friendly_name = codegenerator.name_friendly_module(tlib) + if pathname is None: + self.pathname = tlbparser.get_tlib_filename(tlib) + else: + self.pathname = pathname + self.tlib = tlib - def generate( - self, tlib: typeinfo.ITypeLib, pathname: Optional[str] - ) -> types.ModuleType: + def generate(self) -> types.ModuleType: # create and import the real typelib wrapper module - mod = self._create_wrapper_module(tlib, pathname) - # try to get the friendly-name, if not, returns the real typelib wrapper module - modulename = codegenerator.name_friendly_module(tlib) - if modulename is None: + mod = self._get_existing_wrapper_module() + if mod is None: + mod = self._create_wrapper_module() + if self.friendly_name is None: return mod - # create and import the friendly-named module - return self._create_friendly_module(tlib, modulename) + mod = self._get_existing_friendly_module() + if mod is not None: + return mod + return self._create_friendly_module() - def _create_friendly_module( - self, tlib: typeinfo.ITypeLib, modulename: str - ) -> types.ModuleType: - """helper which creates and imports the friendly-named module.""" + def _get_existing_friendly_module(self) -> Optional[types.ModuleType]: + if self.friendly_name is None: + return try: - mod = _my_import(modulename) + mod = _my_import(self.friendly_name) except Exception as details: - logger.info("Could not import %s: %s", modulename, details) + logger.info("Could not import %s: %s", self.friendly_name, details) else: return mod + + def _create_friendly_module(self) -> types.ModuleType: + """helper which creates and imports the friendly-named module.""" + if self.friendly_name is None: + raise TypeError # the module is always regenerated if the import fails - logger.info("# Generating %s", modulename) + logger.info("# Generating %s", self.friendly_name) # determine the Python module name - modname = codegenerator.name_wrapper_module(tlib) - code = self.codegen.generate_friendly_code(modname) - return _create_module(modulename, code) + code = self.codegen.generate_friendly_code(self.wrapper_name) + return _create_module(self.friendly_name, code) - def _create_wrapper_module( - self, tlib: typeinfo.ITypeLib, pathname: Optional[str] - ) -> types.ModuleType: - """helper which creates and imports the real typelib wrapper module.""" - modulename = codegenerator.name_wrapper_module(tlib) - if modulename in sys.modules: - return sys.modules[modulename] + def _get_existing_wrapper_module(self) -> Optional[types.ModuleType]: + if self.wrapper_name in sys.modules: + return sys.modules[self.wrapper_name] try: - return _my_import(modulename) + return _my_import(self.wrapper_name) except Exception as details: - logger.info("Could not import %s: %s", modulename, details) + logger.info("Could not import %s: %s", self.wrapper_name, details) + + def _create_wrapper_module(self) -> types.ModuleType: + """helper which creates and imports the real typelib wrapper module.""" # generate the module since it doesn't exist or is out of date - logger.info("# Generating %s", modulename) - p = tlbparser.TypeLibParser(tlib) - if pathname is None: - pathname = tlbparser.get_tlib_filename(tlib) - items = list(p.parse().values()) - code = self.codegen.generate_wrapper_code(items, filename=pathname) + logger.info("# Generating %s", self.wrapper_name) + items = list(tlbparser.TypeLibParser(self.tlib).parse().values()) + code = self.codegen.generate_wrapper_code(items, filename=self.pathname) for ext_tlib in self.codegen.externals: # generates dependency COM-lib modules GetModule(ext_tlib) - return _create_module(modulename, code) + return _create_module(self.wrapper_name, code) + +_SymbolName = str +_ModuleName = str +_ItfName = str +_ItfIid = str -def _get_known_symbols() -> Dict[str, str]: - known_symbols: Dict[str, str] = {} + +def _get_known_namespaces() -> Tuple[ + Mapping[_SymbolName, _ModuleName], Mapping[_ItfName, _ItfIid] +]: + """Returns symbols and interfaces that are already statically defined in `ctypes` + and `comtypes`. + From `ctypes`, all the names are obtained. + From `comtypes`, only the names in each module's `__known_symbols__` are obtained. + + Note: + The interfaces that should be included in `__known_symbols__` should be limited + to those that can be said to be bound to the design concept of COM, such as + `IUnknown`, and those defined in `objidl` and `oaidl`. + `comtypes` does NOT aim to statically define all COM object interfaces in + its repository. + """ + known_symbols: Dict[_SymbolName, _ModuleName] = {} + known_interfaces: Dict[_ItfName, _ItfIid] = {} for mod_name in ( "comtypes.persist", "comtypes.typeinfo", @@ -252,11 +281,16 @@ def _get_known_symbols() -> Dict[str, str]: mod = importlib.import_module(mod_name) if hasattr(mod, "__known_symbols__"): names: List[str] = mod.__known_symbols__ + for name in names: + tgt = getattr(mod, name) + if inspect.isclass(tgt) and issubclass(tgt, comtypes.IUnknown): + assert name not in known_interfaces + known_interfaces[name] = str(tgt._iid_) else: names = list(mod.__dict__) for name in names: known_symbols[name] = mod.__name__ - return known_symbols + return known_symbols, known_interfaces ################################################################ diff --git a/comtypes/test/test_client.py b/comtypes/test/test_client.py index bcf0195e..7ffa06fd 100644 --- a/comtypes/test/test_client.py +++ b/comtypes/test/test_client.py @@ -59,12 +59,17 @@ def test_ptr_itypelib(self): mod = comtypes.client.GetModule(typeinfo.LoadTypeLibEx("scrrun.dll")) self.assertIs(mod, Scripting) - def test_imports_IEnumVARIANT_from_other_generated_modules(self): + def test_mscorlib(self): # NOTE: `codegenerator` generates code that contains unused imports, # but removing them are attracting wierd bugs in library-wrappers # which depend on externals. - # NOTE: `mscorlib`, which imports `IEnumVARIANT` from `stdole`. - comtypes.client.GetModule(("{BED7F4EA-1A96-11D2-8F08-00A0C9A6186D}",)) + # `mscorlib` imports `stdole` wrapper module and refers`IEnumVARIANT` from it. + mod = comtypes.client.GetModule(("{BED7F4EA-1A96-11D2-8F08-00A0C9A6186D}",)) + # NOTE: `ModuleGenerator` treats the `ctypes._Pointer` base class for pointers + # as one of the known symbols, but `mscorlib` has the `_Pointer` com interface. + # Even though they have the same name, `codegenerator` generates code to define + # the `_Pointer` interface, rather than importing `_Pointer` from `ctypes`. + self.assertTrue(issubclass(mod._Pointer, comtypes.IUnknown)) def test_no_replacing_Patch_namespace(self): # NOTE: An object named `Patch` is defined in some dll. diff --git a/comtypes/tools/codegenerator.py b/comtypes/tools/codegenerator.py index f16ca607..273ed1d4 100644 --- a/comtypes/tools/codegenerator.py +++ b/comtypes/tools/codegenerator.py @@ -419,7 +419,7 @@ def _get_common_elms(self) -> Tuple[List[_IdlFlagType], str, str]: class CodeGenerator(object): - def __init__(self, known_symbols=None): + def __init__(self, known_symbols=None, known_interfaces=None) -> None: self.stream = io.StringIO() self.imports = ImportedNamespaces() self.declarations = DeclaredNamespaces() @@ -427,6 +427,7 @@ def __init__(self, known_symbols=None): self.unnamed_enum_members: List[Tuple[str, int]] = [] self._to_type_name = TypeNamer() self.known_symbols = known_symbols or {} + self.known_interfaces = known_interfaces or {} self.done = set() # type descriptions that have been generated self.names = set() # names that have been generated @@ -437,13 +438,20 @@ def __init__(self, known_symbols=None): def generate(self, item): if item in self.done: return + if isinstance(item, typedesc.ComInterface): + if self._is_known_interface(item): + self.imports.add(item.name, symbols=self.known_symbols) + self.done.add(item) + return + self.done.add(item) # to avoid infinite recursion. + self.ComInterface(item) + return if isinstance(item, typedesc.StructureHead): name = getattr(item.struct, "name", None) else: name = getattr(item, "name", None) if name in self.known_symbols: self.imports.add(name, symbols=self.known_symbols) - self.done.add(item) if isinstance(item, typedesc.Structure): self.done.add(item.get_head()) @@ -1074,6 +1082,14 @@ def ComInterface(self, itf: typedesc.ComInterface) -> None: self.generate(itf.get_body()) self.names.add(itf.name) + def _is_known_interface(self, item: typedesc.ComInterface) -> bool: + """Returns whether an interface is statically defined in `comtypes`, + based on its name and iid. + """ + if item.name in self.known_interfaces: + return self.known_interfaces[item.name] == item.iid + return False + def _is_enuminterface(self, itf: typedesc.ComInterface) -> bool: # Check if this is an IEnumXXX interface if not itf.name.startswith("IEnum"): @@ -1085,7 +1101,7 @@ def _is_enuminterface(self, itf: typedesc.ComInterface) -> bool: return True def ComInterfaceHead(self, head: typedesc.ComInterfaceHead) -> None: - if head.itf.name in self.known_symbols: + if self._is_known_interface(head.itf): return base = head.itf.base if head.itf.base is None: