Skip to content

Commit

Permalink
Determine whether a COM interface is one of the known symbols not onl…
Browse files Browse the repository at this point in the history
…y by its name but also by using its iid (#529)

* change to constructor arguments

* add `_get_existing_friendly_module` and `_get_existing_wrapper_module`

* update `mscorlib` test

* replace `_get_known_symbols` with `_get_known_namespaces`

* update `codegenerator`
add `known_interfaces` to constructor args.
  • Loading branch information
junkmd authored Apr 16, 2024
1 parent 9ccfa6b commit c8f3e2e
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 48 deletions.
118 changes: 76 additions & 42 deletions comtypes/client/_generate.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import print_function
import ctypes
import importlib
import inspect
import logging
import os
import sys
import types
from typing import Any, Tuple, List, Optional, Dict, Union as _UnionT
from typing import Any, Tuple, List, Mapping, Optional, Dict, Union as _UnionT
import winreg

from comtypes import GUID, typeinfo
Expand Down Expand Up @@ -121,7 +122,7 @@ def GetModule(tlib: _UnionT[Any, typeinfo.ITypeLib]) -> types.ModuleType:
pathname = None
tlib = _load_tlib(tlib)
logger.debug("GetModule(%s)", tlib.GetLibAttr())
return ModuleGenerator().generate(tlib, pathname)
return ModuleGenerator(tlib, pathname).generate()


def _load_tlib(obj: Any) -> typeinfo.ITypeLib:
Expand Down Expand Up @@ -184,63 +185,91 @@ def _create_module(modulename: str, code: str) -> types.ModuleType:


class ModuleGenerator(object):
def __init__(self) -> None:
self.codegen = codegenerator.CodeGenerator(_get_known_symbols())
def __init__(self, tlib: typeinfo.ITypeLib, pathname: Optional[str]) -> None:
known_symbols, known_interfaces = _get_known_namespaces()
self.codegen = codegenerator.CodeGenerator(known_symbols, known_interfaces)
self.wrapper_name = codegenerator.name_wrapper_module(tlib)
self.friendly_name = codegenerator.name_friendly_module(tlib)
if pathname is None:
self.pathname = tlbparser.get_tlib_filename(tlib)
else:
self.pathname = pathname
self.tlib = tlib

def generate(
self, tlib: typeinfo.ITypeLib, pathname: Optional[str]
) -> types.ModuleType:
def generate(self) -> types.ModuleType:
# create and import the real typelib wrapper module
mod = self._create_wrapper_module(tlib, pathname)
# try to get the friendly-name, if not, returns the real typelib wrapper module
modulename = codegenerator.name_friendly_module(tlib)
if modulename is None:
mod = self._get_existing_wrapper_module()
if mod is None:
mod = self._create_wrapper_module()
if self.friendly_name is None:
return mod
# create and import the friendly-named module
return self._create_friendly_module(tlib, modulename)
mod = self._get_existing_friendly_module()
if mod is not None:
return mod
return self._create_friendly_module()

def _create_friendly_module(
self, tlib: typeinfo.ITypeLib, modulename: str
) -> types.ModuleType:
"""helper which creates and imports the friendly-named module."""
def _get_existing_friendly_module(self) -> Optional[types.ModuleType]:
if self.friendly_name is None:
return
try:
mod = _my_import(modulename)
mod = _my_import(self.friendly_name)
except Exception as details:
logger.info("Could not import %s: %s", modulename, details)
logger.info("Could not import %s: %s", self.friendly_name, details)
else:
return mod

def _create_friendly_module(self) -> types.ModuleType:
"""helper which creates and imports the friendly-named module."""
if self.friendly_name is None:
raise TypeError
# the module is always regenerated if the import fails
logger.info("# Generating %s", modulename)
logger.info("# Generating %s", self.friendly_name)
# determine the Python module name
modname = codegenerator.name_wrapper_module(tlib)
code = self.codegen.generate_friendly_code(modname)
return _create_module(modulename, code)
code = self.codegen.generate_friendly_code(self.wrapper_name)
return _create_module(self.friendly_name, code)

def _create_wrapper_module(
self, tlib: typeinfo.ITypeLib, pathname: Optional[str]
) -> types.ModuleType:
"""helper which creates and imports the real typelib wrapper module."""
modulename = codegenerator.name_wrapper_module(tlib)
if modulename in sys.modules:
return sys.modules[modulename]
def _get_existing_wrapper_module(self) -> Optional[types.ModuleType]:
if self.wrapper_name in sys.modules:
return sys.modules[self.wrapper_name]
try:
return _my_import(modulename)
return _my_import(self.wrapper_name)
except Exception as details:
logger.info("Could not import %s: %s", modulename, details)
logger.info("Could not import %s: %s", self.wrapper_name, details)

def _create_wrapper_module(self) -> types.ModuleType:
"""helper which creates and imports the real typelib wrapper module."""
# generate the module since it doesn't exist or is out of date
logger.info("# Generating %s", modulename)
p = tlbparser.TypeLibParser(tlib)
if pathname is None:
pathname = tlbparser.get_tlib_filename(tlib)
items = list(p.parse().values())
code = self.codegen.generate_wrapper_code(items, filename=pathname)
logger.info("# Generating %s", self.wrapper_name)
items = list(tlbparser.TypeLibParser(self.tlib).parse().values())
code = self.codegen.generate_wrapper_code(items, filename=self.pathname)
for ext_tlib in self.codegen.externals: # generates dependency COM-lib modules
GetModule(ext_tlib)
return _create_module(modulename, code)
return _create_module(self.wrapper_name, code)


_SymbolName = str
_ModuleName = str
_ItfName = str
_ItfIid = str

def _get_known_symbols() -> Dict[str, str]:
known_symbols: Dict[str, str] = {}

def _get_known_namespaces() -> Tuple[
Mapping[_SymbolName, _ModuleName], Mapping[_ItfName, _ItfIid]
]:
"""Returns symbols and interfaces that are already statically defined in `ctypes`
and `comtypes`.
From `ctypes`, all the names are obtained.
From `comtypes`, only the names in each module's `__known_symbols__` are obtained.
Note:
The interfaces that should be included in `__known_symbols__` should be limited
to those that can be said to be bound to the design concept of COM, such as
`IUnknown`, and those defined in `objidl` and `oaidl`.
`comtypes` does NOT aim to statically define all COM object interfaces in
its repository.
"""
known_symbols: Dict[_SymbolName, _ModuleName] = {}
known_interfaces: Dict[_ItfName, _ItfIid] = {}
for mod_name in (
"comtypes.persist",
"comtypes.typeinfo",
Expand All @@ -252,11 +281,16 @@ def _get_known_symbols() -> Dict[str, str]:
mod = importlib.import_module(mod_name)
if hasattr(mod, "__known_symbols__"):
names: List[str] = mod.__known_symbols__
for name in names:
tgt = getattr(mod, name)
if inspect.isclass(tgt) and issubclass(tgt, comtypes.IUnknown):
assert name not in known_interfaces
known_interfaces[name] = str(tgt._iid_)
else:
names = list(mod.__dict__)
for name in names:
known_symbols[name] = mod.__name__
return known_symbols
return known_symbols, known_interfaces


################################################################
Expand Down
11 changes: 8 additions & 3 deletions comtypes/test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,17 @@ def test_ptr_itypelib(self):
mod = comtypes.client.GetModule(typeinfo.LoadTypeLibEx("scrrun.dll"))
self.assertIs(mod, Scripting)

def test_imports_IEnumVARIANT_from_other_generated_modules(self):
def test_mscorlib(self):
# NOTE: `codegenerator` generates code that contains unused imports,
# but removing them are attracting wierd bugs in library-wrappers
# which depend on externals.
# NOTE: `mscorlib`, which imports `IEnumVARIANT` from `stdole`.
comtypes.client.GetModule(("{BED7F4EA-1A96-11D2-8F08-00A0C9A6186D}",))
# `mscorlib` imports `stdole` wrapper module and refers`IEnumVARIANT` from it.
mod = comtypes.client.GetModule(("{BED7F4EA-1A96-11D2-8F08-00A0C9A6186D}",))
# NOTE: `ModuleGenerator` treats the `ctypes._Pointer` base class for pointers
# as one of the known symbols, but `mscorlib` has the `_Pointer` com interface.
# Even though they have the same name, `codegenerator` generates code to define
# the `_Pointer` interface, rather than importing `_Pointer` from `ctypes`.
self.assertTrue(issubclass(mod._Pointer, comtypes.IUnknown))

def test_no_replacing_Patch_namespace(self):
# NOTE: An object named `Patch` is defined in some dll.
Expand Down
22 changes: 19 additions & 3 deletions comtypes/tools/codegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,14 +419,15 @@ def _get_common_elms(self) -> Tuple[List[_IdlFlagType], str, str]:


class CodeGenerator(object):
def __init__(self, known_symbols=None):
def __init__(self, known_symbols=None, known_interfaces=None) -> None:
self.stream = io.StringIO()
self.imports = ImportedNamespaces()
self.declarations = DeclaredNamespaces()
self.enums = EnumerationNamespaces()
self.unnamed_enum_members: List[Tuple[str, int]] = []
self._to_type_name = TypeNamer()
self.known_symbols = known_symbols or {}
self.known_interfaces = known_interfaces or {}

self.done = set() # type descriptions that have been generated
self.names = set() # names that have been generated
Expand All @@ -437,13 +438,20 @@ def __init__(self, known_symbols=None):
def generate(self, item):
if item in self.done:
return
if isinstance(item, typedesc.ComInterface):
if self._is_known_interface(item):
self.imports.add(item.name, symbols=self.known_symbols)
self.done.add(item)
return
self.done.add(item) # to avoid infinite recursion.
self.ComInterface(item)
return
if isinstance(item, typedesc.StructureHead):
name = getattr(item.struct, "name", None)
else:
name = getattr(item, "name", None)
if name in self.known_symbols:
self.imports.add(name, symbols=self.known_symbols)

self.done.add(item)
if isinstance(item, typedesc.Structure):
self.done.add(item.get_head())
Expand Down Expand Up @@ -1074,6 +1082,14 @@ def ComInterface(self, itf: typedesc.ComInterface) -> None:
self.generate(itf.get_body())
self.names.add(itf.name)

def _is_known_interface(self, item: typedesc.ComInterface) -> bool:
"""Returns whether an interface is statically defined in `comtypes`,
based on its name and iid.
"""
if item.name in self.known_interfaces:
return self.known_interfaces[item.name] == item.iid
return False

def _is_enuminterface(self, itf: typedesc.ComInterface) -> bool:
# Check if this is an IEnumXXX interface
if not itf.name.startswith("IEnum"):
Expand All @@ -1085,7 +1101,7 @@ def _is_enuminterface(self, itf: typedesc.ComInterface) -> bool:
return True

def ComInterfaceHead(self, head: typedesc.ComInterfaceHead) -> None:
if head.itf.name in self.known_symbols:
if self._is_known_interface(head.itf):
return
base = head.itf.base
if head.itf.base is None:
Expand Down

0 comments on commit c8f3e2e

Please sign in to comment.