Skip to content

Commit

Permalink
make explicit the symbols that imports from the wrapper module into t…
Browse files Browse the repository at this point in the history
…he friendly module (#469)

* add `Library` to generated modules' `__all__`.
because that symbol is public but not included.

* add `typelib_path` to generated modules' `__all__`.
because that symbol is public but not included.

* make `ModuleGenerator` class
that encapsulates `CodeGenerator` instance.

* rename to `generate_wrapper_code`
from `generate_code`

* add `generate_friendly_code`

* add type annotations to `generate_wrapper_code`

* add docstring

* add `get_symbols` methods
to `DeclaredNamespaces` and `ImportedNamespaces`

* update imporing symbols

* add type annotation to return value for `__init__`

* change to using f-string
in `generate_friendly_code`
  • Loading branch information
junkmd authored Jan 17, 2023
1 parent e1ee6f0 commit 532c399
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 57 deletions.
113 changes: 59 additions & 54 deletions comtypes/client/_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,7 @@ def GetModule(tlib: _UnionT[Any, typeinfo.ITypeLib]) -> types.ModuleType:
pathname = None
tlib = _load_tlib(tlib)
logger.debug("GetModule(%s)", tlib.GetLibAttr())
# create and import the real typelib wrapper module
mod = _create_wrapper_module(tlib, pathname)
# try to get the friendly-name, if not, returns the real typelib wrapper module
modulename = codegenerator.name_friendly_module(tlib)
if modulename is None:
return mod
# create and import the friendly-named module
return _create_friendly_module(tlib, modulename)
return ModuleGenerator().generate(tlib, pathname)


def _load_tlib(obj: Any) -> typeinfo.ITypeLib:
Expand Down Expand Up @@ -193,52 +186,64 @@ def _create_module_in_memory(modulename: str, code: str) -> types.ModuleType:
return mod


def _create_friendly_module(
tlib: typeinfo.ITypeLib, modulename: str
) -> types.ModuleType:
"""helper which creates and imports the friendly-named module."""
try:
mod = _my_import(modulename)
except Exception as details:
logger.info("Could not import %s: %s", modulename, details)
else:
return mod
# the module is always regenerated if the import fails
logger.info("# Generating %s", modulename)
# determine the Python module name
modname = codegenerator.name_wrapper_module(tlib).split(".")[-1]
code = "from comtypes.gen import %s\n" % modname
code += "globals().update(%s.__dict__)\n" % modname
code += "__name__ = '%s'" % modulename
if comtypes.client.gen_dir is None:
return _create_module_in_memory(modulename, code)
return _create_module_in_file(modulename, code)


def _create_wrapper_module(
tlib: typeinfo.ITypeLib, pathname: Optional[str]
) -> types.ModuleType:
"""helper which creates and imports the real typelib wrapper module."""
modulename = codegenerator.name_wrapper_module(tlib)
if modulename in sys.modules:
return sys.modules[modulename]
try:
return _my_import(modulename)
except Exception as details:
logger.info("Could not import %s: %s", modulename, details)
# generate the module since it doesn't exist or is out of date
logger.info("# Generating %s", modulename)
p = tlbparser.TypeLibParser(tlib)
if pathname is None:
pathname = tlbparser.get_tlib_filename(tlib)
items = list(p.parse().values())
codegen = codegenerator.CodeGenerator(_get_known_symbols())
code = codegen.generate_code(items, filename=pathname)
for ext_tlib in codegen.externals: # generates dependency COM-lib modules
GetModule(ext_tlib)
if comtypes.client.gen_dir is None:
return _create_module_in_memory(modulename, code)
return _create_module_in_file(modulename, code)
class ModuleGenerator(object):
def __init__(self) -> None:
self.codegen = codegenerator.CodeGenerator(_get_known_symbols())

def generate(
self, tlib: typeinfo.ITypeLib, pathname: Optional[str]
) -> types.ModuleType:
# create and import the real typelib wrapper module
mod = self._create_wrapper_module(tlib, pathname)
# try to get the friendly-name, if not, returns the real typelib wrapper module
modulename = codegenerator.name_friendly_module(tlib)
if modulename is None:
return mod
# create and import the friendly-named module
return self._create_friendly_module(tlib, modulename)

def _create_friendly_module(
self, tlib: typeinfo.ITypeLib, modulename: str
) -> types.ModuleType:
"""helper which creates and imports the friendly-named module."""
try:
mod = _my_import(modulename)
except Exception as details:
logger.info("Could not import %s: %s", modulename, details)
else:
return mod
# the module is always regenerated if the import fails
logger.info("# Generating %s", modulename)
# determine the Python module name
modname = codegenerator.name_wrapper_module(tlib)
code = self.codegen.generate_friendly_code(modname)
if comtypes.client.gen_dir is None:
return _create_module_in_memory(modulename, code)
return _create_module_in_file(modulename, code)

def _create_wrapper_module(
self, tlib: typeinfo.ITypeLib, pathname: Optional[str]
) -> types.ModuleType:
"""helper which creates and imports the real typelib wrapper module."""
modulename = codegenerator.name_wrapper_module(tlib)
if modulename in sys.modules:
return sys.modules[modulename]
try:
return _my_import(modulename)
except Exception as details:
logger.info("Could not import %s: %s", modulename, details)
# generate the module since it doesn't exist or is out of date
logger.info("# Generating %s", modulename)
p = tlbparser.TypeLibParser(tlib)
if pathname is None:
pathname = tlbparser.get_tlib_filename(tlib)
items = list(p.parse().values())
code = self.codegen.generate_wrapper_code(items, filename=pathname)
for ext_tlib in self.codegen.externals: # generates dependency COM-lib modules
GetModule(ext_tlib)
if comtypes.client.gen_dir is None:
return _create_module_in_memory(modulename, code)
return _create_module_in_file(modulename, code)


def _get_known_symbols() -> Dict[str, str]:
Expand Down
82 changes: 79 additions & 3 deletions comtypes/tools/codegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,17 @@
import os
import sys
import textwrap
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union as _UnionT
from typing import (
Any,
Dict,
Iterator,
List,
Optional,
Sequence,
Set,
Tuple,
Union as _UnionT,
)
import io

import comtypes
Expand Down Expand Up @@ -495,9 +505,20 @@ def _generate_typelib_path(self, filename):
os.path.abspath(os.path.join(comtypes.gen.__path__[0], path))
)
assert os.path.isfile(p)
self.names.add("typelib_path")

def generate_wrapper_code(
self, tdescs: Sequence[Any], filename: Optional[str]
) -> str:
"""Returns the code for the COM type library wrapper module.
def generate_code(self, items, filename):
The returned `Python` code string is containing definitions of interfaces,
coclasses, constants, and structures.
The module will have long name that is derived from the type library guid, lcid
and version numbers.
Such as `comtypes.gen._xxxxxxxx_xxxx_xxxx_xxxx_xxxxxxxxxxxx_l_M_m`.
"""
tlib_mtime = None

if filename is not None:
Expand All @@ -520,7 +541,7 @@ def generate_code(self, items, filename):
self.declarations.add("_lcid", "0", "change this if required")
self._generate_typelib_path(filename)

items = set(items)
items = set(tdescs)
loops = 0
while items:
loops += 1
Expand Down Expand Up @@ -557,6 +578,39 @@ def generate_code(self, items, filename):
print("_check_version(%r, %f)" % (version, tlib_mtime), file=output)
return output.getvalue()

def generate_friendly_code(self, modname: str) -> str:
"""Returns the code for the COM type library friendly module.
The returned `Python` code string is containing `from {modname} import
DefinedInWrapper, ...` and `__all__ = ['DefinedInWrapper', ...]`
The `modname` is the wrapper module name like `comtypes.gen._xxxx..._x_x_x`.
The module will have shorter name that is derived from the type library name.
Such as "comtypes.gen.stdole" and "comtypes.gen.Excel".
"""
output = io.StringIO()
txtwrapper = textwrap.TextWrapper(
subsequent_indent=" ", initial_indent=" ", break_long_words=False
)
importing_symbols = set(self.names)
importing_symbols.update(self.imports.get_symbols())
importing_symbols.update(self.declarations.get_symbols())
joined_names = ", ".join(str(n) for n in importing_symbols)
symbols = f"from {modname} import {joined_names}"
if len(symbols) > 80:
wrapped_names = "\n".join(txtwrapper.wrap(joined_names))
symbols = f"from {modname} import (\n{wrapped_names}\n)"
print(symbols, file=output)
print(file=output)
print(file=output)
quoted_names = ", ".join(repr(str(n)) for n in self.names)
dunder_all = f"__all__ = [{quoted_names}]"
if len(dunder_all) > 80:
wrapped_quoted_names = "\n".join(txtwrapper.wrap(quoted_names))
dunder_all = f"__all__ = [\n{wrapped_quoted_names}\n]"
print(dunder_all, file=output)
return output.getvalue()

def need_VARIANT_imports(self, value):
text = repr(value)
if "Decimal(" in text:
Expand Down Expand Up @@ -876,6 +930,7 @@ def TypeLib(self, lib: typedesc.TypeLib) -> None:
)
print(file=self.stream)
print(file=self.stream)
self.names.add("Library")

def External(self, ext: typedesc.External) -> None:
modname = name_wrapper_module(ext.tlib)
Expand Down Expand Up @@ -1329,6 +1384,10 @@ def add(self, name1, name2=None, symbols=None):
IUnknown
)
import ctypes.wintypes
>>> assert imports.get_symbols() == {
... 'Decimal', 'GUID', 'COMMETHOD', 'DISPMETHOD', 'IUnknown',
... 'dispid', 'CoClass', 'BSTR', 'DISPPROPERTY'
... }
>>> print(imports.getvalue(for_stub=True))
from ctypes import *
import datetime
Expand Down Expand Up @@ -1381,6 +1440,14 @@ def __contains__(self, item):
return self.data[import_] == from_
return False

def get_symbols(self) -> Set[str]:
names = set()
for key, val in self.data.items():
if val is None or key == "*":
continue
names.add(key)
return names

def _make_line(self, from_, imports, for_stub):
if for_stub:
import_ = ", ".join("%s as %s" % (n, n) for n in imports)
Expand Down Expand Up @@ -1432,9 +1499,18 @@ def add(self, alias, definition, comment=None):
>>> print(declarations.getvalue())
STRING = c_char_p
_lcid = 0 # change this if required
>>> assert declarations.get_symbols() == {
... 'STRING', '_lcid'
... }
"""
self.data[(alias, definition)] = comment

def get_symbols(self) -> Set[str]:
names = set()
for alias, _ in self.data.keys():
names.add(alias)
return names

def getvalue(self):
lines = []
for (alias, definition), comment in self.data.items():
Expand Down

0 comments on commit 532c399

Please sign in to comment.