From 7309fb57af13ab0a1fc7a66dcba55de7665c8361 Mon Sep 17 00:00:00 2001 From: junkmd Date: Mon, 26 Feb 2024 09:01:20 +0900 Subject: [PATCH 1/5] add test for enumerations in friendly module --- comtypes/test/test_client.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/comtypes/test/test_client.py b/comtypes/test/test_client.py index d44bfd22..629ed4d3 100644 --- a/comtypes/test/test_client.py +++ b/comtypes/test/test_client.py @@ -229,6 +229,15 @@ def test_progid(self): self.assertEqual(consts.TextCompare, Scripting.TextCompare) self.assertEqual(consts.DatabaseCompare, Scripting.DatabaseCompare) + def test_enums_in_friendly_mod(self): + consts = comtypes.client.Constants("scrrun.dll") + comtypes.client.GetModule("scrrun.dll") + from comtypes.gen import Scripting + + for e in Scripting.StandardStreamTypes: + self.assertIn(e.name, consts.StandardStreamTypes) + self.assertEqual(consts.StandardStreamTypes[e.name], e.value) + def test_returns_other_than_enum_members(self): obj = comtypes.client.CreateObject("SAPI.SpVoice") from comtypes.gen import SpeechLib as sapi From ffde27fe75968c02aa760381c9a1046f176fa5e1 Mon Sep 17 00:00:00 2001 From: junkmd Date: Mon, 26 Feb 2024 09:01:20 +0900 Subject: [PATCH 2/5] add enums generation processes --- comtypes/tools/codegenerator.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/comtypes/tools/codegenerator.py b/comtypes/tools/codegenerator.py index 82c3bcc3..f683f003 100644 --- a/comtypes/tools/codegenerator.py +++ b/comtypes/tools/codegenerator.py @@ -429,6 +429,8 @@ def __init__(self, known_symbols=None): self.done = set() # type descriptions that have been generated self.names = set() # names that have been generated self.externals = [] # typelibs imported to generated module + self.enums: Dict[str, List[str]] = {} + self.aliases: Dict[str, str] = {} self.last_item_class = False def generate(self, item): @@ -599,7 +601,12 @@ def generate_friendly_code(self, modname: str) -> str: txtwrapper = textwrap.TextWrapper( subsequent_indent=" ", initial_indent=" ", break_long_words=False ) + print("from enum import IntFlag", file=output) + print(file=output) importing_symbols = set(self.names) + importing_symbols -= set(self.enums) + enum_aliases = {k: v for k, v in self.aliases.items() if v in self.enums} + importing_symbols -= set(enum_aliases) importing_symbols.update(self.imports.get_symbols()) importing_symbols.update(self.declarations.get_symbols()) joined_names = ", ".join(str(n) for n in importing_symbols) @@ -610,6 +617,19 @@ def generate_friendly_code(self, modname: str) -> str: print(symbols, file=output) print(file=output) print(file=output) + for enum_name, enum_members in self.enums.items(): + print(f"class {enum_name}(IntFlag):", file=output) + for m_name in enum_members: + print(f" {m_name} = __wrapper_module__.{m_name}", file=output) + print(file=output) + print(file=output) + if enum_aliases: + for k, v in enum_aliases.items(): + print(f"{k} = {v}", file=output) + else: + print("# no alias for enumerations", file=output) + print(file=output) + print(file=output) quoted_names = ", ".join(repr(str(n)) for n in self.names) dunder_all = f"__all__ = [{quoted_names}]" if len(dunder_all) > 80: @@ -645,6 +665,7 @@ def EnumValue(self, tp: typedesc.EnumValue) -> None: print("# Fixing keyword as EnumValue for %s" % tp.name) tp_name = self._to_type_name(tp) print("%s = %d" % (tp_name, value), file=self.stream) + self.enums.setdefault(tp.enumeration.name, []).append(tp_name) self.names.add(tp_name) def Enumeration(self, tp: typedesc.Enumeration) -> None: @@ -653,10 +674,6 @@ def Enumeration(self, tp: typedesc.Enumeration) -> None: print("# values for enumeration '%s'" % tp.name, file=self.stream) else: print("# values for unnamed enumeration", file=self.stream) - # Some enumerations have the same name for the enum type - # and an enum value. Excel's XlDisplayShapes is such an example. - # Since we don't have separate namespaces for the type and the values, - # we generate the TYPE last, overwriting the value. XXX for item in tp.values: self.generate(item) if tp.name: @@ -675,6 +692,7 @@ def Typedef(self, tp: typedesc.Typedef) -> None: self.declarations.add(tp.name, definition) else: print("%s = %s" % (tp.name, definition), file=self.stream) + self.aliases[tp.name] = definition self.last_item_class = False self.names.add(tp.name) From 6fea40eff1728afc9ef4b6f78c4c10e64355e4b9 Mon Sep 17 00:00:00 2001 From: junkmd Date: Mon, 26 Feb 2024 09:01:20 +0900 Subject: [PATCH 3/5] add the sprout class `EnumerationNamespaces` --- comtypes/tools/codegenerator.py | 56 +++++++++++++++++++++++++++------ 1 file changed, 47 insertions(+), 9 deletions(-) diff --git a/comtypes/tools/codegenerator.py b/comtypes/tools/codegenerator.py index f683f003..59bfe759 100644 --- a/comtypes/tools/codegenerator.py +++ b/comtypes/tools/codegenerator.py @@ -423,13 +423,13 @@ def __init__(self, known_symbols=None): self.stream = io.StringIO() self.imports = ImportedNamespaces() self.declarations = DeclaredNamespaces() + self.enums = EnumerationNamespaces() self._to_type_name = TypeNamer() self.known_symbols = known_symbols or {} self.done = set() # type descriptions that have been generated self.names = set() # names that have been generated self.externals = [] # typelibs imported to generated module - self.enums: Dict[str, List[str]] = {} self.aliases: Dict[str, str] = {} self.last_item_class = False @@ -604,7 +604,7 @@ def generate_friendly_code(self, modname: str) -> str: print("from enum import IntFlag", file=output) print(file=output) importing_symbols = set(self.names) - importing_symbols -= set(self.enums) + importing_symbols -= set(self.enums.get_symbols()) enum_aliases = {k: v for k, v in self.aliases.items() if v in self.enums} importing_symbols -= set(enum_aliases) importing_symbols.update(self.imports.get_symbols()) @@ -617,12 +617,7 @@ def generate_friendly_code(self, modname: str) -> str: print(symbols, file=output) print(file=output) print(file=output) - for enum_name, enum_members in self.enums.items(): - print(f"class {enum_name}(IntFlag):", file=output) - for m_name in enum_members: - print(f" {m_name} = __wrapper_module__.{m_name}", file=output) - print(file=output) - print(file=output) + print(self.enums.getvalue("__wrapper_module__"), file=output) if enum_aliases: for k, v in enum_aliases.items(): print(f"{k} = {v}", file=output) @@ -665,7 +660,8 @@ def EnumValue(self, tp: typedesc.EnumValue) -> None: print("# Fixing keyword as EnumValue for %s" % tp.name) tp_name = self._to_type_name(tp) print("%s = %d" % (tp_name, value), file=self.stream) - self.enums.setdefault(tp.enumeration.name, []).append(tp_name) + if tp.enumeration.name: + self.enums.add(tp.enumeration.name, tp_name) self.names.add(tp_name) def Enumeration(self, tp: typedesc.Enumeration) -> None: @@ -1536,3 +1532,45 @@ def getvalue(self): code = code + " # %s" % comment lines.append(code) return "\n".join(lines) + + +class EnumerationNamespaces(object): + def __init__(self): + self.data: Dict[str, List[str]] = {} + + def add(self, enum_name: str, member_name: str) -> None: + """Adds a namespace will be enumeration and its member. + + Examples: + >>> enums = EnumerationNamespaces() + >>> enums.add('Foo', 'ham') + >>> enums.add('Foo', 'spam') + >>> enums.add('Bar', 'bacon') + >>> assert 'Foo' in enums + >>> assert 'Baz' not in enums + >>> print(enums.getvalue('_0123')) # is necessary for doctest + class Foo(IntFlag): + ham = _0123.ham + spam = _0123.spam + + class Bar(IntFlag): + bacon = _0123.bacon + """ + self.data.setdefault(enum_name, []).append(member_name) + + def __contains__(self, item: str) -> bool: + return item in self.data + + def get_symbols(self) -> Set[str]: + return set(self.data) + + def getvalue(self, wrapper_module_name: str) -> str: + blocks = [] + for enum_name, enum_members in self.data.items(): + lines = [] + lines.append(f"class {enum_name}(IntFlag):") + for member_name in enum_members: + ref = f"{wrapper_module_name}.{member_name}" + lines.append(f" {member_name} = {ref}") + blocks.append("\n".join(lines)) + return "\n\n".join(blocks) From 1d8ad6ba24eab131b71f1c2f8a2eca849adb0378 Mon Sep 17 00:00:00 2001 From: junkmd Date: Mon, 26 Feb 2024 09:01:21 +0900 Subject: [PATCH 4/5] refactor the `codegenerator` --- comtypes/tools/codegenerator.py | 80 ++++++++++++++++++--------------- 1 file changed, 45 insertions(+), 35 deletions(-) diff --git a/comtypes/tools/codegenerator.py b/comtypes/tools/codegenerator.py index 59bfe759..84abf967 100644 --- a/comtypes/tools/codegenerator.py +++ b/comtypes/tools/codegenerator.py @@ -573,14 +573,7 @@ def generate_wrapper_code( print(self.declarations.getvalue(), file=output) print(file=output) print(self.stream.getvalue(), file=output) - names = ", ".join(repr(str(n)) for n in self.names) - dunder_all = "__all__ = [%s]" % names - if len(dunder_all) > 80: - wrapper = textwrap.TextWrapper( - subsequent_indent=" ", initial_indent=" ", break_long_words=False - ) - dunder_all = "__all__ = [\n%s\n]" % "\n".join(wrapper.wrap(names)) - print(dunder_all, file=output) + print(self._make_dunder_all_part(), file=output) print(file=output) if tlib_mtime is not None: print("_check_version(%r, %f)" % (version, tlib_mtime), file=output) @@ -597,42 +590,59 @@ def generate_friendly_code(self, modname: str) -> str: Such as "comtypes.gen.stdole" and "comtypes.gen.Excel". """ output = io.StringIO() - print(f"import {modname} as __wrapper_module__", file=output) - txtwrapper = textwrap.TextWrapper( - subsequent_indent=" ", initial_indent=" ", break_long_words=False - ) print("from enum import IntFlag", file=output) print(file=output) - importing_symbols = set(self.names) - importing_symbols -= set(self.enums.get_symbols()) - enum_aliases = {k: v for k, v in self.aliases.items() if v in self.enums} - importing_symbols -= set(enum_aliases) - importing_symbols.update(self.imports.get_symbols()) - importing_symbols.update(self.declarations.get_symbols()) - joined_names = ", ".join(str(n) for n in importing_symbols) - symbols = f"from {modname} import {joined_names}" - if len(symbols) > 80: - wrapped_names = "\n".join(txtwrapper.wrap(joined_names)) - symbols = f"from {modname} import (\n{wrapped_names}\n)" - print(symbols, file=output) + print(f"import {modname} as __wrapper_module__", file=output) + print(self._make_friendly_module_import_part(modname), file=output) print(file=output) print(file=output) print(self.enums.getvalue("__wrapper_module__"), file=output) + print(file=output) + print(file=output) + enum_aliases = self.enum_aliases if enum_aliases: for k, v in enum_aliases.items(): print(f"{k} = {v}", file=output) - else: - print("# no alias for enumerations", file=output) - print(file=output) - print(file=output) - quoted_names = ", ".join(repr(str(n)) for n in self.names) - dunder_all = f"__all__ = [{quoted_names}]" - if len(dunder_all) > 80: - wrapped_quoted_names = "\n".join(txtwrapper.wrap(quoted_names)) - dunder_all = f"__all__ = [\n{wrapped_quoted_names}\n]" - print(dunder_all, file=output) + print(file=output) + print(file=output) + print(self._make_dunder_all_part(), file=output) return output.getvalue() + def _make_dunder_all_part(self) -> str: + joined_names = ", ".join(repr(str(n)) for n in self.names) + dunder_all = f"__all__ = [{joined_names}]" + if len(dunder_all) > 80: + txtwrapper = textwrap.TextWrapper( + subsequent_indent=" ", initial_indent=" ", break_long_words=False + ) + joined_names = "\n".join(txtwrapper.wrap(joined_names)) + dunder_all = f"__all__ = [\n{joined_names}\n]" + return dunder_all + + def _make_friendly_module_import_part(self, modname: str) -> str: + # The `modname` is the wrapper module name like `comtypes.gen._xxxx..._x_x_x` + txtwrapper = textwrap.TextWrapper( + subsequent_indent=" ", initial_indent=" ", break_long_words=False + ) + symbols = set(self.names) + symbols.update(self.imports.get_symbols()) + symbols.update(self.declarations.get_symbols()) + symbols -= set(self.enums.get_symbols()) + symbols -= set(self.enum_aliases) + joined_names = ", ".join(str(n) for n in symbols) + part = f"from {modname} import {joined_names}" + if len(part) > 80: + txtwrapper = textwrap.TextWrapper( + subsequent_indent=" ", initial_indent=" ", break_long_words=False + ) + joined_names = "\n".join(txtwrapper.wrap(joined_names)) + part = f"from {modname} import (\n{joined_names}\n)" + return part + + @property + def enum_aliases(self) -> Dict[str, str]: + return {k: v for k, v in self.aliases.items() if v in self.enums} + def need_VARIANT_imports(self, value): text = repr(value) if "Decimal(" in text: @@ -1573,4 +1583,4 @@ def getvalue(self, wrapper_module_name: str) -> str: ref = f"{wrapper_module_name}.{member_name}" lines.append(f" {member_name} = {ref}") blocks.append("\n".join(lines)) - return "\n\n".join(blocks) + return "\n\n\n".join(blocks) From 12b2c1f71d0ece608a0c687bc5743f77d57bdafd Mon Sep 17 00:00:00 2001 From: junkmd Date: Mon, 26 Feb 2024 09:01:21 +0900 Subject: [PATCH 5/5] remove `wrapper_module_name` arg from `EnumerationNamespaces.getvalue` --- comtypes/tools/codegenerator.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/comtypes/tools/codegenerator.py b/comtypes/tools/codegenerator.py index 84abf967..7540edc5 100644 --- a/comtypes/tools/codegenerator.py +++ b/comtypes/tools/codegenerator.py @@ -596,7 +596,7 @@ def generate_friendly_code(self, modname: str) -> str: print(self._make_friendly_module_import_part(modname), file=output) print(file=output) print(file=output) - print(self.enums.getvalue("__wrapper_module__"), file=output) + print(self.enums.getvalue(), file=output) print(file=output) print(file=output) enum_aliases = self.enum_aliases @@ -1558,13 +1558,14 @@ def add(self, enum_name: str, member_name: str) -> None: >>> enums.add('Bar', 'bacon') >>> assert 'Foo' in enums >>> assert 'Baz' not in enums - >>> print(enums.getvalue('_0123')) # is necessary for doctest + >>> print(enums.getvalue()) # is necessary for doctest class Foo(IntFlag): - ham = _0123.ham - spam = _0123.spam + ham = __wrapper_module__.ham + spam = __wrapper_module__.spam + class Bar(IntFlag): - bacon = _0123.bacon + bacon = __wrapper_module__.bacon """ self.data.setdefault(enum_name, []).append(member_name) @@ -1574,13 +1575,12 @@ def __contains__(self, item: str) -> bool: def get_symbols(self) -> Set[str]: return set(self.data) - def getvalue(self, wrapper_module_name: str) -> str: + def getvalue(self) -> str: blocks = [] for enum_name, enum_members in self.data.items(): lines = [] lines.append(f"class {enum_name}(IntFlag):") for member_name in enum_members: - ref = f"{wrapper_module_name}.{member_name}" - lines.append(f" {member_name} = {ref}") + lines.append(f" {member_name} = __wrapper_module__.{member_name}") blocks.append("\n".join(lines)) return "\n\n\n".join(blocks)