Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement static import for ISequentialStream (#474) #578

Merged
merged 8 commits into from
Jul 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion comtypes/client/_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def _get_known_namespaces() -> Tuple[
Note:
The interfaces that should be included in `__known_symbols__` should be limited
to those that can be said to be bound to the design concept of COM, such as
`IUnknown`, and those defined in `objidl` and `oaidl`.
`IUnknown`, `IDispatch` and `ITypeInfo`.
`comtypes` does NOT aim to statically define all COM object interfaces in
its repository.
"""
Expand All @@ -272,6 +272,7 @@ def _get_known_namespaces() -> Tuple[
"comtypes.persist",
"comtypes.typeinfo",
"comtypes.automation",
"comtypes.stream",
"comtypes",
"ctypes.wintypes",
"ctypes",
Expand Down
65 changes: 65 additions & 0 deletions comtypes/stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from ctypes import Array, c_ubyte, c_ulong, HRESULT, POINTER, pointer
from typing import Tuple, TYPE_CHECKING

from comtypes import COMMETHOD, GUID, IUnknown


class ISequentialStream(IUnknown):
"""Defines methods for the stream objects in sequence."""

_iid_ = GUID("{0C733A30-2A1C-11CE-ADE5-00AA0044773D}")
_idlflags_ = []

_methods_ = [
# Note that these functions are called `Read` and `Write` in Microsoft's documentation,
# see https://learn.microsoft.com/en-us/windows/win32/api/objidl/nn-objidl-isequentialstream.
# However, the comtypes code generation detects these as `RemoteRead` and `RemoteWrite`
# for very subtle reasons, see e.g. https://stackoverflow.com/q/19820999/. We will not
# rename these in this manual import for the sake of consistency.
COMMETHOD(
[],
HRESULT,
"RemoteRead",
# This call only works if `pv` is pre-allocated with `cb` bytes,
# which cannot be done by the high level function generated by metaclasses.
# Therefore, we override the high level function to implement this behaviour
# and then delegate the call the raw COM method.
(["out"], POINTER(c_ubyte), "pv"),
(["in"], c_ulong, "cb"),
(["out"], POINTER(c_ulong), "pcbRead"),
),
COMMETHOD(
[],
HRESULT,
"RemoteWrite",
(["in"], POINTER(c_ubyte), "pv"),
(["in"], c_ulong, "cb"),
(["out"], POINTER(c_ulong), "pcbWritten"),
),
]

def RemoteRead(self, cb: int) -> Tuple["Array[c_ubyte]", int]:
"""Reads a specified number of bytes from the stream object into memory
starting at the current seek pointer.
"""
# Behaves as if `pv` is pre-allocated with `cb` bytes by the high level func.
pv = (c_ubyte * cb)()
pcb_read = pointer(c_ulong(0))
self.__com_RemoteRead(pv, c_ulong(cb), pcb_read) # type: ignore
# return both `out` parameters
return pv, pcb_read.contents.value

if TYPE_CHECKING:

def RemoteWrite(self, pv: "Array[c_ubyte]", cb: int) -> int:
"""Writes a specified number of bytes into the stream object starting at
the current seek pointer.
"""
...


# fmt: off
__known_symbols__ = [
'ISequentialStream',
]
# fmt: on
11 changes: 11 additions & 0 deletions comtypes/test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ def test_mscorlib(self):
# the `_Pointer` interface, rather than importing `_Pointer` from `ctypes`.
self.assertTrue(issubclass(mod._Pointer, comtypes.IUnknown))

def test_portabledeviceapi(self):
mod = comtypes.client.GetModule("portabledeviceapi.dll")
from comtypes.stream import ISequentialStream

self.assertTrue(issubclass(mod.IStream, ISequentialStream))

def test_no_replacing_Patch_namespace(self):
# NOTE: An object named `Patch` is defined in some dll.
# Depending on how the namespace is defined in the static module,
Expand Down Expand Up @@ -117,6 +123,11 @@ def test_symbols_in_comtypes(self):

self._doit(comtypes)

def test_symbols_in_comtypes_stream(self):
import comtypes.stream

self._doit(comtypes.stream)

def test_symbols_in_comtypes_automation(self):
import comtypes.automation

Expand Down
53 changes: 53 additions & 0 deletions comtypes/test/test_istream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import unittest as ut

from ctypes import POINTER, byref, c_bool, c_ubyte
import comtypes
import comtypes.client

comtypes.client.GetModule("portabledeviceapi.dll")
from comtypes.gen.PortableDeviceApiLib import IStream


def _create_stream() -> IStream:
# Create an IStream
stream = POINTER(IStream)() # type: ignore
comtypes._ole32.CreateStreamOnHGlobal(None, c_bool(True), byref(stream))
return stream # type: ignore


class Test_RemoteWrite(ut.TestCase):
def test_RemoteWrite(self):
stream = _create_stream()
test_data = "Some data".encode("utf-8")
pv = (c_ubyte * len(test_data)).from_buffer(bytearray(test_data))

written = stream.RemoteWrite(pv, len(test_data))

# Verification
self.assertEqual(written, len(test_data))


class Test_RemoteRead(ut.TestCase):
def test_RemoteRead(self):
stream = _create_stream()
test_data = "Some data".encode("utf-8")
pv = (c_ubyte * len(test_data)).from_buffer(bytearray(test_data))
stream.RemoteWrite(pv, len(test_data))

# Make sure the data actually gets written before trying to read back
stream.Commit(0)
# Move the stream back to the beginning
STREAM_SEEK_SET = 0
stream.RemoteSeek(0, STREAM_SEEK_SET)

buffer_size = 1024

read_buffer, data_read = stream.RemoteRead(buffer_size)

# Verification
self.assertEqual(data_read, len(test_data))
self.assertEqual(bytearray(read_buffer)[0:data_read], test_data)


if __name__ == "__main__":
ut.main()