Skip to content

Commit

Permalink
update codegenerator
Browse files Browse the repository at this point in the history
add `known_interfaces` to constructor args.
  • Loading branch information
junkmd committed Apr 15, 2024
1 parent 56cbb9e commit b24c2c4
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
4 changes: 2 additions & 2 deletions comtypes/client/_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,8 @@ def _create_module(modulename: str, code: str) -> types.ModuleType:

class ModuleGenerator(object):
def __init__(self, tlib: typeinfo.ITypeLib, pathname: Optional[str]) -> None:
known_symbols, _ = _get_known_namespaces()
self.codegen = codegenerator.CodeGenerator(known_symbols)
known_symbols, known_interfaces = _get_known_namespaces()
self.codegen = codegenerator.CodeGenerator(known_symbols, known_interfaces)
self.wrapper_name = codegenerator.name_wrapper_module(tlib)
self.friendly_name = codegenerator.name_friendly_module(tlib)
if pathname is None:
Expand Down
22 changes: 19 additions & 3 deletions comtypes/tools/codegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,14 +419,15 @@ def _get_common_elms(self) -> Tuple[List[_IdlFlagType], str, str]:


class CodeGenerator(object):
def __init__(self, known_symbols=None):
def __init__(self, known_symbols=None, known_interfaces=None) -> None:
self.stream = io.StringIO()
self.imports = ImportedNamespaces()
self.declarations = DeclaredNamespaces()
self.enums = EnumerationNamespaces()
self.unnamed_enum_members: List[Tuple[str, int]] = []
self._to_type_name = TypeNamer()
self.known_symbols = known_symbols or {}
self.known_interfaces = known_interfaces or {}

self.done = set() # type descriptions that have been generated
self.names = set() # names that have been generated
Expand All @@ -437,13 +438,20 @@ def __init__(self, known_symbols=None):
def generate(self, item):
if item in self.done:
return
if isinstance(item, typedesc.ComInterface):
if self._is_known_interface(item):
self.imports.add(item.name, symbols=self.known_symbols)
self.done.add(item)
return
self.done.add(item) # to avoid infinite recursion.
self.ComInterface(item)
return
if isinstance(item, typedesc.StructureHead):
name = getattr(item.struct, "name", None)
else:
name = getattr(item, "name", None)
if name in self.known_symbols:
self.imports.add(name, symbols=self.known_symbols)

self.done.add(item)
if isinstance(item, typedesc.Structure):
self.done.add(item.get_head())
Expand Down Expand Up @@ -1074,6 +1082,14 @@ def ComInterface(self, itf: typedesc.ComInterface) -> None:
self.generate(itf.get_body())
self.names.add(itf.name)

def _is_known_interface(self, item: typedesc.ComInterface) -> bool:
"""Returns whether an interface is statically defined in `comtypes`,
based on its name and iid.
"""
if item.name in self.known_interfaces:
return self.known_interfaces[item.name] == item.iid
return False

def _is_enuminterface(self, itf: typedesc.ComInterface) -> bool:
# Check if this is an IEnumXXX interface
if not itf.name.startswith("IEnum"):
Expand All @@ -1085,7 +1101,7 @@ def _is_enuminterface(self, itf: typedesc.ComInterface) -> bool:
return True

def ComInterfaceHead(self, head: typedesc.ComInterfaceHead) -> None:
if head.itf.name in self.known_symbols:
if self._is_known_interface(head.itf):
return
base = head.itf.base
if head.itf.base is None:
Expand Down

0 comments on commit b24c2c4

Please sign in to comment.