diff --git a/comtypes/client/_events.py b/comtypes/client/_events.py index a2f5397b..80f9a3b9 100644 --- a/comtypes/client/_events.py +++ b/comtypes/client/_events.py @@ -12,12 +12,14 @@ LPVOID, ULONG, ) +from typing import Any, Callable, Optional, Protocol, Type, TypeVar import comtypes import comtypes.automation import comtypes.connectionpoints import comtypes.hresult import comtypes.typeinfo +from comtypes import COMObject, IUnknown from comtypes.client._generate import GetModule logger = logging.getLogger(__name__) @@ -58,21 +60,38 @@ class SECURITY_ATTRIBUTES(Structure): _CloseHandle.restype = BOOL +_T_IUnknown = TypeVar("_T_IUnknown", bound=IUnknown) + + +class _SupportsQueryInterface(Protocol): + def QueryInterface(self, interface: Type[_T_IUnknown]) -> _T_IUnknown: ... + + class _AdviseConnection(object): - def __init__(self, source, interface, receiver): + def __init__( + self, + source: IUnknown, + interface: Type[IUnknown], + receiver: _SupportsQueryInterface, + ) -> None: self.cp = None self.cookie = None self.receiver = None self._connect(source, interface, receiver) - def _connect(self, source, interface, receiver): + def _connect( + self, + source: IUnknown, + interface: Type[IUnknown], + receiver: _SupportsQueryInterface, + ) -> None: cpc = source.QueryInterface(comtypes.connectionpoints.IConnectionPointContainer) self.cp = cpc.FindConnectionPoint(ctypes.byref(interface._iid_)) logger.debug("Start advise %s", interface) self.cookie = self.cp.Advise(receiver) self.receiver = receiver - def disconnect(self): + def disconnect(self) -> None: if self.cookie: self.cp.Unadvise(self.cookie) logger.debug("Unadvised %s", self.cp) @@ -80,7 +99,7 @@ def disconnect(self): self.cookie = None del self.receiver - def __del__(self): + def __del__(self) -> None: try: if self.cookie is not None: self.cp.Unadvise(self.cookie) @@ -89,7 +108,7 @@ def __del__(self): pass -def FindOutgoingInterface(source): +def FindOutgoingInterface(source: IUnknown) -> Type[IUnknown]: """XXX Describe the strategy that is used...""" # If the COM object implements IProvideClassInfo2, it is easy to # find the default outgoing interface. @@ -129,7 +148,7 @@ def FindOutgoingInterface(source): raise TypeError("cannot determine source interface") -def find_single_connection_interface(source): +def find_single_connection_interface(source: IUnknown) -> Optional[Type[IUnknown]]: # Enumerate the connection interfaces. If we find a single one, # return it, if there are more, we give up since we cannot # determine which one to use. @@ -187,11 +206,11 @@ class _SinkMethodFinder(_MethodFinder): event handlers. """ - def __init__(self, inst, sink): + def __init__(self, inst: COMObject, sink: Any) -> None: super(_SinkMethodFinder, self).__init__(inst) self.sink = sink - def find_method(self, fq_name, mthname): + def find_method(self, fq_name: str, mthname: str) -> Callable[..., Any]: impl = self._find_method(fq_name, mthname) # Caller of this method catches AttributeError, # so we need to be careful in the following code @@ -206,7 +225,7 @@ def find_method(self, fq_name, mthname): except AttributeError as details: raise RuntimeError(details) - def _find_method(self, fq_name, mthname): + def _find_method(self, fq_name: str, mthname: str) -> Callable[..., Any]: try: return super(_SinkMethodFinder, self).find_method(fq_name, mthname) except AttributeError: @@ -216,7 +235,7 @@ def _find_method(self, fq_name, mthname): return getattr(self.sink, mthname) -def CreateEventReceiver(interface, handler): +def CreateEventReceiver(interface: Type[IUnknown], handler: Any) -> COMObject: class Sink(comtypes.COMObject): _com_interfaces_ = [interface] diff --git a/comtypes/connectionpoints.py b/comtypes/connectionpoints.py index b22d1414..35cee4d9 100644 --- a/comtypes/connectionpoints.py +++ b/comtypes/connectionpoints.py @@ -1,7 +1,15 @@ -import sys -from ctypes import * +from ctypes import POINTER, Structure, c_ulong +from typing import TYPE_CHECKING, Tuple +from typing import Union as _UnionT -from comtypes import COMMETHOD, GUID, HRESULT, IUnknown, dispid +from comtypes import COMMETHOD, GUID, HRESULT, IUnknown + +if TYPE_CHECKING: + from ctypes import _CArgObject, _Pointer + + from comtypes import hints # noqa # type: ignore + + REFIID = _UnionT[_Pointer[GUID], _CArgObject] _GUID = GUID @@ -22,16 +30,35 @@ class IConnectionPointContainer(IUnknown): _iid_ = GUID("{B196B284-BAB4-101A-B69C-00AA00341D07}") _idlflags_ = [] + if TYPE_CHECKING: + + def EnumConnectionPoints(self) -> "IEnumConnectionPoints": ... + def FindConnectionPoint(self, riid: REFIID) -> "IConnectionPoint": ... + class IConnectionPoint(IUnknown): _iid_ = GUID("{B196B286-BAB4-101A-B69C-00AA00341D07}") _idlflags_ = [] + if TYPE_CHECKING: + + def GetConnectionPointContainer(self) -> IConnectionPointContainer: ... + def Advise(self, pUnkSink: IUnknown) -> int: ... + def Unadvise(self, dwCookie: int) -> hints.Hresult: ... + def EnumConnections(self) -> "IEnumConnections": ... + class IEnumConnections(IUnknown): _iid_ = GUID("{B196B287-BAB4-101A-B69C-00AA00341D07}") _idlflags_ = [] + if TYPE_CHECKING: + + def Next(self, cConnections: int) -> Tuple[tagCONNECTDATA, int]: ... + def Skip(self, cConnections: int) -> hints.Hresult: ... + def Reset(self) -> hints.Hresult: ... + def Clone(self) -> "IEnumConnections": ... + def __iter__(self): return self @@ -46,6 +73,13 @@ class IEnumConnectionPoints(IUnknown): _iid_ = GUID("{B196B285-BAB4-101A-B69C-00AA00341D07}") _idlflags_ = [] + if TYPE_CHECKING: + + def Next(self, cConnections: int) -> Tuple[IConnectionPoint, int]: ... + def Skip(self, cConnections: int) -> hints.Hresult: ... + def Reset(self) -> hints.Hresult: ... + def Clone(self) -> "IEnumConnectionPoints": ... + def __iter__(self): return self