Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add enumeration definitions in the friendly modules #475

Merged
merged 5 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions comtypes/test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,15 @@ def test_progid(self):
self.assertEqual(consts.TextCompare, Scripting.TextCompare)
self.assertEqual(consts.DatabaseCompare, Scripting.DatabaseCompare)

def test_enums_in_friendly_mod(self):
consts = comtypes.client.Constants("scrrun.dll")
comtypes.client.GetModule("scrrun.dll")
from comtypes.gen import Scripting

for e in Scripting.StandardStreamTypes:
self.assertIn(e.name, consts.StandardStreamTypes)
self.assertEqual(consts.StandardStreamTypes[e.name], e.value)

def test_returns_other_than_enum_members(self):
obj = comtypes.client.CreateObject("SAPI.SpVoice")
from comtypes.gen import SpeechLib as sapi
Expand Down
126 changes: 96 additions & 30 deletions comtypes/tools/codegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,12 +423,14 @@ def __init__(self, known_symbols=None):
self.stream = io.StringIO()
self.imports = ImportedNamespaces()
self.declarations = DeclaredNamespaces()
self.enums = EnumerationNamespaces()
self._to_type_name = TypeNamer()
self.known_symbols = known_symbols or {}

self.done = set() # type descriptions that have been generated
self.names = set() # names that have been generated
self.externals = [] # typelibs imported to generated module
self.aliases: Dict[str, str] = {}
self.last_item_class = False

def generate(self, item):
Expand Down Expand Up @@ -571,14 +573,7 @@ def generate_wrapper_code(
print(self.declarations.getvalue(), file=output)
print(file=output)
print(self.stream.getvalue(), file=output)
names = ", ".join(repr(str(n)) for n in self.names)
dunder_all = "__all__ = [%s]" % names
if len(dunder_all) > 80:
wrapper = textwrap.TextWrapper(
subsequent_indent=" ", initial_indent=" ", break_long_words=False
)
dunder_all = "__all__ = [\n%s\n]" % "\n".join(wrapper.wrap(names))
print(dunder_all, file=output)
print(self._make_dunder_all_part(), file=output)
print(file=output)
if tlib_mtime is not None:
print("_check_version(%r, %f)" % (version, tlib_mtime), file=output)
Expand All @@ -595,29 +590,59 @@ def generate_friendly_code(self, modname: str) -> str:
Such as "comtypes.gen.stdole" and "comtypes.gen.Excel".
"""
output = io.StringIO()
print("from enum import IntFlag", file=output)
print(file=output)
print(f"import {modname} as __wrapper_module__", file=output)
txtwrapper = textwrap.TextWrapper(
subsequent_indent=" ", initial_indent=" ", break_long_words=False
)
importing_symbols = set(self.names)
importing_symbols.update(self.imports.get_symbols())
importing_symbols.update(self.declarations.get_symbols())
joined_names = ", ".join(str(n) for n in importing_symbols)
symbols = f"from {modname} import {joined_names}"
if len(symbols) > 80:
wrapped_names = "\n".join(txtwrapper.wrap(joined_names))
symbols = f"from {modname} import (\n{wrapped_names}\n)"
print(symbols, file=output)
print(self._make_friendly_module_import_part(modname), file=output)
print(file=output)
print(file=output)
quoted_names = ", ".join(repr(str(n)) for n in self.names)
dunder_all = f"__all__ = [{quoted_names}]"
if len(dunder_all) > 80:
wrapped_quoted_names = "\n".join(txtwrapper.wrap(quoted_names))
dunder_all = f"__all__ = [\n{wrapped_quoted_names}\n]"
print(dunder_all, file=output)
print(self.enums.getvalue(), file=output)
print(file=output)
print(file=output)
enum_aliases = self.enum_aliases
if enum_aliases:
for k, v in enum_aliases.items():
print(f"{k} = {v}", file=output)
print(file=output)
print(file=output)
print(self._make_dunder_all_part(), file=output)
return output.getvalue()

def _make_dunder_all_part(self) -> str:
joined_names = ", ".join(repr(str(n)) for n in self.names)
dunder_all = f"__all__ = [{joined_names}]"
if len(dunder_all) > 80:
txtwrapper = textwrap.TextWrapper(
subsequent_indent=" ", initial_indent=" ", break_long_words=False
)
joined_names = "\n".join(txtwrapper.wrap(joined_names))
dunder_all = f"__all__ = [\n{joined_names}\n]"
return dunder_all

def _make_friendly_module_import_part(self, modname: str) -> str:
# The `modname` is the wrapper module name like `comtypes.gen._xxxx..._x_x_x`
txtwrapper = textwrap.TextWrapper(
subsequent_indent=" ", initial_indent=" ", break_long_words=False
)
symbols = set(self.names)
symbols.update(self.imports.get_symbols())
symbols.update(self.declarations.get_symbols())
symbols -= set(self.enums.get_symbols())
symbols -= set(self.enum_aliases)
joined_names = ", ".join(str(n) for n in symbols)
part = f"from {modname} import {joined_names}"
if len(part) > 80:
txtwrapper = textwrap.TextWrapper(
subsequent_indent=" ", initial_indent=" ", break_long_words=False
)
joined_names = "\n".join(txtwrapper.wrap(joined_names))
part = f"from {modname} import (\n{joined_names}\n)"
return part

@property
def enum_aliases(self) -> Dict[str, str]:
return {k: v for k, v in self.aliases.items() if v in self.enums}

def need_VARIANT_imports(self, value):
text = repr(value)
if "Decimal(" in text:
Expand Down Expand Up @@ -645,6 +670,8 @@ def EnumValue(self, tp: typedesc.EnumValue) -> None:
print("# Fixing keyword as EnumValue for %s" % tp.name)
tp_name = self._to_type_name(tp)
print("%s = %d" % (tp_name, value), file=self.stream)
if tp.enumeration.name:
self.enums.add(tp.enumeration.name, tp_name)
self.names.add(tp_name)

def Enumeration(self, tp: typedesc.Enumeration) -> None:
Expand All @@ -653,10 +680,6 @@ def Enumeration(self, tp: typedesc.Enumeration) -> None:
print("# values for enumeration '%s'" % tp.name, file=self.stream)
else:
print("# values for unnamed enumeration", file=self.stream)
# Some enumerations have the same name for the enum type
# and an enum value. Excel's XlDisplayShapes is such an example.
# Since we don't have separate namespaces for the type and the values,
# we generate the TYPE last, overwriting the value. XXX
for item in tp.values:
self.generate(item)
if tp.name:
Expand All @@ -675,6 +698,7 @@ def Typedef(self, tp: typedesc.Typedef) -> None:
self.declarations.add(tp.name, definition)
else:
print("%s = %s" % (tp.name, definition), file=self.stream)
self.aliases[tp.name] = definition
self.last_item_class = False
self.names.add(tp.name)

Expand Down Expand Up @@ -1518,3 +1542,45 @@ def getvalue(self):
code = code + " # %s" % comment
lines.append(code)
return "\n".join(lines)


class EnumerationNamespaces(object):
def __init__(self):
self.data: Dict[str, List[str]] = {}

def add(self, enum_name: str, member_name: str) -> None:
"""Adds a namespace will be enumeration and its member.

Examples:
>>> enums = EnumerationNamespaces()
>>> enums.add('Foo', 'ham')
>>> enums.add('Foo', 'spam')
>>> enums.add('Bar', 'bacon')
>>> assert 'Foo' in enums
>>> assert 'Baz' not in enums
>>> print(enums.getvalue()) # <BLANKLINE> is necessary for doctest
class Foo(IntFlag):
ham = __wrapper_module__.ham
spam = __wrapper_module__.spam
<BLANKLINE>
<BLANKLINE>
class Bar(IntFlag):
bacon = __wrapper_module__.bacon
"""
self.data.setdefault(enum_name, []).append(member_name)

def __contains__(self, item: str) -> bool:
return item in self.data

def get_symbols(self) -> Set[str]:
return set(self.data)

def getvalue(self) -> str:
blocks = []
for enum_name, enum_members in self.data.items():
lines = []
lines.append(f"class {enum_name}(IntFlag):")
for member_name in enum_members:
lines.append(f" {member_name} = __wrapper_module__.{member_name}")
blocks.append("\n".join(lines))
return "\n\n\n".join(blocks)
Loading