From 468f050ca085f2e0bd46ec310cba4ec532fc7012 Mon Sep 17 00:00:00 2001 From: Jun Komoda <45822440+junkmd@users.noreply.github.com> Date: Tue, 30 Apr 2024 11:29:22 +0000 Subject: [PATCH 1/3] add conditional branching for `ComInterfaceHead` and `ComInterfaceBody` --- comtypes/tools/codegenerator.py | 25 ++++++++++++++++++++----- comtypes/tools/typedesc.py | 2 +- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/comtypes/tools/codegenerator.py b/comtypes/tools/codegenerator.py index afc14cc7..2b58b1b6 100644 --- a/comtypes/tools/codegenerator.py +++ b/comtypes/tools/codegenerator.py @@ -441,10 +441,28 @@ def __init__(self, known_symbols=None, known_interfaces=None) -> None: def generate(self, item): if item in self.done: return + if isinstance(item, typedesc.ComInterfaceHead): + if self._is_known_interface(item.itf): + self.imports.add(item.itf.name, symbols=self.known_symbols) + self.done.add(item) + return + self.done.add(item) + self.ComInterfaceHead(item) + return + if isinstance(item, typedesc.ComInterfaceBody): + if self._is_known_interface(item.itf): + self.imports.add(item.itf.name, symbols=self.known_symbols) + self.done.add(item) + return + self.done.add(item) + self.ComInterfaceBody(item) + return if isinstance(item, typedesc.ComInterface): if self._is_known_interface(item): self.imports.add(item.name, symbols=self.known_symbols) self.done.add(item) + self.done.add(item.get_head()) + self.done.add(item.get_body()) return self.done.add(item) # to avoid infinite recursion. self.ComInterface(item) @@ -1104,14 +1122,11 @@ def _is_enuminterface(self, itf: typedesc.ComInterface) -> bool: return True def ComInterfaceHead(self, head: typedesc.ComInterfaceHead) -> None: - if self._is_known_interface(head.itf): - return - base = head.itf.base if head.itf.base is None: # we don't beed to generate IUnknown return - self.generate(base.get_head()) - self.more.add(base) + self.generate(head.itf.base.get_head()) + self.more.add(head.itf.base) basename = self._to_type_name(head.itf.base) self.imports.add("comtypes", "GUID") diff --git a/comtypes/tools/typedesc.py b/comtypes/tools/typedesc.py index 3085f0ea..8bf167cc 100644 --- a/comtypes/tools/typedesc.py +++ b/comtypes/tools/typedesc.py @@ -176,7 +176,7 @@ def __init__( self, name: str, members: List[ComMethod], - base: Any, + base: "Optional[ComInterface]", iid: str, idlflags: List[str], ) -> None: From 3731c94067e59bd1bd49f33c59c2d295b887b58b Mon Sep 17 00:00:00 2001 From: Jun Komoda <45822440+junkmd@users.noreply.github.com> Date: Tue, 30 Apr 2024 11:29:22 +0000 Subject: [PATCH 2/3] add conditional branching for `DispInterfaceHead`, `DispInterfaceBody` and `DispInterface` --- comtypes/tools/codegenerator.py | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/comtypes/tools/codegenerator.py b/comtypes/tools/codegenerator.py index 2b58b1b6..a2981cff 100644 --- a/comtypes/tools/codegenerator.py +++ b/comtypes/tools/codegenerator.py @@ -467,6 +467,32 @@ def generate(self, item): self.done.add(item) # to avoid infinite recursion. self.ComInterface(item) return + if isinstance(item, typedesc.DispInterfaceHead): + if self._is_known_interface(item.itf): + self.imports.add(item.itf.name, symbols=self.known_symbols) + self.done.add(item) + return + self.done.add(item) + self.DispInterfaceHead(item) + return + if isinstance(item, typedesc.DispInterfaceBody): + if self._is_known_interface(item.itf): + self.imports.add(item.itf.name, symbols=self.known_symbols) + self.done.add(item) + return + self.done.add(item) + self.DispInterfaceBody(item) + return + if isinstance(item, typedesc.DispInterface): + if self._is_known_interface(item): + self.imports.add(item.name, symbols=self.known_symbols) + self.done.add(item) + self.done.add(item.get_head()) + self.done.add(item.get_body()) + return + self.done.add(item) # to avoid infinite recursion. + self.DispInterface(item) + return if isinstance(item, typedesc.StructureHead): name = getattr(item.struct, "name", None) else: @@ -1103,7 +1129,9 @@ def ComInterface(self, itf: typedesc.ComInterface) -> None: self.generate(itf.get_body()) self.names.add(itf.name) - def _is_known_interface(self, item: typedesc.ComInterface) -> bool: + def _is_known_interface( + self, item: _UnionT[typedesc.ComInterface, typedesc.DispInterface] + ) -> bool: """Returns whether an interface is statically defined in `comtypes`, based on its name and iid. """ From a8d6eb2d07f1f5254e9cbb861964ce2e5705770b Mon Sep 17 00:00:00 2001 From: Jun Komoda <45822440+junkmd@users.noreply.github.com> Date: Tue, 30 Apr 2024 11:29:22 +0000 Subject: [PATCH 3/3] make defining interfaces DRY and add TypeGuard --- comtypes/hints.pyi | 2 + comtypes/tools/codegenerator.py | 105 ++++++++++++++++---------------- 2 files changed, 56 insertions(+), 51 deletions(-) diff --git a/comtypes/hints.pyi b/comtypes/hints.pyi index 30fb4906..7d7e5689 100644 --- a/comtypes/hints.pyi +++ b/comtypes/hints.pyi @@ -23,8 +23,10 @@ else: from typing_extensions import Protocol if sys.version_info >= (3, 10): from typing import Concatenate, ParamSpec, TypeAlias + from typing import TypeGuard as TypeGuard else: from typing_extensions import Concatenate, ParamSpec, TypeAlias + from typing_extensions import TypeGuard as TypeGuard if sys.version_info >= (3, 11): from typing import Self else: diff --git a/comtypes/tools/codegenerator.py b/comtypes/tools/codegenerator.py index a2981cff..4dca21e2 100644 --- a/comtypes/tools/codegenerator.py +++ b/comtypes/tools/codegenerator.py @@ -421,6 +421,16 @@ def _get_common_elms(self) -> Tuple[List[_IdlFlagType], str, str]: return (idlflags, type_name, self._m.name) +_InterfaceTypeDesc = _UnionT[ + typedesc.ComInterface, + typedesc.ComInterfaceHead, + typedesc.ComInterfaceBody, + typedesc.DispInterface, + typedesc.DispInterfaceHead, + typedesc.DispInterfaceBody, +] + + class CodeGenerator(object): def __init__(self, known_symbols=None, known_interfaces=None) -> None: self.stream = io.StringIO() @@ -441,57 +451,8 @@ def __init__(self, known_symbols=None, known_interfaces=None) -> None: def generate(self, item): if item in self.done: return - if isinstance(item, typedesc.ComInterfaceHead): - if self._is_known_interface(item.itf): - self.imports.add(item.itf.name, symbols=self.known_symbols) - self.done.add(item) - return - self.done.add(item) - self.ComInterfaceHead(item) - return - if isinstance(item, typedesc.ComInterfaceBody): - if self._is_known_interface(item.itf): - self.imports.add(item.itf.name, symbols=self.known_symbols) - self.done.add(item) - return - self.done.add(item) - self.ComInterfaceBody(item) - return - if isinstance(item, typedesc.ComInterface): - if self._is_known_interface(item): - self.imports.add(item.name, symbols=self.known_symbols) - self.done.add(item) - self.done.add(item.get_head()) - self.done.add(item.get_body()) - return - self.done.add(item) # to avoid infinite recursion. - self.ComInterface(item) - return - if isinstance(item, typedesc.DispInterfaceHead): - if self._is_known_interface(item.itf): - self.imports.add(item.itf.name, symbols=self.known_symbols) - self.done.add(item) - return - self.done.add(item) - self.DispInterfaceHead(item) - return - if isinstance(item, typedesc.DispInterfaceBody): - if self._is_known_interface(item.itf): - self.imports.add(item.itf.name, symbols=self.known_symbols) - self.done.add(item) - return - self.done.add(item) - self.DispInterfaceBody(item) - return - if isinstance(item, typedesc.DispInterface): - if self._is_known_interface(item): - self.imports.add(item.name, symbols=self.known_symbols) - self.done.add(item) - self.done.add(item.get_head()) - self.done.add(item.get_body()) - return - self.done.add(item) # to avoid infinite recursion. - self.DispInterface(item) + if self._is_interface_typedesc(item): + self._define_interface(item) return if isinstance(item, typedesc.StructureHead): name = getattr(item.struct, "name", None) @@ -1124,6 +1085,48 @@ def CoClass(self, coclass: typedesc.CoClass) -> None: self.names.add(coclass.name) + def _is_interface_typedesc( + self, item: Any + ) -> "comtypes.hints.TypeGuard[_InterfaceTypeDesc]": + return isinstance( + item, + ( + typedesc.ComInterface, + typedesc.ComInterfaceHead, + typedesc.ComInterfaceBody, + typedesc.DispInterface, + typedesc.DispInterfaceHead, + typedesc.DispInterfaceBody, + ), + ) + + def _define_interface(self, item: _InterfaceTypeDesc) -> None: + if isinstance( + item, + ( + typedesc.ComInterfaceHead, + typedesc.ComInterfaceBody, + typedesc.DispInterfaceHead, + typedesc.DispInterfaceBody, + ), + ): + if self._is_known_interface(item.itf): + self.imports.add(item.itf.name, symbols=self.known_symbols) + self.done.add(item) + return + elif isinstance(item, (typedesc.ComInterface, typedesc.DispInterface)): + if self._is_known_interface(item): + self.imports.add(item.name, symbols=self.known_symbols) + self.done.add(item) + self.done.add(item.get_head()) + self.done.add(item.get_body()) + return + else: + raise TypeError + self.done.add(item) # to avoid infinite recursion. + mth = getattr(self, type(item).__name__) + mth(item) + def ComInterface(self, itf: typedesc.ComInterface) -> None: self.generate(itf.get_head()) self.generate(itf.get_body())