diff --git a/comtypes/client/_generate.py b/comtypes/client/_generate.py index 0fa9aca9..9468293e 100644 --- a/comtypes/client/_generate.py +++ b/comtypes/client/_generate.py @@ -121,14 +121,7 @@ def GetModule(tlib: _UnionT[Any, typeinfo.ITypeLib]) -> types.ModuleType: pathname = None tlib = _load_tlib(tlib) logger.debug("GetModule(%s)", tlib.GetLibAttr()) - # create and import the real typelib wrapper module - mod = _create_wrapper_module(tlib, pathname) - # try to get the friendly-name, if not, returns the real typelib wrapper module - modulename = codegenerator.name_friendly_module(tlib) - if modulename is None: - return mod - # create and import the friendly-named module - return _create_friendly_module(tlib, modulename) + return ModuleGenerator().generate(tlib, pathname) def _load_tlib(obj: Any) -> typeinfo.ITypeLib: @@ -193,52 +186,64 @@ def _create_module_in_memory(modulename: str, code: str) -> types.ModuleType: return mod -def _create_friendly_module( - tlib: typeinfo.ITypeLib, modulename: str -) -> types.ModuleType: - """helper which creates and imports the friendly-named module.""" - try: - mod = _my_import(modulename) - except Exception as details: - logger.info("Could not import %s: %s", modulename, details) - else: - return mod - # the module is always regenerated if the import fails - logger.info("# Generating %s", modulename) - # determine the Python module name - modname = codegenerator.name_wrapper_module(tlib).split(".")[-1] - code = "from comtypes.gen import %s\n" % modname - code += "globals().update(%s.__dict__)\n" % modname - code += "__name__ = '%s'" % modulename - if comtypes.client.gen_dir is None: - return _create_module_in_memory(modulename, code) - return _create_module_in_file(modulename, code) - - -def _create_wrapper_module( - tlib: typeinfo.ITypeLib, pathname: Optional[str] -) -> types.ModuleType: - """helper which creates and imports the real typelib wrapper module.""" - modulename = codegenerator.name_wrapper_module(tlib) - if modulename in sys.modules: - return sys.modules[modulename] - try: - return _my_import(modulename) - except Exception as details: - logger.info("Could not import %s: %s", modulename, details) - # generate the module since it doesn't exist or is out of date - logger.info("# Generating %s", modulename) - p = tlbparser.TypeLibParser(tlib) - if pathname is None: - pathname = tlbparser.get_tlib_filename(tlib) - items = list(p.parse().values()) - codegen = codegenerator.CodeGenerator(_get_known_symbols()) - code = codegen.generate_code(items, filename=pathname) - for ext_tlib in codegen.externals: # generates dependency COM-lib modules - GetModule(ext_tlib) - if comtypes.client.gen_dir is None: - return _create_module_in_memory(modulename, code) - return _create_module_in_file(modulename, code) +class ModuleGenerator(object): + def __init__(self) -> None: + self.codegen = codegenerator.CodeGenerator(_get_known_symbols()) + + def generate( + self, tlib: typeinfo.ITypeLib, pathname: Optional[str] + ) -> types.ModuleType: + # create and import the real typelib wrapper module + mod = self._create_wrapper_module(tlib, pathname) + # try to get the friendly-name, if not, returns the real typelib wrapper module + modulename = codegenerator.name_friendly_module(tlib) + if modulename is None: + return mod + # create and import the friendly-named module + return self._create_friendly_module(tlib, modulename) + + def _create_friendly_module( + self, tlib: typeinfo.ITypeLib, modulename: str + ) -> types.ModuleType: + """helper which creates and imports the friendly-named module.""" + try: + mod = _my_import(modulename) + except Exception as details: + logger.info("Could not import %s: %s", modulename, details) + else: + return mod + # the module is always regenerated if the import fails + logger.info("# Generating %s", modulename) + # determine the Python module name + modname = codegenerator.name_wrapper_module(tlib) + code = self.codegen.generate_friendly_code(modname) + if comtypes.client.gen_dir is None: + return _create_module_in_memory(modulename, code) + return _create_module_in_file(modulename, code) + + def _create_wrapper_module( + self, tlib: typeinfo.ITypeLib, pathname: Optional[str] + ) -> types.ModuleType: + """helper which creates and imports the real typelib wrapper module.""" + modulename = codegenerator.name_wrapper_module(tlib) + if modulename in sys.modules: + return sys.modules[modulename] + try: + return _my_import(modulename) + except Exception as details: + logger.info("Could not import %s: %s", modulename, details) + # generate the module since it doesn't exist or is out of date + logger.info("# Generating %s", modulename) + p = tlbparser.TypeLibParser(tlib) + if pathname is None: + pathname = tlbparser.get_tlib_filename(tlib) + items = list(p.parse().values()) + code = self.codegen.generate_wrapper_code(items, filename=pathname) + for ext_tlib in self.codegen.externals: # generates dependency COM-lib modules + GetModule(ext_tlib) + if comtypes.client.gen_dir is None: + return _create_module_in_memory(modulename, code) + return _create_module_in_file(modulename, code) def _get_known_symbols() -> Dict[str, str]: diff --git a/comtypes/tools/codegenerator.py b/comtypes/tools/codegenerator.py index 48a34e42..8bc98686 100644 --- a/comtypes/tools/codegenerator.py +++ b/comtypes/tools/codegenerator.py @@ -7,7 +7,17 @@ import os import sys import textwrap -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union as _UnionT +from typing import ( + Any, + Dict, + Iterator, + List, + Optional, + Sequence, + Set, + Tuple, + Union as _UnionT, +) import io import comtypes @@ -495,9 +505,20 @@ def _generate_typelib_path(self, filename): os.path.abspath(os.path.join(comtypes.gen.__path__[0], path)) ) assert os.path.isfile(p) + self.names.add("typelib_path") + + def generate_wrapper_code( + self, tdescs: Sequence[Any], filename: Optional[str] + ) -> str: + """Returns the code for the COM type library wrapper module. - def generate_code(self, items, filename): + The returned `Python` code string is containing definitions of interfaces, + coclasses, constants, and structures. + The module will have long name that is derived from the type library guid, lcid + and version numbers. + Such as `comtypes.gen._xxxxxxxx_xxxx_xxxx_xxxx_xxxxxxxxxxxx_l_M_m`. + """ tlib_mtime = None if filename is not None: @@ -520,7 +541,7 @@ def generate_code(self, items, filename): self.declarations.add("_lcid", "0", "change this if required") self._generate_typelib_path(filename) - items = set(items) + items = set(tdescs) loops = 0 while items: loops += 1 @@ -557,6 +578,39 @@ def generate_code(self, items, filename): print("_check_version(%r, %f)" % (version, tlib_mtime), file=output) return output.getvalue() + def generate_friendly_code(self, modname: str) -> str: + """Returns the code for the COM type library friendly module. + + The returned `Python` code string is containing `from {modname} import + DefinedInWrapper, ...` and `__all__ = ['DefinedInWrapper', ...]` + The `modname` is the wrapper module name like `comtypes.gen._xxxx..._x_x_x`. + + The module will have shorter name that is derived from the type library name. + Such as "comtypes.gen.stdole" and "comtypes.gen.Excel". + """ + output = io.StringIO() + txtwrapper = textwrap.TextWrapper( + subsequent_indent=" ", initial_indent=" ", break_long_words=False + ) + importing_symbols = set(self.names) + importing_symbols.update(self.imports.get_symbols()) + importing_symbols.update(self.declarations.get_symbols()) + joined_names = ", ".join(str(n) for n in importing_symbols) + symbols = f"from {modname} import {joined_names}" + if len(symbols) > 80: + wrapped_names = "\n".join(txtwrapper.wrap(joined_names)) + symbols = f"from {modname} import (\n{wrapped_names}\n)" + print(symbols, file=output) + print(file=output) + print(file=output) + quoted_names = ", ".join(repr(str(n)) for n in self.names) + dunder_all = f"__all__ = [{quoted_names}]" + if len(dunder_all) > 80: + wrapped_quoted_names = "\n".join(txtwrapper.wrap(quoted_names)) + dunder_all = f"__all__ = [\n{wrapped_quoted_names}\n]" + print(dunder_all, file=output) + return output.getvalue() + def need_VARIANT_imports(self, value): text = repr(value) if "Decimal(" in text: @@ -876,6 +930,7 @@ def TypeLib(self, lib: typedesc.TypeLib) -> None: ) print(file=self.stream) print(file=self.stream) + self.names.add("Library") def External(self, ext: typedesc.External) -> None: modname = name_wrapper_module(ext.tlib) @@ -1329,6 +1384,10 @@ def add(self, name1, name2=None, symbols=None): IUnknown ) import ctypes.wintypes + >>> assert imports.get_symbols() == { + ... 'Decimal', 'GUID', 'COMMETHOD', 'DISPMETHOD', 'IUnknown', + ... 'dispid', 'CoClass', 'BSTR', 'DISPPROPERTY' + ... } >>> print(imports.getvalue(for_stub=True)) from ctypes import * import datetime @@ -1381,6 +1440,14 @@ def __contains__(self, item): return self.data[import_] == from_ return False + def get_symbols(self) -> Set[str]: + names = set() + for key, val in self.data.items(): + if val is None or key == "*": + continue + names.add(key) + return names + def _make_line(self, from_, imports, for_stub): if for_stub: import_ = ", ".join("%s as %s" % (n, n) for n in imports) @@ -1432,9 +1499,18 @@ def add(self, alias, definition, comment=None): >>> print(declarations.getvalue()) STRING = c_char_p _lcid = 0 # change this if required + >>> assert declarations.get_symbols() == { + ... 'STRING', '_lcid' + ... } """ self.data[(alias, definition)] = comment + def get_symbols(self) -> Set[str]: + names = set() + for alias, _ in self.data.keys(): + names.add(alias) + return names + def getvalue(self): lines = [] for (alias, definition), comment in self.data.items():