Skip to content

Commit

Permalink
Turn RegistryEntries into an ABC and create subclasses for frozen…
Browse files Browse the repository at this point in the history
… and non-frozen cases. (#738)

* `__init__` -> `__new__`

* Make `RegistryEntries` a dynamic subtype resolving class.

* Small fix.

* Remove `_iter_ctx_entries`.

* Make `RegistryEntries` abstract class.

* Reduce conditional branches for retrieving server dll.
  • Loading branch information
junkmd authored Jan 14, 2025
1 parent 5e608d4 commit a1a19b7
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 86 deletions.
109 changes: 54 additions & 55 deletions comtypes/server/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,18 @@
"""

import _ctypes
import abc
import logging
import os
import sys
import winreg
from ctypes import WinDLL, WinError
from ctypes.wintypes import HKEY, LONG, LPCWSTR
from typing import Iterator, List, Optional, Tuple, Type, Union
from typing import Iterable, Iterator, List, Optional, Tuple, Type, Union

import comtypes.server.inprocserver # noqa
from comtypes import CLSCTX_INPROC_SERVER, CLSCTX_LOCAL_SERVER
from comtypes.hresult import TYPE_E_CANTLOADLIBRARY, TYPE_E_REGISTRYACCESS
from comtypes.server.inprocserver import _clsid_to_class
from comtypes.server.localserver import run as run_localserver
from comtypes.server.w_getopt import w_getopt
from comtypes.typeinfo import (
Expand Down Expand Up @@ -113,6 +114,11 @@ def __init__(self) -> None:
self._frozen = getattr(sys, "frozen", None)
self._frozendllhandle = getattr(sys, "frozendllhandle", None)

def _generate_reg_entries(self, cls: Type) -> Iterable[_Entry]:
if self._frozen is None:
return InterpRegistryEntries(cls)
return FrozenRegistryEntries(cls, self._frozen, self._frozendllhandle)

def nodebug(self, cls: Type) -> None:
"""Delete logging entries from the registry."""
clsid = cls._reg_clsid_
Expand Down Expand Up @@ -168,13 +174,7 @@ def register(self, cls: Type, executable: Optional[str] = None) -> None:
self._register(cls, executable)

def _register(self, cls: Type, executable: Optional[str] = None) -> None:
table = sorted(
RegistryEntries(
cls,
frozen=self._frozen,
frozendllhandle=self._frozendllhandle,
)
)
table = sorted(self._generate_reg_entries(cls))
_debug("Registering %s", cls)
for hkey, subkey, valuename, value in table:
_debug("[%s\\%s]", _explain(hkey), subkey)
Expand Down Expand Up @@ -210,12 +210,7 @@ def unregister(self, cls: Type, force: bool = False) -> None:
def _unregister(self, cls: Type, force: bool = False) -> None:
# If force==False, we only remove those entries that we
# actually would have written. It seems ATL does the same.
table = [
t[:2]
for t in RegistryEntries(
cls, frozen=self._frozen, frozendllhandle=self._frozendllhandle
)
]
table = [t[:2] for t in self._generate_reg_entries(cls)]
# only unique entries
table = list(set(table))
table.sort()
Expand Down Expand Up @@ -246,14 +241,13 @@ def _unregister(self, cls: Type, force: bool = False) -> None:
_debug("Done")


def _get_serverdll(handle: Optional[int]) -> str:
def _get_serverdll(handle: int) -> str:
"""Return the pathname of the dll hosting the COM object."""
if handle is not None:
return GetModuleFileName(handle, 260)
return _ctypes.__file__
assert isinstance(handle, int)
return GetModuleFileName(handle, 260)


class RegistryEntries(object):
class RegistryEntries(abc.ABC):
"""Iterator of tuples containing registry entries.
The tuples must be (key, subkey, name, value).
Expand All @@ -275,11 +269,15 @@ class RegistryEntries(object):
IDL library name of the type library containing the coclass.
"""

@abc.abstractmethod
def __iter__(self) -> Iterator[_Entry]: ...


class FrozenRegistryEntries(RegistryEntries):
def __init__(
self,
cls: Type,
*,
frozen: Optional[str] = None,
frozen: str,
frozendllhandle: Optional[int] = None,
) -> None:
self._cls = cls
Expand All @@ -290,9 +288,38 @@ def __iter__(self) -> Iterator[_Entry]:
# that's the only required attribute for registration
reg_clsid = str(self._cls._reg_clsid_)
yield from _iter_reg_entries(self._cls, reg_clsid)
yield from _iter_ctx_entries(
self._cls, reg_clsid, self._frozen, self._frozendllhandle
)
clsctx: int = getattr(self._cls, "_reg_clsctx_", 0)
localsvr_ctx = bool(clsctx & CLSCTX_LOCAL_SERVER)
inprocsvr_ctx = bool(clsctx & CLSCTX_INPROC_SERVER)
if localsvr_ctx and self._frozendllhandle is None:
yield from _iter_frozen_local_ctx_entries(self._cls, reg_clsid)
if inprocsvr_ctx and self._frozen == "dll":
assert self._frozendllhandle is not None
frozen_dll = _get_serverdll(self._frozendllhandle)
yield from _iter_inproc_ctx_entries(reg_clsid, frozen_dll)
yield from _iter_inproc_threading_model_entries(self._cls, reg_clsid)
yield from _iter_tlib_entries(self._cls, reg_clsid)


class InterpRegistryEntries(RegistryEntries):
def __init__(self, cls: Type) -> None:
self._cls = cls

def __iter__(self) -> Iterator[_Entry]:
# that's the only required attribute for registration
reg_clsid = str(self._cls._reg_clsid_)
yield from _iter_reg_entries(self._cls, reg_clsid)
clsctx: int = getattr(self._cls, "_reg_clsctx_", 0)
localsvr_ctx = bool(clsctx & CLSCTX_LOCAL_SERVER)
inprocsvr_ctx = bool(clsctx & CLSCTX_INPROC_SERVER)
if localsvr_ctx:
yield from _iter_interp_local_ctx_entries(self._cls, reg_clsid)
if inprocsvr_ctx:
yield from _iter_inproc_ctx_entries(reg_clsid, _ctypes.__file__)
# only for non-frozen inproc servers the PythonPath/PythonClass is needed.
yield from _iter_inproc_python_entries(self._cls, reg_clsid)
yield from _iter_inproc_threading_model_entries(self._cls, reg_clsid)
yield from _iter_tlib_entries(self._cls, reg_clsid)


def _get_full_classname(cls: Type) -> str:
Expand Down Expand Up @@ -347,27 +374,6 @@ def _iter_reg_entries(cls: Type, reg_clsid: str) -> Iterator[_Entry]:
yield (HKCR, f"{reg_novers_progid}\\CLSID", "", reg_clsid) # 3a


def _iter_ctx_entries(
cls: Type, reg_clsid: str, frozen: Optional[str], frozendllhandle: Optional[int]
) -> Iterator[_Entry]:
clsctx: int = getattr(cls, "_reg_clsctx_", 0)
localsvr_ctx = bool(clsctx & CLSCTX_LOCAL_SERVER)
inprocsvr_ctx = bool(clsctx & CLSCTX_INPROC_SERVER)

if localsvr_ctx and frozendllhandle is None:
if frozen is None:
yield from _iter_interp_local_ctx_entries(cls, reg_clsid)
else:
yield from _iter_frozen_local_ctx_entries(cls, reg_clsid)
if inprocsvr_ctx and frozen in (None, "dll"):
yield from _iter_inproc_ctx_entries(cls, reg_clsid, frozendllhandle)
# only for non-frozen inproc servers the PythonPath/PythonClass is needed.
if frozendllhandle is None or not _clsid_to_class:
yield from _iter_inproc_python_entries(cls, reg_clsid)
yield from _iter_inproc_threading_model_entries(cls, reg_clsid)
yield from _iter_tlib_entries(cls, reg_clsid)


def _iter_interp_local_ctx_entries(cls: Type, reg_clsid: str) -> Iterator[_Entry]:
exe = sys.executable
exe = f'"{exe}"' if " " in exe else exe
Expand All @@ -383,17 +389,10 @@ def _iter_frozen_local_ctx_entries(cls: Type, reg_clsid: str) -> Iterator[_Entry
yield (HKCR, rf"CLSID\{reg_clsid}\LocalServer32", "", f"{exe}")


def _iter_inproc_ctx_entries(
cls: Type, reg_clsid: str, frozendllhandle: Optional[int]
) -> Iterator[_Entry]:
def _iter_inproc_ctx_entries(reg_clsid: str, dllfile: str) -> Iterator[_Entry]:
# Register InprocServer32 only when run from script or from
# py2exe dll server, not from py2exe exe server.
yield (
HKCR,
rf"CLSID\{reg_clsid}\InprocServer32",
"",
_get_serverdll(frozendllhandle),
)
yield (HKCR, rf"CLSID\{reg_clsid}\InprocServer32", "", dllfile)


def _iter_inproc_python_entries(cls: Type, reg_clsid: str) -> Iterator[_Entry]:
Expand Down
59 changes: 28 additions & 31 deletions comtypes/test/test_server_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
import comtypes.server.inprocserver
from comtypes import GUID
from comtypes.server import register
from comtypes.server.register import Registrar, RegistryEntries, _get_serverdll
from comtypes.server.register import (
FrozenRegistryEntries,
InterpRegistryEntries,
Registrar,
_get_serverdll,
)

HKCR = winreg.HKEY_CLASSES_ROOT
MULTI_SZ = winreg.REG_MULTI_SZ
Expand Down Expand Up @@ -190,9 +195,6 @@ def test_calls_cls_unregister(self):


class Test_get_serverdll(ut.TestCase):
def test_nonfrozen(self):
self.assertEqual(_ctypes.__file__, _get_serverdll(None))

@mock.patch.object(register, "GetModuleFileName")
def test_frozen(self, GetModuleFileName):
handle, dll_path = 1234, r"path\to\frozen.dll"
Expand All @@ -203,15 +205,15 @@ def test_frozen(self, GetModuleFileName):
self.assertEqual(260, maxsize)


class Test_NonFrozen_RegistryEntries(ut.TestCase):
class Test_InterpRegistryEntries(ut.TestCase):
def test_reg_clsid(self):
reg_clsid = GUID.create_new()

class Cls:
_reg_clsid_ = reg_clsid

expected = [(HKCR, rf"CLSID\{reg_clsid}", "", "")]
self.assertEqual(expected, list(RegistryEntries(Cls)))
self.assertEqual(expected, list(InterpRegistryEntries(Cls)))

def test_reg_desc(self):
reg_clsid = GUID.create_new()
Expand All @@ -222,7 +224,7 @@ class Cls:
_reg_desc_ = reg_desc

expected = [(HKCR, rf"CLSID\{reg_clsid}", "", reg_desc)]
self.assertEqual(expected, list(RegistryEntries(Cls)))
self.assertEqual(expected, list(InterpRegistryEntries(Cls)))

def test_reg_novers_progid(self):
reg_clsid = GUID.create_new()
Expand All @@ -233,7 +235,7 @@ class Cls:
_reg_novers_progid_ = reg_novers_progid

expected = [(HKCR, rf"CLSID\{reg_clsid}", "", "Lib Server")]
self.assertEqual(expected, list(RegistryEntries(Cls)))
self.assertEqual(expected, list(InterpRegistryEntries(Cls)))

def test_progid(self):
reg_clsid = GUID.create_new()
Expand All @@ -249,7 +251,7 @@ class Cls:
(HKCR, reg_progid, "", "Lib Server 1"),
(HKCR, rf"{reg_progid}\CLSID", "", str(reg_clsid)),
]
self.assertEqual(expected, list(RegistryEntries(Cls)))
self.assertEqual(expected, list(InterpRegistryEntries(Cls)))

def test_reg_progid_reg_desc(self):
reg_clsid = GUID.create_new()
Expand All @@ -267,7 +269,7 @@ class Cls:
(HKCR, reg_progid, "", "description for testing"),
(HKCR, rf"{reg_progid}\CLSID", "", str(reg_clsid)),
]
self.assertEqual(expected, list(RegistryEntries(Cls)))
self.assertEqual(expected, list(InterpRegistryEntries(Cls)))

def test_reg_progid_reg_novers_progid(self):
reg_clsid = GUID.create_new()
Expand All @@ -290,7 +292,7 @@ class Cls:
(HKCR, rf"{reg_novers_progid}\CurVer", "", "Lib.Server.1"),
(HKCR, rf"{reg_novers_progid}\CLSID", "", str(reg_clsid)),
]
self.assertEqual(expected, list(RegistryEntries(Cls)))
self.assertEqual(expected, list(InterpRegistryEntries(Cls)))

def test_local_server(self):
reg_clsid = GUID.create_new()
Expand All @@ -306,7 +308,7 @@ class Cls:
(HKCR, clsid_sub, "", ""),
(HKCR, local_srv_sub, "", f"{sys.executable} {__file__}"),
]
self.assertEqual(expected, list(RegistryEntries(Cls)))
self.assertEqual(expected, list(InterpRegistryEntries(Cls)))

def test_inproc_server(self):
reg_clsid = GUID.create_new()
Expand All @@ -325,7 +327,7 @@ class Cls:
(HKCR, inproc_srv_sub, "PythonClass", full_classname),
(HKCR, inproc_srv_sub, "PythonPath", os.path.dirname(__file__)),
]
self.assertEqual(expected, list(RegistryEntries(Cls)))
self.assertEqual(expected, list(InterpRegistryEntries(Cls)))

def test_inproc_server_reg_threading(self):
reg_clsid = GUID.create_new()
Expand All @@ -347,7 +349,7 @@ class Cls:
(HKCR, inproc_srv_sub, "PythonPath", os.path.dirname(__file__)),
(HKCR, inproc_srv_sub, "ThreadingModel", reg_threading),
]
self.assertEqual(expected, list(RegistryEntries(Cls)))
self.assertEqual(expected, list(InterpRegistryEntries(Cls)))

def test_reg_typelib(self):
reg_clsid = GUID.create_new()
Expand All @@ -362,7 +364,7 @@ class Cls:
(HKCR, rf"CLSID\{reg_clsid}", "", ""),
(HKCR, rf"CLSID\{reg_clsid}\Typelib", "", libid),
]
self.assertEqual(expected, list(RegistryEntries(Cls)))
self.assertEqual(expected, list(InterpRegistryEntries(Cls)))

def test_all_entries(self):
reg_clsid = GUID.create_new()
Expand Down Expand Up @@ -403,10 +405,10 @@ class Cls:
(HKCR, inproc_srv_sub, "ThreadingModel", reg_threading),
(HKCR, rf"{clsid_sub}\Typelib", "", libid),
]
self.assertEqual(expected, list(RegistryEntries(Cls)))
self.assertEqual(expected, list(InterpRegistryEntries(Cls)))


class Test_Frozen_RegistryEntries(ut.TestCase):
class Test_FrozenRegistryEntries(ut.TestCase):
SERVERDLL = r"my\target\server.dll"

# We do not test the scenario where `frozen` is `'dll'` but
Expand All @@ -424,7 +426,7 @@ class Cls:
# In such cases, the server does not start because the
# InprocServer32/LocalServer32 keys are not registered.
expected = [(HKCR, rf"CLSID\{reg_clsid}", "", "")]
entries = RegistryEntries(Cls, frozen="dll", frozendllhandle=1234)
entries = FrozenRegistryEntries(Cls, frozen="dll", frozendllhandle=1234)
self.assertEqual(expected, list(entries))

def test_local_windows_exe(self):
Expand All @@ -439,10 +441,12 @@ class Cls:
(HKCR, rf"CLSID\{reg_clsid}", "", ""),
(HKCR, rf"CLSID\{reg_clsid}\LocalServer32", "", sys.executable),
]
self.assertEqual(expected, list(RegistryEntries(Cls, frozen="windows_exe")))
self.assertEqual(
expected, list(FrozenRegistryEntries(Cls, frozen="windows_exe"))
)

@mock.patch.object(register, "_get_serverdll", return_value=SERVERDLL)
def test_inproc_dll_nonempty_clsid_to_class(self, get_serverdll):
def test_inproc_dll(self, get_serverdll):
reg_clsid = GUID.create_new()
reg_clsctx = comtypes.CLSCTX_INPROC_SERVER

Expand All @@ -457,10 +461,8 @@ class Cls:
(HKCR, inproc_srv_sub, "", self.SERVERDLL),
]

with mock.patch.dict(comtypes.server.inprocserver._clsid_to_class):
comtypes.server.inprocserver._clsid_to_class.update({5678: Cls})
entries = RegistryEntries(Cls, frozen="dll", frozendllhandle=1234)
self.assertEqual(expected, list(entries))
entries = FrozenRegistryEntries(Cls, frozen="dll", frozendllhandle=1234)
self.assertEqual(expected, list(entries))
get_serverdll.assert_called_once_with(1234)

@mock.patch.object(register, "_get_serverdll", return_value=SERVERDLL)
Expand All @@ -476,18 +478,13 @@ class Cls:

clsid_sub = rf"CLSID\{reg_clsid}"
inproc_srv_sub = rf"{clsid_sub}\InprocServer32"
full_classname = f"{__name__}.Cls"
expected = [
(HKCR, clsid_sub, "", ""),
(HKCR, inproc_srv_sub, "", self.SERVERDLL),
# 'PythonClass' and 'PythonPath' are not required for
# frozen inproc servers. This may be bugs but they do
# not affect the server behavior.
(HKCR, inproc_srv_sub, "PythonClass", full_classname),
(HKCR, inproc_srv_sub, "PythonPath", os.path.dirname(__file__)),
(HKCR, inproc_srv_sub, "ThreadingModel", reg_threading),
]
self.assertEqual(
expected, list(RegistryEntries(Cls, frozen="dll", frozendllhandle=1234))
expected,
list(FrozenRegistryEntries(Cls, frozen="dll", frozendllhandle=1234)),
)
get_serverdll.assert_called_once_with(1234)

0 comments on commit a1a19b7

Please sign in to comment.