Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace GetModuleFileNameA with GetModuleFileNameW to prevent a TypeError. #733

Merged
merged 4 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions comtypes/server/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,24 @@
python mycomobj.py /nodebug
"""

import ctypes
import logging
import os
import sys
import winreg
from ctypes import WinError, c_ulong, c_wchar_p, create_string_buffer, sizeof, windll
from ctypes import WinError, windll
from typing import Iterator, Tuple

import comtypes
import comtypes.server.inprocserver
from comtypes.hresult import *
from comtypes.server import w_getopt
from comtypes.typeinfo import REGKIND_REGISTER, LoadTypeLibEx, UnRegisterTypeLib
from comtypes.typeinfo import (
REGKIND_REGISTER,
GetModuleFileName,
LoadTypeLibEx,
UnRegisterTypeLib,
)

_debug = logging.getLogger(__name__).debug

Expand All @@ -67,7 +73,7 @@ def _non_zero(retval, func, args):

SHDeleteKey = windll.shlwapi.SHDeleteKeyW
SHDeleteKey.errcheck = _non_zero
SHDeleteKey.argtypes = c_ulong, c_wchar_p
SHDeleteKey.argtypes = ctypes.c_ulong, ctypes.c_wchar_p

Set = set

Expand Down Expand Up @@ -219,9 +225,7 @@ def _get_serverdll():
"""Return the pathname of the dll hosting the COM object."""
handle = getattr(sys, "frozendllhandle", None)
if handle is not None:
buf = create_string_buffer(260)
windll.kernel32.GetModuleFileNameA(handle, buf, sizeof(buf))
return buf[:]
return GetModuleFileName(handle, 260)
import _ctypes

return _ctypes.__file__
Expand Down
23 changes: 10 additions & 13 deletions comtypes/test/test_server_register.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import _ctypes
import ctypes
import os
import sys
import unittest as ut
Expand Down Expand Up @@ -194,18 +193,16 @@ class Test_get_serverdll(ut.TestCase):
def test_nonfrozen(self):
self.assertEqual(_ctypes.__file__, _get_serverdll())

def test_frozen(self):
with mock.patch.object(register, "sys") as _sys:
with mock.patch.object(register, "windll") as _windll:
handle = 1234
_sys.frozendllhandle = handle
self.assertEqual(b"\x00" * 260, _get_serverdll())
GetModuleFileName = _windll.kernel32.GetModuleFileNameA
(((hModule, lpFilename, nSize), _),) = GetModuleFileName.call_args_list
self.assertEqual(handle, hModule)
buf_type = type(ctypes.create_string_buffer(260))
self.assertIsInstance(lpFilename, buf_type)
self.assertEqual(260, nSize)
@mock.patch.object(register, "GetModuleFileName")
@mock.patch.object(register, "sys")
def test_frozen(self, _sys, GetModuleFileName):
handle, dll_path = 1234, r"path\to\frozendll"
_sys.frozendllhandle = handle
GetModuleFileName.return_value = dll_path
self.assertEqual(r"path\to\frozendll", _get_serverdll())
(((hmodule, maxsize), _),) = GetModuleFileName.call_args_list
self.assertEqual(handle, hmodule)
self.assertEqual(260, maxsize)


class Test_NonFrozen_RegistryEntries(ut.TestCase):
Expand Down
26 changes: 19 additions & 7 deletions comtypes/test/test_typeinfo.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import os
import ctypes
import sys
import unittest
from ctypes import POINTER, byref

from comtypes import GUID, COMError
from comtypes.automation import DISPATCH_METHOD
from comtypes.typeinfo import (
LoadTypeLibEx,
TKIND_DISPATCH,
TKIND_INTERFACE,
GetModuleFileName,
LoadRegTypeLib,
LoadTypeLibEx,
QueryPathOfRegTypeLib,
TKIND_INTERFACE,
TKIND_DISPATCH,
TKIND_ENUM,
)


Expand Down Expand Up @@ -94,5 +94,17 @@ def test_TypeInfo(self):
self.assertEqual(guid, ti.GetTypeAttr().guid)


class Test_GetModuleFileName(unittest.TestCase):
def test_null_handler(self):
self.assertEqual(GetModuleFileName(None, 260), sys.executable)

def test_loaded_module_handle(self):
import _ctypes

dll_path = _ctypes.__file__
hmodule = ctypes.WinDLL(dll_path)._handle
self.assertEqual(GetModuleFileName(hmodule, 260), dll_path)


if __name__ == "__main__":
unittest.main()
65 changes: 55 additions & 10 deletions comtypes/typeinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,48 @@
# generated by 'xml2py'
# flags '..\tools\windows.xml -m comtypes -m comtypes.automation -w -r .*TypeLibEx -r .*TypeLib -o typeinfo.py'
# then hacked manually
import ctypes
import sys
from typing import Any, overload, TypeVar, TYPE_CHECKING
from typing import List, Type, Tuple
from typing import Optional, Union as _UnionT
from typing import Callable, Sequence
import weakref

import ctypes
from ctypes import HRESULT, POINTER, _Pointer, byref, c_int, c_void_p, c_wchar_p
from ctypes.wintypes import DWORD, LONG, UINT, ULONG, WCHAR, WORD, INT, SHORT, USHORT
from comtypes import BSTR, _CData, COMMETHOD, GUID, IID, IUnknown, STDMETHOD
from comtypes.automation import DISPID, LCID, SCODE
from comtypes.automation import DISPPARAMS, EXCEPINFO, VARIANT, VARIANTARG, VARTYPE
from ctypes.wintypes import (
DWORD,
HMODULE,
INT,
LONG,
LPWSTR,
SHORT,
UINT,
ULONG,
USHORT,
WCHAR,
WORD,
)
from typing import (
TYPE_CHECKING,
Any,
Callable,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
overload,
)
from typing import Union as _UnionT

from comtypes import BSTR, COMMETHOD, GUID, IID, STDMETHOD, IUnknown, _CData
from comtypes.automation import (
DISPID,
DISPPARAMS,
EXCEPINFO,
LCID,
SCODE,
VARIANT,
VARIANTARG,
VARTYPE,
)

if TYPE_CHECKING:
from comtypes import hints # type: ignore
Expand Down Expand Up @@ -666,6 +695,22 @@ def QueryPathOfRegTypeLib(
return pathname.value.split("\0")[0]


_GetModuleFileNameW = ctypes.windll.kernel32.GetModuleFileNameW
_GetModuleFileNameW.argtypes = HMODULE, LPWSTR, DWORD
_GetModuleFileNameW.restype = DWORD


def GetModuleFileName(handle: Optional[int], maxsize: int) -> str:
"""Returns the fullpath of the loaded module specified by the handle.
If the handle is NULL, returns the executable file path of the current process.

https://learn.microsoft.com/ja-jp/windows/win32/api/libloaderapi/nf-libloaderapi-loadlibraryw
"""
buf = ctypes.create_unicode_buffer(maxsize)
length = _GetModuleFileNameW(handle, buf, maxsize)
return buf.value[:length]


################################################################
# Structures

Expand Down