Skip to content

Commit

Permalink
refactor the codegenerator
Browse files Browse the repository at this point in the history
  • Loading branch information
junkmd committed Feb 11, 2024
1 parent a288a0b commit 8d2ef20
Showing 1 changed file with 44 additions and 31 deletions.
75 changes: 44 additions & 31 deletions comtypes/tools/codegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,14 +573,7 @@ def generate_wrapper_code(
print(self.declarations.getvalue(), file=output)
print(file=output)
print(self.stream.getvalue(), file=output)
names = ", ".join(repr(str(n)) for n in self.names)
dunder_all = "__all__ = [%s]" % names
if len(dunder_all) > 80:
wrapper = textwrap.TextWrapper(
subsequent_indent=" ", initial_indent=" ", break_long_words=False
)
dunder_all = "__all__ = [\n%s\n]" % "\n".join(wrapper.wrap(names))
print(dunder_all, file=output)
print(self._make_dunder_all_part(), file=output)
print(file=output)
if tlib_mtime is not None:
print("_check_version(%r, %f)" % (version, tlib_mtime), file=output)
Expand All @@ -604,37 +597,57 @@ def generate_friendly_code(self, modname: str) -> str:
wrapper_module_name = modname.split(".")[-1]
print("from enum import IntFlag", file=output)
print(file=output)
importing_symbols = set(self.names)
importing_symbols -= set(self.enums.get_symbols())
enum_aliases = {k: v for k, v in self.aliases.items() if v in self.enums}
importing_symbols -= set(enum_aliases)
importing_symbols.update(self.imports.get_symbols())
importing_symbols.update(self.declarations.get_symbols())
joined_names = ", ".join(str(n) for n in importing_symbols)
symbols = f"from {modname} import {joined_names}"
if len(symbols) > 80:
wrapped_names = "\n".join(txtwrapper.wrap(joined_names))
symbols = f"from {modname} import (\n{wrapped_names}\n)"
print(f"from comtypes.gen import {wrapper_module_name}", file=output)
print(symbols, file=output)
print(self._make_friendly_module_import_part(modname), file=output)
print(file=output)
print(file=output)
print(self.enums.getvalue(wrapper_module_name), file=output)
print(file=output)
print(file=output)
enum_aliases = self.enum_aliases
if enum_aliases:
for k, v in enum_aliases.items():
print(f"{k} = {v}", file=output)
else:
print("# no alias for enumerations", file=output)
print(file=output)
print(file=output)
quoted_names = ", ".join(repr(str(n)) for n in self.names)
dunder_all = f"__all__ = [{quoted_names}]"
if len(dunder_all) > 80:
wrapped_quoted_names = "\n".join(txtwrapper.wrap(quoted_names))
dunder_all = f"__all__ = [\n{wrapped_quoted_names}\n]"
print(dunder_all, file=output)
print(file=output)
print(file=output)
print(self._make_dunder_all_part(), file=output)
return output.getvalue()

def _make_dunder_all_part(self) -> str:
joined_names = ", ".join(repr(str(n)) for n in self.names)
dunder_all = f"__all__ = [{joined_names}]"
if len(dunder_all) > 80:
txtwrapper = textwrap.TextWrapper(
subsequent_indent=" ", initial_indent=" ", break_long_words=False
)
joined_names = "\n".join(txtwrapper.wrap(joined_names))
dunder_all = f"__all__ = [\n{joined_names}\n]"
return dunder_all

def _make_friendly_module_import_part(self, modname: str) -> str:
# The `modname` is the wrapper module name like `comtypes.gen._xxxx..._x_x_x`
txtwrapper = textwrap.TextWrapper(
subsequent_indent=" ", initial_indent=" ", break_long_words=False
)
symbols = set(self.names)
symbols.update(self.imports.get_symbols())
symbols.update(self.declarations.get_symbols())
symbols -= set(self.enums.get_symbols())
symbols -= set(self.enum_aliases)
joined_names = ", ".join(str(n) for n in symbols)
part = f"from {modname} import {joined_names}"
if len(part) > 80:
txtwrapper = textwrap.TextWrapper(
subsequent_indent=" ", initial_indent=" ", break_long_words=False
)
joined_names = "\n".join(txtwrapper.wrap(joined_names))
part = f"from {modname} import (\n{joined_names}\n)"
return part

@property
def enum_aliases(self) -> Dict[str, str]:
return {k: v for k, v in self.aliases.items() if v in self.enums}

def need_VARIANT_imports(self, value):
text = repr(value)
if "Decimal(" in text:
Expand Down Expand Up @@ -1575,4 +1588,4 @@ def getvalue(self, wrapper_module_name: str) -> str:
ref = f"{wrapper_module_name}.{member_name}"
lines.append(f" {member_name} = {ref}")
blocks.append("\n".join(lines))
return "\n\n".join(blocks)
return "\n\n\n".join(blocks)

0 comments on commit 8d2ef20

Please sign in to comment.