Skip to content

Commit

Permalink
Define constants and c_int aliases earlier than others in wrapper m…
Browse files Browse the repository at this point in the history
…odules (#527)
  • Loading branch information
junkmd authored Apr 15, 2024
1 parent 040152f commit 7fc1cce
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 10 deletions.
7 changes: 6 additions & 1 deletion comtypes/test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,12 @@ def test_the_name_of_the_enum_member_and_the_coclass_are_duplicated(self):
# the definition of an enumeration, the generation of the module will fail.
# See also https://github.com/enthought/comtypes/issues/524
with contextlib.redirect_stdout(None): # supress warnings
comtypes.client.GetModule("mshtml.tlb")
mshtml = comtypes.client.GetModule("mshtml.tlb")
# When the member of an enumeration and a CoClass have the same name,
# the defined later one is assigned to the name in the module.
# By asserting whether the CoClass is assigned to that name, it ensures
# that the member of the enumeration is defined earlier.
self.assertTrue(issubclass(mshtml.htmlInputImage, comtypes.CoClass))

def test_abstracted_wrapper_module_in_friendly_module(self):
mod = comtypes.client.GetModule("scrrun.dll")
Expand Down
42 changes: 33 additions & 9 deletions comtypes/tools/codegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ def __init__(self, known_symbols=None):
self.imports = ImportedNamespaces()
self.declarations = DeclaredNamespaces()
self.enums = EnumerationNamespaces()
self.unnamed_enum_members: List[Tuple[str, int]] = []
self._to_type_name = TypeNamer()
self.known_symbols = known_symbols or {}

Expand Down Expand Up @@ -572,6 +573,12 @@ def generate_wrapper_code(
print(file=output)
print(self.declarations.getvalue(), file=output)
print(file=output)
if self.unnamed_enum_members:
print("# values for unnamed enumeration", file=output)
for n, v in self.unnamed_enum_members:
print(f"{n} = {v}", file=output)
print(file=output)
print(self.enums.to_constants(), file=output)
print(self.stream.getvalue(), file=output)
print(self._make_dunder_all_part(), file=output)
print(file=output)
Expand All @@ -596,7 +603,7 @@ def generate_friendly_code(self, modname: str) -> str:
print(self._make_friendly_module_import_part(modname), file=output)
print(file=output)
print(file=output)
print(self.enums.getvalue(), file=output)
print(self.enums.to_intflags(), file=output)
print(file=output)
print(file=output)
enum_aliases = self.enum_aliases
Expand Down Expand Up @@ -669,21 +676,17 @@ def EnumValue(self, tp: typedesc.EnumValue) -> None:
if __warn_on_munge__:
print("# Fixing keyword as EnumValue for %s" % tp.name)
tp_name = self._to_type_name(tp)
print("%s = %d" % (tp_name, value), file=self.stream)
if tp.enumeration.name:
self.enums.add(tp.enumeration.name, tp_name, value)
else:
self.unnamed_enum_members.append((tp_name, value))
self.names.add(tp_name)

def Enumeration(self, tp: typedesc.Enumeration) -> None:
self.last_item_class = False
if tp.name:
print("# values for enumeration '%s'" % tp.name, file=self.stream)
else:
print("# values for unnamed enumeration", file=self.stream)
for item in tp.values:
self.generate(item)
if tp.name:
print("%s = c_int # enum" % tp.name, file=self.stream)
self.names.add(tp.name)

def Typedef(self, tp: typedesc.Typedef) -> None:
Expand Down Expand Up @@ -1552,20 +1555,30 @@ def add(self, enum_name: str, member_name: str, value: int) -> None:
"""Adds a namespace will be enumeration and its member.
Examples:
<BLANKLINE> is necessary for doctest
>>> enums = EnumerationNamespaces()
>>> enums.add('Foo', 'ham', 1)
>>> enums.add('Foo', 'spam', 2)
>>> enums.add('Bar', 'bacon', 3)
>>> assert 'Foo' in enums
>>> assert 'Baz' not in enums
>>> print(enums.getvalue()) # <BLANKLINE> is necessary for doctest
>>> print(enums.to_intflags())
class Foo(IntFlag):
ham = 1
spam = 2
<BLANKLINE>
<BLANKLINE>
class Bar(IntFlag):
bacon = 3
>>> print(enums.to_constants())
# values for enumeration 'Foo'
ham = 1
spam = 2
Foo = c_int # enum
<BLANKLINE>
# values for enumeration 'Bar'
bacon = 3
Bar = c_int # enum
"""
self.data.setdefault(enum_name, []).append((member_name, value))

Expand All @@ -1575,7 +1588,18 @@ def __contains__(self, item: str) -> bool:
def get_symbols(self) -> Set[str]:
return set(self.data)

def getvalue(self) -> str:
def to_constants(self) -> str:
blocks = []
for enum_name, enum_members in self.data.items():
lines = []
lines.append(f"# values for enumeration '{enum_name}'")
for n, v in enum_members:
lines.append(f"{n} = {v}")
lines.append(f"{enum_name} = c_int # enum")
blocks.append("\n".join(lines))
return "\n\n".join(blocks)

def to_intflags(self) -> str:
blocks = []
for enum_name, enum_members in self.data.items():
lines = []
Expand Down

0 comments on commit 7fc1cce

Please sign in to comment.