From 8d2ef20caee01973f4c5a9381432f8fb409572ac Mon Sep 17 00:00:00 2001 From: junkmd Date: Sun, 11 Feb 2024 00:35:23 +0000 Subject: [PATCH] refactor the `codegenerator` --- comtypes/tools/codegenerator.py | 75 +++++++++++++++++++-------------- 1 file changed, 44 insertions(+), 31 deletions(-) diff --git a/comtypes/tools/codegenerator.py b/comtypes/tools/codegenerator.py index 403faaa71..19256d435 100644 --- a/comtypes/tools/codegenerator.py +++ b/comtypes/tools/codegenerator.py @@ -573,14 +573,7 @@ def generate_wrapper_code( print(self.declarations.getvalue(), file=output) print(file=output) print(self.stream.getvalue(), file=output) - names = ", ".join(repr(str(n)) for n in self.names) - dunder_all = "__all__ = [%s]" % names - if len(dunder_all) > 80: - wrapper = textwrap.TextWrapper( - subsequent_indent=" ", initial_indent=" ", break_long_words=False - ) - dunder_all = "__all__ = [\n%s\n]" % "\n".join(wrapper.wrap(names)) - print(dunder_all, file=output) + print(self._make_dunder_all_part(), file=output) print(file=output) if tlib_mtime is not None: print("_check_version(%r, %f)" % (version, tlib_mtime), file=output) @@ -604,37 +597,57 @@ def generate_friendly_code(self, modname: str) -> str: wrapper_module_name = modname.split(".")[-1] print("from enum import IntFlag", file=output) print(file=output) - importing_symbols = set(self.names) - importing_symbols -= set(self.enums.get_symbols()) - enum_aliases = {k: v for k, v in self.aliases.items() if v in self.enums} - importing_symbols -= set(enum_aliases) - importing_symbols.update(self.imports.get_symbols()) - importing_symbols.update(self.declarations.get_symbols()) - joined_names = ", ".join(str(n) for n in importing_symbols) - symbols = f"from {modname} import {joined_names}" - if len(symbols) > 80: - wrapped_names = "\n".join(txtwrapper.wrap(joined_names)) - symbols = f"from {modname} import (\n{wrapped_names}\n)" print(f"from comtypes.gen import {wrapper_module_name}", file=output) - print(symbols, file=output) + print(self._make_friendly_module_import_part(modname), file=output) print(file=output) print(file=output) print(self.enums.getvalue(wrapper_module_name), file=output) + print(file=output) + print(file=output) + enum_aliases = self.enum_aliases if enum_aliases: for k, v in enum_aliases.items(): print(f"{k} = {v}", file=output) - else: - print("# no alias for enumerations", file=output) - print(file=output) - print(file=output) - quoted_names = ", ".join(repr(str(n)) for n in self.names) - dunder_all = f"__all__ = [{quoted_names}]" - if len(dunder_all) > 80: - wrapped_quoted_names = "\n".join(txtwrapper.wrap(quoted_names)) - dunder_all = f"__all__ = [\n{wrapped_quoted_names}\n]" - print(dunder_all, file=output) + print(file=output) + print(file=output) + print(self._make_dunder_all_part(), file=output) return output.getvalue() + def _make_dunder_all_part(self) -> str: + joined_names = ", ".join(repr(str(n)) for n in self.names) + dunder_all = f"__all__ = [{joined_names}]" + if len(dunder_all) > 80: + txtwrapper = textwrap.TextWrapper( + subsequent_indent=" ", initial_indent=" ", break_long_words=False + ) + joined_names = "\n".join(txtwrapper.wrap(joined_names)) + dunder_all = f"__all__ = [\n{joined_names}\n]" + return dunder_all + + def _make_friendly_module_import_part(self, modname: str) -> str: + # The `modname` is the wrapper module name like `comtypes.gen._xxxx..._x_x_x` + txtwrapper = textwrap.TextWrapper( + subsequent_indent=" ", initial_indent=" ", break_long_words=False + ) + symbols = set(self.names) + symbols.update(self.imports.get_symbols()) + symbols.update(self.declarations.get_symbols()) + symbols -= set(self.enums.get_symbols()) + symbols -= set(self.enum_aliases) + joined_names = ", ".join(str(n) for n in symbols) + part = f"from {modname} import {joined_names}" + if len(part) > 80: + txtwrapper = textwrap.TextWrapper( + subsequent_indent=" ", initial_indent=" ", break_long_words=False + ) + joined_names = "\n".join(txtwrapper.wrap(joined_names)) + part = f"from {modname} import (\n{joined_names}\n)" + return part + + @property + def enum_aliases(self) -> Dict[str, str]: + return {k: v for k, v in self.aliases.items() if v in self.enums} + def need_VARIANT_imports(self, value): text = repr(value) if "Decimal(" in text: @@ -1575,4 +1588,4 @@ def getvalue(self, wrapper_module_name: str) -> str: ref = f"{wrapper_module_name}.{member_name}" lines.append(f" {member_name} = {ref}") blocks.append("\n".join(lines)) - return "\n\n".join(blocks) + return "\n\n\n".join(blocks)