From a1a19b7e455f228302c82bf4654357b7c2d5b31b Mon Sep 17 00:00:00 2001 From: Jun Komoda <45822440+junkmd@users.noreply.github.com> Date: Wed, 15 Jan 2025 08:54:15 +0900 Subject: [PATCH] Turn `RegistryEntries` into an `ABC` and create subclasses for frozen and non-frozen cases. (#738) * `__init__` -> `__new__` * Make `RegistryEntries` a dynamic subtype resolving class. * Small fix. * Remove `_iter_ctx_entries`. * Make `RegistryEntries` abstract class. * Reduce conditional branches for retrieving server dll. --- comtypes/server/register.py | 109 +++++++++++++------------- comtypes/test/test_server_register.py | 59 +++++++------- 2 files changed, 82 insertions(+), 86 deletions(-) diff --git a/comtypes/server/register.py b/comtypes/server/register.py index ed47b68a..ccffc289 100644 --- a/comtypes/server/register.py +++ b/comtypes/server/register.py @@ -37,17 +37,18 @@ """ import _ctypes +import abc import logging import os import sys import winreg from ctypes import WinDLL, WinError from ctypes.wintypes import HKEY, LONG, LPCWSTR -from typing import Iterator, List, Optional, Tuple, Type, Union +from typing import Iterable, Iterator, List, Optional, Tuple, Type, Union +import comtypes.server.inprocserver # noqa from comtypes import CLSCTX_INPROC_SERVER, CLSCTX_LOCAL_SERVER from comtypes.hresult import TYPE_E_CANTLOADLIBRARY, TYPE_E_REGISTRYACCESS -from comtypes.server.inprocserver import _clsid_to_class from comtypes.server.localserver import run as run_localserver from comtypes.server.w_getopt import w_getopt from comtypes.typeinfo import ( @@ -113,6 +114,11 @@ def __init__(self) -> None: self._frozen = getattr(sys, "frozen", None) self._frozendllhandle = getattr(sys, "frozendllhandle", None) + def _generate_reg_entries(self, cls: Type) -> Iterable[_Entry]: + if self._frozen is None: + return InterpRegistryEntries(cls) + return FrozenRegistryEntries(cls, self._frozen, self._frozendllhandle) + def nodebug(self, cls: Type) -> None: """Delete logging entries from the registry.""" clsid = cls._reg_clsid_ @@ -168,13 +174,7 @@ def register(self, cls: Type, executable: Optional[str] = None) -> None: self._register(cls, executable) def _register(self, cls: Type, executable: Optional[str] = None) -> None: - table = sorted( - RegistryEntries( - cls, - frozen=self._frozen, - frozendllhandle=self._frozendllhandle, - ) - ) + table = sorted(self._generate_reg_entries(cls)) _debug("Registering %s", cls) for hkey, subkey, valuename, value in table: _debug("[%s\\%s]", _explain(hkey), subkey) @@ -210,12 +210,7 @@ def unregister(self, cls: Type, force: bool = False) -> None: def _unregister(self, cls: Type, force: bool = False) -> None: # If force==False, we only remove those entries that we # actually would have written. It seems ATL does the same. - table = [ - t[:2] - for t in RegistryEntries( - cls, frozen=self._frozen, frozendllhandle=self._frozendllhandle - ) - ] + table = [t[:2] for t in self._generate_reg_entries(cls)] # only unique entries table = list(set(table)) table.sort() @@ -246,14 +241,13 @@ def _unregister(self, cls: Type, force: bool = False) -> None: _debug("Done") -def _get_serverdll(handle: Optional[int]) -> str: +def _get_serverdll(handle: int) -> str: """Return the pathname of the dll hosting the COM object.""" - if handle is not None: - return GetModuleFileName(handle, 260) - return _ctypes.__file__ + assert isinstance(handle, int) + return GetModuleFileName(handle, 260) -class RegistryEntries(object): +class RegistryEntries(abc.ABC): """Iterator of tuples containing registry entries. The tuples must be (key, subkey, name, value). @@ -275,11 +269,15 @@ class RegistryEntries(object): IDL library name of the type library containing the coclass. """ + @abc.abstractmethod + def __iter__(self) -> Iterator[_Entry]: ... + + +class FrozenRegistryEntries(RegistryEntries): def __init__( self, cls: Type, - *, - frozen: Optional[str] = None, + frozen: str, frozendllhandle: Optional[int] = None, ) -> None: self._cls = cls @@ -290,9 +288,38 @@ def __iter__(self) -> Iterator[_Entry]: # that's the only required attribute for registration reg_clsid = str(self._cls._reg_clsid_) yield from _iter_reg_entries(self._cls, reg_clsid) - yield from _iter_ctx_entries( - self._cls, reg_clsid, self._frozen, self._frozendllhandle - ) + clsctx: int = getattr(self._cls, "_reg_clsctx_", 0) + localsvr_ctx = bool(clsctx & CLSCTX_LOCAL_SERVER) + inprocsvr_ctx = bool(clsctx & CLSCTX_INPROC_SERVER) + if localsvr_ctx and self._frozendllhandle is None: + yield from _iter_frozen_local_ctx_entries(self._cls, reg_clsid) + if inprocsvr_ctx and self._frozen == "dll": + assert self._frozendllhandle is not None + frozen_dll = _get_serverdll(self._frozendllhandle) + yield from _iter_inproc_ctx_entries(reg_clsid, frozen_dll) + yield from _iter_inproc_threading_model_entries(self._cls, reg_clsid) + yield from _iter_tlib_entries(self._cls, reg_clsid) + + +class InterpRegistryEntries(RegistryEntries): + def __init__(self, cls: Type) -> None: + self._cls = cls + + def __iter__(self) -> Iterator[_Entry]: + # that's the only required attribute for registration + reg_clsid = str(self._cls._reg_clsid_) + yield from _iter_reg_entries(self._cls, reg_clsid) + clsctx: int = getattr(self._cls, "_reg_clsctx_", 0) + localsvr_ctx = bool(clsctx & CLSCTX_LOCAL_SERVER) + inprocsvr_ctx = bool(clsctx & CLSCTX_INPROC_SERVER) + if localsvr_ctx: + yield from _iter_interp_local_ctx_entries(self._cls, reg_clsid) + if inprocsvr_ctx: + yield from _iter_inproc_ctx_entries(reg_clsid, _ctypes.__file__) + # only for non-frozen inproc servers the PythonPath/PythonClass is needed. + yield from _iter_inproc_python_entries(self._cls, reg_clsid) + yield from _iter_inproc_threading_model_entries(self._cls, reg_clsid) + yield from _iter_tlib_entries(self._cls, reg_clsid) def _get_full_classname(cls: Type) -> str: @@ -347,27 +374,6 @@ def _iter_reg_entries(cls: Type, reg_clsid: str) -> Iterator[_Entry]: yield (HKCR, f"{reg_novers_progid}\\CLSID", "", reg_clsid) # 3a -def _iter_ctx_entries( - cls: Type, reg_clsid: str, frozen: Optional[str], frozendllhandle: Optional[int] -) -> Iterator[_Entry]: - clsctx: int = getattr(cls, "_reg_clsctx_", 0) - localsvr_ctx = bool(clsctx & CLSCTX_LOCAL_SERVER) - inprocsvr_ctx = bool(clsctx & CLSCTX_INPROC_SERVER) - - if localsvr_ctx and frozendllhandle is None: - if frozen is None: - yield from _iter_interp_local_ctx_entries(cls, reg_clsid) - else: - yield from _iter_frozen_local_ctx_entries(cls, reg_clsid) - if inprocsvr_ctx and frozen in (None, "dll"): - yield from _iter_inproc_ctx_entries(cls, reg_clsid, frozendllhandle) - # only for non-frozen inproc servers the PythonPath/PythonClass is needed. - if frozendllhandle is None or not _clsid_to_class: - yield from _iter_inproc_python_entries(cls, reg_clsid) - yield from _iter_inproc_threading_model_entries(cls, reg_clsid) - yield from _iter_tlib_entries(cls, reg_clsid) - - def _iter_interp_local_ctx_entries(cls: Type, reg_clsid: str) -> Iterator[_Entry]: exe = sys.executable exe = f'"{exe}"' if " " in exe else exe @@ -383,17 +389,10 @@ def _iter_frozen_local_ctx_entries(cls: Type, reg_clsid: str) -> Iterator[_Entry yield (HKCR, rf"CLSID\{reg_clsid}\LocalServer32", "", f"{exe}") -def _iter_inproc_ctx_entries( - cls: Type, reg_clsid: str, frozendllhandle: Optional[int] -) -> Iterator[_Entry]: +def _iter_inproc_ctx_entries(reg_clsid: str, dllfile: str) -> Iterator[_Entry]: # Register InprocServer32 only when run from script or from # py2exe dll server, not from py2exe exe server. - yield ( - HKCR, - rf"CLSID\{reg_clsid}\InprocServer32", - "", - _get_serverdll(frozendllhandle), - ) + yield (HKCR, rf"CLSID\{reg_clsid}\InprocServer32", "", dllfile) def _iter_inproc_python_entries(cls: Type, reg_clsid: str) -> Iterator[_Entry]: diff --git a/comtypes/test/test_server_register.py b/comtypes/test/test_server_register.py index b14c371d..dc3a2b2c 100644 --- a/comtypes/test/test_server_register.py +++ b/comtypes/test/test_server_register.py @@ -9,7 +9,12 @@ import comtypes.server.inprocserver from comtypes import GUID from comtypes.server import register -from comtypes.server.register import Registrar, RegistryEntries, _get_serverdll +from comtypes.server.register import ( + FrozenRegistryEntries, + InterpRegistryEntries, + Registrar, + _get_serverdll, +) HKCR = winreg.HKEY_CLASSES_ROOT MULTI_SZ = winreg.REG_MULTI_SZ @@ -190,9 +195,6 @@ def test_calls_cls_unregister(self): class Test_get_serverdll(ut.TestCase): - def test_nonfrozen(self): - self.assertEqual(_ctypes.__file__, _get_serverdll(None)) - @mock.patch.object(register, "GetModuleFileName") def test_frozen(self, GetModuleFileName): handle, dll_path = 1234, r"path\to\frozen.dll" @@ -203,7 +205,7 @@ def test_frozen(self, GetModuleFileName): self.assertEqual(260, maxsize) -class Test_NonFrozen_RegistryEntries(ut.TestCase): +class Test_InterpRegistryEntries(ut.TestCase): def test_reg_clsid(self): reg_clsid = GUID.create_new() @@ -211,7 +213,7 @@ class Cls: _reg_clsid_ = reg_clsid expected = [(HKCR, rf"CLSID\{reg_clsid}", "", "")] - self.assertEqual(expected, list(RegistryEntries(Cls))) + self.assertEqual(expected, list(InterpRegistryEntries(Cls))) def test_reg_desc(self): reg_clsid = GUID.create_new() @@ -222,7 +224,7 @@ class Cls: _reg_desc_ = reg_desc expected = [(HKCR, rf"CLSID\{reg_clsid}", "", reg_desc)] - self.assertEqual(expected, list(RegistryEntries(Cls))) + self.assertEqual(expected, list(InterpRegistryEntries(Cls))) def test_reg_novers_progid(self): reg_clsid = GUID.create_new() @@ -233,7 +235,7 @@ class Cls: _reg_novers_progid_ = reg_novers_progid expected = [(HKCR, rf"CLSID\{reg_clsid}", "", "Lib Server")] - self.assertEqual(expected, list(RegistryEntries(Cls))) + self.assertEqual(expected, list(InterpRegistryEntries(Cls))) def test_progid(self): reg_clsid = GUID.create_new() @@ -249,7 +251,7 @@ class Cls: (HKCR, reg_progid, "", "Lib Server 1"), (HKCR, rf"{reg_progid}\CLSID", "", str(reg_clsid)), ] - self.assertEqual(expected, list(RegistryEntries(Cls))) + self.assertEqual(expected, list(InterpRegistryEntries(Cls))) def test_reg_progid_reg_desc(self): reg_clsid = GUID.create_new() @@ -267,7 +269,7 @@ class Cls: (HKCR, reg_progid, "", "description for testing"), (HKCR, rf"{reg_progid}\CLSID", "", str(reg_clsid)), ] - self.assertEqual(expected, list(RegistryEntries(Cls))) + self.assertEqual(expected, list(InterpRegistryEntries(Cls))) def test_reg_progid_reg_novers_progid(self): reg_clsid = GUID.create_new() @@ -290,7 +292,7 @@ class Cls: (HKCR, rf"{reg_novers_progid}\CurVer", "", "Lib.Server.1"), (HKCR, rf"{reg_novers_progid}\CLSID", "", str(reg_clsid)), ] - self.assertEqual(expected, list(RegistryEntries(Cls))) + self.assertEqual(expected, list(InterpRegistryEntries(Cls))) def test_local_server(self): reg_clsid = GUID.create_new() @@ -306,7 +308,7 @@ class Cls: (HKCR, clsid_sub, "", ""), (HKCR, local_srv_sub, "", f"{sys.executable} {__file__}"), ] - self.assertEqual(expected, list(RegistryEntries(Cls))) + self.assertEqual(expected, list(InterpRegistryEntries(Cls))) def test_inproc_server(self): reg_clsid = GUID.create_new() @@ -325,7 +327,7 @@ class Cls: (HKCR, inproc_srv_sub, "PythonClass", full_classname), (HKCR, inproc_srv_sub, "PythonPath", os.path.dirname(__file__)), ] - self.assertEqual(expected, list(RegistryEntries(Cls))) + self.assertEqual(expected, list(InterpRegistryEntries(Cls))) def test_inproc_server_reg_threading(self): reg_clsid = GUID.create_new() @@ -347,7 +349,7 @@ class Cls: (HKCR, inproc_srv_sub, "PythonPath", os.path.dirname(__file__)), (HKCR, inproc_srv_sub, "ThreadingModel", reg_threading), ] - self.assertEqual(expected, list(RegistryEntries(Cls))) + self.assertEqual(expected, list(InterpRegistryEntries(Cls))) def test_reg_typelib(self): reg_clsid = GUID.create_new() @@ -362,7 +364,7 @@ class Cls: (HKCR, rf"CLSID\{reg_clsid}", "", ""), (HKCR, rf"CLSID\{reg_clsid}\Typelib", "", libid), ] - self.assertEqual(expected, list(RegistryEntries(Cls))) + self.assertEqual(expected, list(InterpRegistryEntries(Cls))) def test_all_entries(self): reg_clsid = GUID.create_new() @@ -403,10 +405,10 @@ class Cls: (HKCR, inproc_srv_sub, "ThreadingModel", reg_threading), (HKCR, rf"{clsid_sub}\Typelib", "", libid), ] - self.assertEqual(expected, list(RegistryEntries(Cls))) + self.assertEqual(expected, list(InterpRegistryEntries(Cls))) -class Test_Frozen_RegistryEntries(ut.TestCase): +class Test_FrozenRegistryEntries(ut.TestCase): SERVERDLL = r"my\target\server.dll" # We do not test the scenario where `frozen` is `'dll'` but @@ -424,7 +426,7 @@ class Cls: # In such cases, the server does not start because the # InprocServer32/LocalServer32 keys are not registered. expected = [(HKCR, rf"CLSID\{reg_clsid}", "", "")] - entries = RegistryEntries(Cls, frozen="dll", frozendllhandle=1234) + entries = FrozenRegistryEntries(Cls, frozen="dll", frozendllhandle=1234) self.assertEqual(expected, list(entries)) def test_local_windows_exe(self): @@ -439,10 +441,12 @@ class Cls: (HKCR, rf"CLSID\{reg_clsid}", "", ""), (HKCR, rf"CLSID\{reg_clsid}\LocalServer32", "", sys.executable), ] - self.assertEqual(expected, list(RegistryEntries(Cls, frozen="windows_exe"))) + self.assertEqual( + expected, list(FrozenRegistryEntries(Cls, frozen="windows_exe")) + ) @mock.patch.object(register, "_get_serverdll", return_value=SERVERDLL) - def test_inproc_dll_nonempty_clsid_to_class(self, get_serverdll): + def test_inproc_dll(self, get_serverdll): reg_clsid = GUID.create_new() reg_clsctx = comtypes.CLSCTX_INPROC_SERVER @@ -457,10 +461,8 @@ class Cls: (HKCR, inproc_srv_sub, "", self.SERVERDLL), ] - with mock.patch.dict(comtypes.server.inprocserver._clsid_to_class): - comtypes.server.inprocserver._clsid_to_class.update({5678: Cls}) - entries = RegistryEntries(Cls, frozen="dll", frozendllhandle=1234) - self.assertEqual(expected, list(entries)) + entries = FrozenRegistryEntries(Cls, frozen="dll", frozendllhandle=1234) + self.assertEqual(expected, list(entries)) get_serverdll.assert_called_once_with(1234) @mock.patch.object(register, "_get_serverdll", return_value=SERVERDLL) @@ -476,18 +478,13 @@ class Cls: clsid_sub = rf"CLSID\{reg_clsid}" inproc_srv_sub = rf"{clsid_sub}\InprocServer32" - full_classname = f"{__name__}.Cls" expected = [ (HKCR, clsid_sub, "", ""), (HKCR, inproc_srv_sub, "", self.SERVERDLL), - # 'PythonClass' and 'PythonPath' are not required for - # frozen inproc servers. This may be bugs but they do - # not affect the server behavior. - (HKCR, inproc_srv_sub, "PythonClass", full_classname), - (HKCR, inproc_srv_sub, "PythonPath", os.path.dirname(__file__)), (HKCR, inproc_srv_sub, "ThreadingModel", reg_threading), ] self.assertEqual( - expected, list(RegistryEntries(Cls, frozen="dll", frozendllhandle=1234)) + expected, + list(FrozenRegistryEntries(Cls, frozen="dll", frozendllhandle=1234)), ) get_serverdll.assert_called_once_with(1234)