Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent the occurrence of SyntaxError in friendly modules. #533

Merged
merged 9 commits into from
May 5, 2024
Next Next commit
fix ModuleGenerator and unify _create_module_...
junkmd committed Apr 11, 2024
commit 81bdb667912956ec05dbd588de8ec5a5584d2f8b
133 changes: 62 additions & 71 deletions comtypes/client/_generate.py
Original file line number Diff line number Diff line change
@@ -121,7 +121,7 @@ def GetModule(tlib: _UnionT[Any, typeinfo.ITypeLib]) -> types.ModuleType:
pathname = None
tlib = _load_tlib(tlib)
logger.debug("GetModule(%s)", tlib.GetLibAttr())
return ModuleGenerator().generate(tlib, pathname)
return ModuleGenerator(tlib, pathname).generate()


def _load_tlib(obj: Any) -> typeinfo.ITypeLib:
@@ -160,90 +160,81 @@ def _load_tlib(obj: Any) -> typeinfo.ITypeLib:
raise TypeError("'%r' is not supported type for loading typelib" % obj)


def _create_module_in_file(modulename: str, code: str) -> types.ModuleType:
"""create module in file system, and import it"""
def _create_module(modulename: str, code: str) -> types.ModuleType:
"""Creates the module, then imports it."""
# `modulename` is 'comtypes.gen.xxx'
filename = "%s.py" % modulename.split(".")[-1]
with open(os.path.join(comtypes.client.gen_dir, filename), "w") as ofi:
stem = modulename.split(".")[-1]
if comtypes.client.gen_dir is None:
# in memory system
import comtypes.gen as g

mod = types.ModuleType(modulename)
abs_gen_path = os.path.abspath(g.__path__[0]) # type: ignore
mod.__file__ = os.path.join(abs_gen_path, "<memory>")
exec(code, mod.__dict__)
sys.modules[modulename] = mod
setattr(g, stem, mod)
return mod
# in file system
with open(os.path.join(comtypes.client.gen_dir, f"{stem}.py"), "w") as ofi:
print(code, file=ofi)
# clear the import cache to make sure Python sees newly created modules
if hasattr(importlib, "invalidate_caches"):
importlib.invalidate_caches()
importlib.invalidate_caches()
return _my_import(modulename)


def _create_module_in_memory(modulename: str, code: str) -> types.ModuleType:
"""create module in memory system, and import it"""
# `modulename` is 'comtypes.gen.xxx'
import comtypes.gen as g

mod = types.ModuleType(modulename)
abs_gen_path = os.path.abspath(g.__path__[0]) # type: ignore
mod.__file__ = os.path.join(abs_gen_path, "<memory>")
exec(code, mod.__dict__)
sys.modules[modulename] = mod
setattr(g, modulename.split(".")[-1], mod)
return mod


class ModuleGenerator(object):
def __init__(self) -> None:
self.codegen = codegenerator.CodeGenerator(_get_known_symbols())

def generate(
self, tlib: typeinfo.ITypeLib, pathname: Optional[str]
) -> types.ModuleType:
# create and import the real typelib wrapper module
mod = self._create_wrapper_module(tlib, pathname)
# try to get the friendly-name, if not, returns the real typelib wrapper module
modulename = codegenerator.name_friendly_module(tlib)
if modulename is None:
return mod
# create and import the friendly-named module
return self._create_friendly_module(tlib, modulename)
def __init__(self, tlib: typeinfo.ITypeLib, pathname: Optional[str]) -> None:
self.wrapper_name = codegenerator.name_wrapper_module(tlib)
self.friendly_name = codegenerator.name_friendly_module(tlib)
if pathname is None:
self.pathname = tlbparser.get_tlib_filename(tlib)
else:
self.pathname = pathname
self.tlib = tlib

def generate(self) -> types.ModuleType:
# tries to import existing modules
wrapper_module = self._get_existing_wrapper_module()
if wrapper_module is not None:
if self.friendly_name is None:
return wrapper_module
else:
friendly_module = self._get_existing_friendly_module()
if friendly_module is not None:
return friendly_module
# (re)generates wrapper and friendly modules
codegen = codegenerator.CodeGenerator(_get_known_symbols())
codebases: List[Tuple[str, str]] = []
logger.info("# Generating %s", self.wrapper_name)
items = list(tlbparser.TypeLibParser(self.tlib).parse().values())
wrp_code = codegen.generate_wrapper_code(items, filename=self.pathname)
codebases.append((self.wrapper_name, wrp_code))
if self.friendly_name is not None:
logger.info("# Generating %s", self.friendly_name)
frd_code = codegen.generate_friendly_code(self.wrapper_name)
codebases.append((self.friendly_name, frd_code))
for ext_tlib in codegen.externals: # generates dependency COM-lib modules
GetModule(ext_tlib)
return [_create_module(name, code) for (name, code) in codebases][-1]

def _create_friendly_module(
self, tlib: typeinfo.ITypeLib, modulename: str
) -> types.ModuleType:
"""helper which creates and imports the friendly-named module."""
def _get_existing_friendly_module(self) -> Optional[types.ModuleType]:
if self.friendly_name is None:
return
try:
mod = _my_import(modulename)
mod = _my_import(self.friendly_name)
except Exception as details:
logger.info("Could not import %s: %s", modulename, details)
logger.info("Could not import %s: %s", self.friendly_name, details)
else:
return mod
# the module is always regenerated if the import fails
logger.info("# Generating %s", modulename)
# determine the Python module name
modname = codegenerator.name_wrapper_module(tlib)
code = self.codegen.generate_friendly_code(modname)
if comtypes.client.gen_dir is None:
return _create_module_in_memory(modulename, code)
return _create_module_in_file(modulename, code)

def _create_wrapper_module(
self, tlib: typeinfo.ITypeLib, pathname: Optional[str]
) -> types.ModuleType:
"""helper which creates and imports the real typelib wrapper module."""
modulename = codegenerator.name_wrapper_module(tlib)
if modulename in sys.modules:
return sys.modules[modulename]

def _get_existing_wrapper_module(self) -> Optional[types.ModuleType]:
if self.wrapper_name in sys.modules:
return sys.modules[self.wrapper_name]
try:
return _my_import(modulename)
return _my_import(self.wrapper_name)
except Exception as details:
logger.info("Could not import %s: %s", modulename, details)
# generate the module since it doesn't exist or is out of date
logger.info("# Generating %s", modulename)
p = tlbparser.TypeLibParser(tlib)
if pathname is None:
pathname = tlbparser.get_tlib_filename(tlib)
items = list(p.parse().values())
code = self.codegen.generate_wrapper_code(items, filename=pathname)
for ext_tlib in self.codegen.externals: # generates dependency COM-lib modules
GetModule(ext_tlib)
if comtypes.client.gen_dir is None:
return _create_module_in_memory(modulename, code)
return _create_module_in_file(modulename, code)
logger.info("Could not import %s: %s", self.wrapper_name, details)


def _get_known_symbols() -> Dict[str, str]: