diff --git a/comtypes/hints.pyi b/comtypes/hints.pyi index 30fb4906..7d7e5689 100644 --- a/comtypes/hints.pyi +++ b/comtypes/hints.pyi @@ -23,8 +23,10 @@ else: from typing_extensions import Protocol if sys.version_info >= (3, 10): from typing import Concatenate, ParamSpec, TypeAlias + from typing import TypeGuard as TypeGuard else: from typing_extensions import Concatenate, ParamSpec, TypeAlias + from typing_extensions import TypeGuard as TypeGuard if sys.version_info >= (3, 11): from typing import Self else: diff --git a/comtypes/tools/codegenerator.py b/comtypes/tools/codegenerator.py index afc14cc7..4dca21e2 100644 --- a/comtypes/tools/codegenerator.py +++ b/comtypes/tools/codegenerator.py @@ -421,6 +421,16 @@ def _get_common_elms(self) -> Tuple[List[_IdlFlagType], str, str]: return (idlflags, type_name, self._m.name) +_InterfaceTypeDesc = _UnionT[ + typedesc.ComInterface, + typedesc.ComInterfaceHead, + typedesc.ComInterfaceBody, + typedesc.DispInterface, + typedesc.DispInterfaceHead, + typedesc.DispInterfaceBody, +] + + class CodeGenerator(object): def __init__(self, known_symbols=None, known_interfaces=None) -> None: self.stream = io.StringIO() @@ -441,13 +451,8 @@ def __init__(self, known_symbols=None, known_interfaces=None) -> None: def generate(self, item): if item in self.done: return - if isinstance(item, typedesc.ComInterface): - if self._is_known_interface(item): - self.imports.add(item.name, symbols=self.known_symbols) - self.done.add(item) - return - self.done.add(item) # to avoid infinite recursion. - self.ComInterface(item) + if self._is_interface_typedesc(item): + self._define_interface(item) return if isinstance(item, typedesc.StructureHead): name = getattr(item.struct, "name", None) @@ -1080,12 +1085,56 @@ def CoClass(self, coclass: typedesc.CoClass) -> None: self.names.add(coclass.name) + def _is_interface_typedesc( + self, item: Any + ) -> "comtypes.hints.TypeGuard[_InterfaceTypeDesc]": + return isinstance( + item, + ( + typedesc.ComInterface, + typedesc.ComInterfaceHead, + typedesc.ComInterfaceBody, + typedesc.DispInterface, + typedesc.DispInterfaceHead, + typedesc.DispInterfaceBody, + ), + ) + + def _define_interface(self, item: _InterfaceTypeDesc) -> None: + if isinstance( + item, + ( + typedesc.ComInterfaceHead, + typedesc.ComInterfaceBody, + typedesc.DispInterfaceHead, + typedesc.DispInterfaceBody, + ), + ): + if self._is_known_interface(item.itf): + self.imports.add(item.itf.name, symbols=self.known_symbols) + self.done.add(item) + return + elif isinstance(item, (typedesc.ComInterface, typedesc.DispInterface)): + if self._is_known_interface(item): + self.imports.add(item.name, symbols=self.known_symbols) + self.done.add(item) + self.done.add(item.get_head()) + self.done.add(item.get_body()) + return + else: + raise TypeError + self.done.add(item) # to avoid infinite recursion. + mth = getattr(self, type(item).__name__) + mth(item) + def ComInterface(self, itf: typedesc.ComInterface) -> None: self.generate(itf.get_head()) self.generate(itf.get_body()) self.names.add(itf.name) - def _is_known_interface(self, item: typedesc.ComInterface) -> bool: + def _is_known_interface( + self, item: _UnionT[typedesc.ComInterface, typedesc.DispInterface] + ) -> bool: """Returns whether an interface is statically defined in `comtypes`, based on its name and iid. """ @@ -1104,14 +1153,11 @@ def _is_enuminterface(self, itf: typedesc.ComInterface) -> bool: return True def ComInterfaceHead(self, head: typedesc.ComInterfaceHead) -> None: - if self._is_known_interface(head.itf): - return - base = head.itf.base if head.itf.base is None: # we don't beed to generate IUnknown return - self.generate(base.get_head()) - self.more.add(base) + self.generate(head.itf.base.get_head()) + self.more.add(head.itf.base) basename = self._to_type_name(head.itf.base) self.imports.add("comtypes", "GUID") diff --git a/comtypes/tools/typedesc.py b/comtypes/tools/typedesc.py index 3085f0ea..8bf167cc 100644 --- a/comtypes/tools/typedesc.py +++ b/comtypes/tools/typedesc.py @@ -176,7 +176,7 @@ def __init__( self, name: str, members: List[ComMethod], - base: Any, + base: "Optional[ComInterface]", iid: str, idlflags: List[str], ) -> None: