diff --git a/comtypes/test/test_client.py b/comtypes/test/test_client.py index 6ab7468c..bcf0195e 100644 --- a/comtypes/test/test_client.py +++ b/comtypes/test/test_client.py @@ -82,7 +82,12 @@ def test_the_name_of_the_enum_member_and_the_coclass_are_duplicated(self): # the definition of an enumeration, the generation of the module will fail. # See also https://github.com/enthought/comtypes/issues/524 with contextlib.redirect_stdout(None): # supress warnings - comtypes.client.GetModule("mshtml.tlb") + mshtml = comtypes.client.GetModule("mshtml.tlb") + # When the member of an enumeration and a CoClass have the same name, + # the defined later one is assigned to the name in the module. + # By asserting whether the CoClass is assigned to that name, it ensures + # that the member of the enumeration is defined earlier. + self.assertTrue(issubclass(mshtml.htmlInputImage, comtypes.CoClass)) def test_abstracted_wrapper_module_in_friendly_module(self): mod = comtypes.client.GetModule("scrrun.dll") diff --git a/comtypes/tools/codegenerator.py b/comtypes/tools/codegenerator.py index edc3362f..f16ca607 100644 --- a/comtypes/tools/codegenerator.py +++ b/comtypes/tools/codegenerator.py @@ -424,6 +424,7 @@ def __init__(self, known_symbols=None): self.imports = ImportedNamespaces() self.declarations = DeclaredNamespaces() self.enums = EnumerationNamespaces() + self.unnamed_enum_members: List[Tuple[str, int]] = [] self._to_type_name = TypeNamer() self.known_symbols = known_symbols or {} @@ -572,6 +573,12 @@ def generate_wrapper_code( print(file=output) print(self.declarations.getvalue(), file=output) print(file=output) + if self.unnamed_enum_members: + print("# values for unnamed enumeration", file=output) + for n, v in self.unnamed_enum_members: + print(f"{n} = {v}", file=output) + print(file=output) + print(self.enums.to_constants(), file=output) print(self.stream.getvalue(), file=output) print(self._make_dunder_all_part(), file=output) print(file=output) @@ -596,7 +603,7 @@ def generate_friendly_code(self, modname: str) -> str: print(self._make_friendly_module_import_part(modname), file=output) print(file=output) print(file=output) - print(self.enums.getvalue(), file=output) + print(self.enums.to_intflags(), file=output) print(file=output) print(file=output) enum_aliases = self.enum_aliases @@ -669,21 +676,17 @@ def EnumValue(self, tp: typedesc.EnumValue) -> None: if __warn_on_munge__: print("# Fixing keyword as EnumValue for %s" % tp.name) tp_name = self._to_type_name(tp) - print("%s = %d" % (tp_name, value), file=self.stream) if tp.enumeration.name: self.enums.add(tp.enumeration.name, tp_name, value) + else: + self.unnamed_enum_members.append((tp_name, value)) self.names.add(tp_name) def Enumeration(self, tp: typedesc.Enumeration) -> None: self.last_item_class = False - if tp.name: - print("# values for enumeration '%s'" % tp.name, file=self.stream) - else: - print("# values for unnamed enumeration", file=self.stream) for item in tp.values: self.generate(item) if tp.name: - print("%s = c_int # enum" % tp.name, file=self.stream) self.names.add(tp.name) def Typedef(self, tp: typedesc.Typedef) -> None: @@ -1552,13 +1555,14 @@ def add(self, enum_name: str, member_name: str, value: int) -> None: """Adds a namespace will be enumeration and its member. Examples: + is necessary for doctest >>> enums = EnumerationNamespaces() >>> enums.add('Foo', 'ham', 1) >>> enums.add('Foo', 'spam', 2) >>> enums.add('Bar', 'bacon', 3) >>> assert 'Foo' in enums >>> assert 'Baz' not in enums - >>> print(enums.getvalue()) # is necessary for doctest + >>> print(enums.to_intflags()) class Foo(IntFlag): ham = 1 spam = 2 @@ -1566,6 +1570,15 @@ class Foo(IntFlag): class Bar(IntFlag): bacon = 3 + >>> print(enums.to_constants()) + # values for enumeration 'Foo' + ham = 1 + spam = 2 + Foo = c_int # enum + + # values for enumeration 'Bar' + bacon = 3 + Bar = c_int # enum """ self.data.setdefault(enum_name, []).append((member_name, value)) @@ -1575,7 +1588,18 @@ def __contains__(self, item: str) -> bool: def get_symbols(self) -> Set[str]: return set(self.data) - def getvalue(self) -> str: + def to_constants(self) -> str: + blocks = [] + for enum_name, enum_members in self.data.items(): + lines = [] + lines.append(f"# values for enumeration '{enum_name}'") + for n, v in enum_members: + lines.append(f"{n} = {v}") + lines.append(f"{enum_name} = c_int # enum") + blocks.append("\n".join(lines)) + return "\n\n".join(blocks) + + def to_intflags(self) -> str: blocks = [] for enum_name, enum_members in self.data.items(): lines = []