Skip to content

Commit

Permalink
Add enumeration definitions in the friendly modules (#475)
Browse files Browse the repository at this point in the history
* add test for enumerations in friendly module

* add enums generation processes

* add the sprout class `EnumerationNamespaces`

* refactor the `codegenerator`

* remove `wrapper_module_name` arg
from `EnumerationNamespaces.getvalue`
  • Loading branch information
junkmd authored Mar 19, 2024
1 parent aa770ca commit 765daf7
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 30 deletions.
9 changes: 9 additions & 0 deletions comtypes/test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,15 @@ def test_progid(self):
self.assertEqual(consts.TextCompare, Scripting.TextCompare)
self.assertEqual(consts.DatabaseCompare, Scripting.DatabaseCompare)

def test_enums_in_friendly_mod(self):
consts = comtypes.client.Constants("scrrun.dll")
comtypes.client.GetModule("scrrun.dll")
from comtypes.gen import Scripting

for e in Scripting.StandardStreamTypes:
self.assertIn(e.name, consts.StandardStreamTypes)
self.assertEqual(consts.StandardStreamTypes[e.name], e.value)

def test_returns_other_than_enum_members(self):
obj = comtypes.client.CreateObject("SAPI.SpVoice")
from comtypes.gen import SpeechLib as sapi
Expand Down
126 changes: 96 additions & 30 deletions comtypes/tools/codegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,12 +423,14 @@ def __init__(self, known_symbols=None):
self.stream = io.StringIO()
self.imports = ImportedNamespaces()
self.declarations = DeclaredNamespaces()
self.enums = EnumerationNamespaces()
self._to_type_name = TypeNamer()
self.known_symbols = known_symbols or {}

self.done = set() # type descriptions that have been generated
self.names = set() # names that have been generated
self.externals = [] # typelibs imported to generated module
self.aliases: Dict[str, str] = {}
self.last_item_class = False

def generate(self, item):
Expand Down Expand Up @@ -571,14 +573,7 @@ def generate_wrapper_code(
print(self.declarations.getvalue(), file=output)
print(file=output)
print(self.stream.getvalue(), file=output)
names = ", ".join(repr(str(n)) for n in self.names)
dunder_all = "__all__ = [%s]" % names
if len(dunder_all) > 80:
wrapper = textwrap.TextWrapper(
subsequent_indent=" ", initial_indent=" ", break_long_words=False
)
dunder_all = "__all__ = [\n%s\n]" % "\n".join(wrapper.wrap(names))
print(dunder_all, file=output)
print(self._make_dunder_all_part(), file=output)
print(file=output)
if tlib_mtime is not None:
print("_check_version(%r, %f)" % (version, tlib_mtime), file=output)
Expand All @@ -595,29 +590,59 @@ def generate_friendly_code(self, modname: str) -> str:
Such as "comtypes.gen.stdole" and "comtypes.gen.Excel".
"""
output = io.StringIO()
print("from enum import IntFlag", file=output)
print(file=output)
print(f"import {modname} as __wrapper_module__", file=output)
txtwrapper = textwrap.TextWrapper(
subsequent_indent=" ", initial_indent=" ", break_long_words=False
)
importing_symbols = set(self.names)
importing_symbols.update(self.imports.get_symbols())
importing_symbols.update(self.declarations.get_symbols())
joined_names = ", ".join(str(n) for n in importing_symbols)
symbols = f"from {modname} import {joined_names}"
if len(symbols) > 80:
wrapped_names = "\n".join(txtwrapper.wrap(joined_names))
symbols = f"from {modname} import (\n{wrapped_names}\n)"
print(symbols, file=output)
print(self._make_friendly_module_import_part(modname), file=output)
print(file=output)
print(file=output)
quoted_names = ", ".join(repr(str(n)) for n in self.names)
dunder_all = f"__all__ = [{quoted_names}]"
if len(dunder_all) > 80:
wrapped_quoted_names = "\n".join(txtwrapper.wrap(quoted_names))
dunder_all = f"__all__ = [\n{wrapped_quoted_names}\n]"
print(dunder_all, file=output)
print(self.enums.getvalue(), file=output)
print(file=output)
print(file=output)
enum_aliases = self.enum_aliases
if enum_aliases:
for k, v in enum_aliases.items():
print(f"{k} = {v}", file=output)
print(file=output)
print(file=output)
print(self._make_dunder_all_part(), file=output)
return output.getvalue()

def _make_dunder_all_part(self) -> str:
joined_names = ", ".join(repr(str(n)) for n in self.names)
dunder_all = f"__all__ = [{joined_names}]"
if len(dunder_all) > 80:
txtwrapper = textwrap.TextWrapper(
subsequent_indent=" ", initial_indent=" ", break_long_words=False
)
joined_names = "\n".join(txtwrapper.wrap(joined_names))
dunder_all = f"__all__ = [\n{joined_names}\n]"
return dunder_all

def _make_friendly_module_import_part(self, modname: str) -> str:
# The `modname` is the wrapper module name like `comtypes.gen._xxxx..._x_x_x`
txtwrapper = textwrap.TextWrapper(
subsequent_indent=" ", initial_indent=" ", break_long_words=False
)
symbols = set(self.names)
symbols.update(self.imports.get_symbols())
symbols.update(self.declarations.get_symbols())
symbols -= set(self.enums.get_symbols())
symbols -= set(self.enum_aliases)
joined_names = ", ".join(str(n) for n in symbols)
part = f"from {modname} import {joined_names}"
if len(part) > 80:
txtwrapper = textwrap.TextWrapper(
subsequent_indent=" ", initial_indent=" ", break_long_words=False
)
joined_names = "\n".join(txtwrapper.wrap(joined_names))
part = f"from {modname} import (\n{joined_names}\n)"
return part

@property
def enum_aliases(self) -> Dict[str, str]:
return {k: v for k, v in self.aliases.items() if v in self.enums}

def need_VARIANT_imports(self, value):
text = repr(value)
if "Decimal(" in text:
Expand Down Expand Up @@ -645,6 +670,8 @@ def EnumValue(self, tp: typedesc.EnumValue) -> None:
print("# Fixing keyword as EnumValue for %s" % tp.name)
tp_name = self._to_type_name(tp)
print("%s = %d" % (tp_name, value), file=self.stream)
if tp.enumeration.name:
self.enums.add(tp.enumeration.name, tp_name)
self.names.add(tp_name)

def Enumeration(self, tp: typedesc.Enumeration) -> None:
Expand All @@ -653,10 +680,6 @@ def Enumeration(self, tp: typedesc.Enumeration) -> None:
print("# values for enumeration '%s'" % tp.name, file=self.stream)
else:
print("# values for unnamed enumeration", file=self.stream)
# Some enumerations have the same name for the enum type
# and an enum value. Excel's XlDisplayShapes is such an example.
# Since we don't have separate namespaces for the type and the values,
# we generate the TYPE last, overwriting the value. XXX
for item in tp.values:
self.generate(item)
if tp.name:
Expand All @@ -675,6 +698,7 @@ def Typedef(self, tp: typedesc.Typedef) -> None:
self.declarations.add(tp.name, definition)
else:
print("%s = %s" % (tp.name, definition), file=self.stream)
self.aliases[tp.name] = definition
self.last_item_class = False
self.names.add(tp.name)

Expand Down Expand Up @@ -1518,3 +1542,45 @@ def getvalue(self):
code = code + " # %s" % comment
lines.append(code)
return "\n".join(lines)


class EnumerationNamespaces(object):
def __init__(self):
self.data: Dict[str, List[str]] = {}

def add(self, enum_name: str, member_name: str) -> None:
"""Adds a namespace will be enumeration and its member.
Examples:
>>> enums = EnumerationNamespaces()
>>> enums.add('Foo', 'ham')
>>> enums.add('Foo', 'spam')
>>> enums.add('Bar', 'bacon')
>>> assert 'Foo' in enums
>>> assert 'Baz' not in enums
>>> print(enums.getvalue()) # <BLANKLINE> is necessary for doctest
class Foo(IntFlag):
ham = __wrapper_module__.ham
spam = __wrapper_module__.spam
<BLANKLINE>
<BLANKLINE>
class Bar(IntFlag):
bacon = __wrapper_module__.bacon
"""
self.data.setdefault(enum_name, []).append(member_name)

def __contains__(self, item: str) -> bool:
return item in self.data

def get_symbols(self) -> Set[str]:
return set(self.data)

def getvalue(self) -> str:
blocks = []
for enum_name, enum_members in self.data.items():
lines = []
lines.append(f"class {enum_name}(IntFlag):")
for member_name in enum_members:
lines.append(f" {member_name} = __wrapper_module__.{member_name}")
blocks.append("\n".join(lines))
return "\n\n\n".join(blocks)

0 comments on commit 765daf7

Please sign in to comment.