diff --git a/comtypes/test/test_client.py b/comtypes/test/test_client.py index d44bfd22..629ed4d3 100644 --- a/comtypes/test/test_client.py +++ b/comtypes/test/test_client.py @@ -229,6 +229,15 @@ def test_progid(self): self.assertEqual(consts.TextCompare, Scripting.TextCompare) self.assertEqual(consts.DatabaseCompare, Scripting.DatabaseCompare) + def test_enums_in_friendly_mod(self): + consts = comtypes.client.Constants("scrrun.dll") + comtypes.client.GetModule("scrrun.dll") + from comtypes.gen import Scripting + + for e in Scripting.StandardStreamTypes: + self.assertIn(e.name, consts.StandardStreamTypes) + self.assertEqual(consts.StandardStreamTypes[e.name], e.value) + def test_returns_other_than_enum_members(self): obj = comtypes.client.CreateObject("SAPI.SpVoice") from comtypes.gen import SpeechLib as sapi diff --git a/comtypes/tools/codegenerator.py b/comtypes/tools/codegenerator.py index 82c3bcc3..7540edc5 100644 --- a/comtypes/tools/codegenerator.py +++ b/comtypes/tools/codegenerator.py @@ -423,12 +423,14 @@ def __init__(self, known_symbols=None): self.stream = io.StringIO() self.imports = ImportedNamespaces() self.declarations = DeclaredNamespaces() + self.enums = EnumerationNamespaces() self._to_type_name = TypeNamer() self.known_symbols = known_symbols or {} self.done = set() # type descriptions that have been generated self.names = set() # names that have been generated self.externals = [] # typelibs imported to generated module + self.aliases: Dict[str, str] = {} self.last_item_class = False def generate(self, item): @@ -571,14 +573,7 @@ def generate_wrapper_code( print(self.declarations.getvalue(), file=output) print(file=output) print(self.stream.getvalue(), file=output) - names = ", ".join(repr(str(n)) for n in self.names) - dunder_all = "__all__ = [%s]" % names - if len(dunder_all) > 80: - wrapper = textwrap.TextWrapper( - subsequent_indent=" ", initial_indent=" ", break_long_words=False - ) - dunder_all = "__all__ = [\n%s\n]" % "\n".join(wrapper.wrap(names)) - print(dunder_all, file=output) + print(self._make_dunder_all_part(), file=output) print(file=output) if tlib_mtime is not None: print("_check_version(%r, %f)" % (version, tlib_mtime), file=output) @@ -595,29 +590,59 @@ def generate_friendly_code(self, modname: str) -> str: Such as "comtypes.gen.stdole" and "comtypes.gen.Excel". """ output = io.StringIO() + print("from enum import IntFlag", file=output) + print(file=output) print(f"import {modname} as __wrapper_module__", file=output) - txtwrapper = textwrap.TextWrapper( - subsequent_indent=" ", initial_indent=" ", break_long_words=False - ) - importing_symbols = set(self.names) - importing_symbols.update(self.imports.get_symbols()) - importing_symbols.update(self.declarations.get_symbols()) - joined_names = ", ".join(str(n) for n in importing_symbols) - symbols = f"from {modname} import {joined_names}" - if len(symbols) > 80: - wrapped_names = "\n".join(txtwrapper.wrap(joined_names)) - symbols = f"from {modname} import (\n{wrapped_names}\n)" - print(symbols, file=output) + print(self._make_friendly_module_import_part(modname), file=output) print(file=output) print(file=output) - quoted_names = ", ".join(repr(str(n)) for n in self.names) - dunder_all = f"__all__ = [{quoted_names}]" - if len(dunder_all) > 80: - wrapped_quoted_names = "\n".join(txtwrapper.wrap(quoted_names)) - dunder_all = f"__all__ = [\n{wrapped_quoted_names}\n]" - print(dunder_all, file=output) + print(self.enums.getvalue(), file=output) + print(file=output) + print(file=output) + enum_aliases = self.enum_aliases + if enum_aliases: + for k, v in enum_aliases.items(): + print(f"{k} = {v}", file=output) + print(file=output) + print(file=output) + print(self._make_dunder_all_part(), file=output) return output.getvalue() + def _make_dunder_all_part(self) -> str: + joined_names = ", ".join(repr(str(n)) for n in self.names) + dunder_all = f"__all__ = [{joined_names}]" + if len(dunder_all) > 80: + txtwrapper = textwrap.TextWrapper( + subsequent_indent=" ", initial_indent=" ", break_long_words=False + ) + joined_names = "\n".join(txtwrapper.wrap(joined_names)) + dunder_all = f"__all__ = [\n{joined_names}\n]" + return dunder_all + + def _make_friendly_module_import_part(self, modname: str) -> str: + # The `modname` is the wrapper module name like `comtypes.gen._xxxx..._x_x_x` + txtwrapper = textwrap.TextWrapper( + subsequent_indent=" ", initial_indent=" ", break_long_words=False + ) + symbols = set(self.names) + symbols.update(self.imports.get_symbols()) + symbols.update(self.declarations.get_symbols()) + symbols -= set(self.enums.get_symbols()) + symbols -= set(self.enum_aliases) + joined_names = ", ".join(str(n) for n in symbols) + part = f"from {modname} import {joined_names}" + if len(part) > 80: + txtwrapper = textwrap.TextWrapper( + subsequent_indent=" ", initial_indent=" ", break_long_words=False + ) + joined_names = "\n".join(txtwrapper.wrap(joined_names)) + part = f"from {modname} import (\n{joined_names}\n)" + return part + + @property + def enum_aliases(self) -> Dict[str, str]: + return {k: v for k, v in self.aliases.items() if v in self.enums} + def need_VARIANT_imports(self, value): text = repr(value) if "Decimal(" in text: @@ -645,6 +670,8 @@ def EnumValue(self, tp: typedesc.EnumValue) -> None: print("# Fixing keyword as EnumValue for %s" % tp.name) tp_name = self._to_type_name(tp) print("%s = %d" % (tp_name, value), file=self.stream) + if tp.enumeration.name: + self.enums.add(tp.enumeration.name, tp_name) self.names.add(tp_name) def Enumeration(self, tp: typedesc.Enumeration) -> None: @@ -653,10 +680,6 @@ def Enumeration(self, tp: typedesc.Enumeration) -> None: print("# values for enumeration '%s'" % tp.name, file=self.stream) else: print("# values for unnamed enumeration", file=self.stream) - # Some enumerations have the same name for the enum type - # and an enum value. Excel's XlDisplayShapes is such an example. - # Since we don't have separate namespaces for the type and the values, - # we generate the TYPE last, overwriting the value. XXX for item in tp.values: self.generate(item) if tp.name: @@ -675,6 +698,7 @@ def Typedef(self, tp: typedesc.Typedef) -> None: self.declarations.add(tp.name, definition) else: print("%s = %s" % (tp.name, definition), file=self.stream) + self.aliases[tp.name] = definition self.last_item_class = False self.names.add(tp.name) @@ -1518,3 +1542,45 @@ def getvalue(self): code = code + " # %s" % comment lines.append(code) return "\n".join(lines) + + +class EnumerationNamespaces(object): + def __init__(self): + self.data: Dict[str, List[str]] = {} + + def add(self, enum_name: str, member_name: str) -> None: + """Adds a namespace will be enumeration and its member. + + Examples: + >>> enums = EnumerationNamespaces() + >>> enums.add('Foo', 'ham') + >>> enums.add('Foo', 'spam') + >>> enums.add('Bar', 'bacon') + >>> assert 'Foo' in enums + >>> assert 'Baz' not in enums + >>> print(enums.getvalue()) # is necessary for doctest + class Foo(IntFlag): + ham = __wrapper_module__.ham + spam = __wrapper_module__.spam + + + class Bar(IntFlag): + bacon = __wrapper_module__.bacon + """ + self.data.setdefault(enum_name, []).append(member_name) + + def __contains__(self, item: str) -> bool: + return item in self.data + + def get_symbols(self) -> Set[str]: + return set(self.data) + + def getvalue(self) -> str: + blocks = [] + for enum_name, enum_members in self.data.items(): + lines = [] + lines.append(f"class {enum_name}(IntFlag):") + for member_name in enum_members: + lines.append(f" {member_name} = __wrapper_module__.{member_name}") + blocks.append("\n".join(lines)) + return "\n\n\n".join(blocks)