diff --git a/comtypes/_comobject.py b/comtypes/_comobject.py index 74d0dbf7..e808b644 100644 --- a/comtypes/_comobject.py +++ b/comtypes/_comobject.py @@ -4,14 +4,16 @@ from ctypes import ( POINTER, FormatError, + OleDLL, Structure, + WinDLL, byref, c_long, c_void_p, - oledll, + c_wchar_p, pointer, - windll, ) +from ctypes.wintypes import INT, LONG, LPVOID, UINT, ULONG, WORD from typing import ( TYPE_CHECKING, Any, @@ -29,6 +31,7 @@ from comtypes import GUID, IPersist, IUnknown, hresult from comtypes._vtbl import _MethodFinder, create_dispimpl, create_vtbl_mapping +from comtypes.automation import DISPID, DISPPARAMS, EXCEPINFO, VARIANT from comtypes.errorinfo import ISupportErrorInfo from comtypes.typeinfo import IProvideClassInfo, IProvideClassInfo2, ITypeInfo @@ -51,9 +54,10 @@ ################################################################ +_kernel32 = WinDLL("kernel32") try: - _InterlockedIncrement = windll.kernel32.InterlockedIncrement - _InterlockedDecrement = windll.kernel32.InterlockedDecrement + _InterlockedIncrement = _kernel32.InterlockedIncrement + _InterlockedDecrement = _kernel32.InterlockedDecrement except AttributeError: import threading @@ -82,14 +86,69 @@ def _InterlockedDecrement(ob: c_long) -> int: _InterlockedIncrement.restype = c_long _InterlockedDecrement.restype = c_long +_oleaut32 = WinDLL("oleaut32") + +_DispGetIDsOfNames = _oleaut32.DispGetIDsOfNames +_DispGetIDsOfNames.argtypes = [ + POINTER(ITypeInfo), + POINTER(c_wchar_p), + UINT, + POINTER(DISPID), +] +_DispGetIDsOfNames.restype = ( + LONG # technically, it is a HRESULT, but we want to avoid the OSError +) + +_DispInvoke = _oleaut32.DispInvoke +_DispInvoke.argtypes = [ + LPVOID, + POINTER(ITypeInfo), + DISPID, + WORD, + POINTER(DISPPARAMS), + POINTER(VARIANT), + POINTER(EXCEPINFO), + POINTER(UINT), +] +_DispInvoke.restype = ( + LONG # technically, it is a HRESULT, but we want to avoid the OSError +) + + +_ole32_nohresult = WinDLL("ole32") +_ole32 = OleDLL("ole32") + +_CoInitialize = _ole32_nohresult.CoInitialize +_CoInitialize.argtypes = [LPVOID] +_CoInitialize.restype = ( + LONG # technically, it is a HRESULT, but we want to avoid the OSError +) + +_CoUninitialize = _ole32_nohresult.CoUninitialize +_CoUninitialize.argtypes = [] +_CoUninitialize.restype = None + +_CoAddRefServerProcess = _ole32.CoAddRefServerProcess +_CoAddRefServerProcess.argtypes = [] +_CoAddRefServerProcess.restype = ULONG + +_CoReleaseServerProcess = _ole32.CoReleaseServerProcess +_CoReleaseServerProcess.argtypes = [] +_CoReleaseServerProcess.restype = ULONG + + +_user32 = WinDLL("user32") + +_PostQuitMessage = _user32.PostQuitMessage +_PostQuitMessage.argtypes = [INT] +_PostQuitMessage.restype = None + class LocalServer(object): _queue: Optional[queue.Queue] = None def run(self, classobjects: Sequence["hints.localserver.ClassFactory"]) -> None: - # Use windll instead of oledll so that we don't get an - # exception on a FAILED hresult: - hr = windll.ole32.CoInitialize(None) + hr = _CoInitialize(None) if hresult.RPC_E_CHANGED_MODE == hr: # we're running in MTA: no message pump needed _debug("Server running in MTA") @@ -100,7 +159,7 @@ def run(self, classobjects: Sequence["hints.localserver.ClassFactory"]) -> None: if hr >= 0: # we need a matching CoUninitialize() call for a successful # CoInitialize(). - windll.ole32.CoUninitialize() + _CoUninitialize() self.run_sta() for obj in classobjects: @@ -116,15 +175,15 @@ def run_mta(self) -> None: self._queue.get() def Lock(self) -> None: - oledll.ole32.CoAddRefServerProcess() + _CoAddRefServerProcess() def Unlock(self) -> None: - rc = oledll.ole32.CoReleaseServerProcess() + rc = _CoReleaseServerProcess() if rc == 0: if self._queue: self._queue.put(42) else: - windll.user32.PostQuitMessage(0) + _PostQuitMessage(0) class InprocServer(object): @@ -401,7 +460,7 @@ def IDispatch_GetIDsOfNames(self, this, riid, rgszNames, cNames, lcid, rgDispId) tinfo = self.__typeinfo except AttributeError: return hresult.E_NOTIMPL - return windll.oleaut32.DispGetIDsOfNames(tinfo, rgszNames, cNames, rgDispId) + return _DispGetIDsOfNames(tinfo, rgszNames, cNames, rgDispId) def IDispatch_Invoke( self, @@ -432,7 +491,7 @@ def IDispatch_Invoke( # an error. interface = self._com_interfaces_[0] ptr = self._com_pointers_[interface._iid_] - return windll.oleaut32.DispInvoke( + return _DispInvoke( ptr, tinfo, dispIdMember,