Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test_comobject(part 2). #710

Merged
merged 2 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions comtypes/_comobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from comtypes.typeinfo import IProvideClassInfo, IProvideClassInfo2

if TYPE_CHECKING:
from ctypes import _FuncPointer, _Pointer
from ctypes import _CArgObject, _FuncPointer, _Pointer

from comtypes import hints # type: ignore
from comtypes._memberspec import _ArgSpecElmType, _ParamFlagType
Expand Down Expand Up @@ -706,7 +706,11 @@ def IUnknown_Release(
return result

def IUnknown_QueryInterface(
self, this: Any, riid: "_Pointer[GUID]", ppvObj: c_void_p, _debug=_debug
self,
this: Any,
riid: "_Pointer[GUID]",
ppvObj: Union[c_void_p, "_CArgObject"],
_debug=_debug,
) -> int:
# XXX This is probably too slow.
# riid[0].hashcode() alone takes 33 us!
Expand All @@ -733,7 +737,7 @@ def QueryInterface(self, interface: Type[_T_IUnknown]) -> _T_IUnknown:
# CopyComPointer(src, dst) calls AddRef!
result = POINTER(interface)()
CopyComPointer(ptr, byref(result))
return result
return result # type: ignore

################################################################
# ISupportErrorInfo::InterfaceSupportsErrorInfo implementation
Expand Down
41 changes: 40 additions & 1 deletion comtypes/test/test_comobject.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import ctypes
import unittest as ut
from ctypes import POINTER, byref, pointer

import comtypes
import comtypes.client
from comtypes import IUnknown
from comtypes import IUnknown, hresult
from comtypes.automation import IDispatch

comtypes.client.GetModule("UIAutomationCore.dll")
comtypes.client.GetModule("scrrun.dll")
Expand Down Expand Up @@ -34,6 +37,42 @@ def test_dispatch_interface(self):
self.assertEqual(dic.GetIDsOfNames("Add"), [1])


class Test_IUnknown_QueryInterface(ut.TestCase):
def test_e_pointer(self):
hr = scrrun.Dictionary().IUnknown_QueryInterface(
None, pointer(scrrun.IDictionary._iid_), ctypes.c_void_p()
)
self.assertEqual(hr, hresult.E_POINTER)

def test_e_no_interface(self):
hr = scrrun.Dictionary().IUnknown_QueryInterface(
None, pointer(uiac.IUIAutomation._iid_), ctypes.c_void_p()
)
self.assertEqual(hr, hresult.E_NOINTERFACE)

def test_valid_pointer(self):
ptr = ctypes.c_void_p()
ctypes.oledll.ole32.CoCreateInstance(
byref(scrrun.Dictionary._reg_clsid_),
None,
comtypes.CLSCTX_SERVER,
byref(scrrun.IDictionary._iid_),
byref(ptr),
)
hr = scrrun.Dictionary().IUnknown_QueryInterface(
None, pointer(scrrun.IDictionary._iid_), ptr
)
self.assertEqual(hr, hresult.S_OK)

def test_valid_interface(self):
dic = POINTER(IDispatch)()
hr = scrrun.Dictionary().IUnknown_QueryInterface(
None, pointer(scrrun.IDictionary._iid_), byref(dic)
)
self.assertEqual(hr, hresult.S_OK)
self.assertEqual(dic.GetTypeInfoCount(), 1) # type: ignore


class Test_IPersist_GetClassID(ut.TestCase):
def test(self):
self.assertEqual(
Expand Down