Skip to content

Commit

Permalink
split helpers module in tools.codegenerator subdirectory package …
Browse files Browse the repository at this point in the history
…part 2 (enthought#549)

* revert `helpers` file name

* remove duplicated stuffs in `codegenerator` package
  • Loading branch information
junkmd authored May 22, 2024
1 parent 9a83390 commit cc071fe
Show file tree
Hide file tree
Showing 5 changed files with 3 additions and 2,017 deletions.
2 changes: 1 addition & 1 deletion comtypes/tools/codegenerator/codegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from comtypes.tools.codegenerator import namespaces
from comtypes.tools.codegenerator import packing
from comtypes.tools.codegenerator.modulenamer import name_wrapper_module
from comtypes.tools.codegenerator.helpers_ import (
from comtypes.tools.codegenerator.helpers import (
get_real_type,
ASSUME_STRINGS,
ComMethodGenerator,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import keyword
import textwrap
from typing import Any
from typing import Dict, List, Set, Tuple
from typing import List, Tuple
from typing import Iterator
from typing import Union as _UnionT

import comtypes
from comtypes.tools import typedesc
from comtypes.tools.codegenerator.modulenamer import name_wrapper_module


class lcid(object):
Expand Down Expand Up @@ -66,114 +66,9 @@ def get_real_type(tp):
ASSUME_STRINGS = True


def _calc_packing(struct, fields, pack, isStruct):
# Try a certain packing, raise PackingError if field offsets,
# total size ot total alignment is wrong.
if struct.size is None: # incomplete struct
return -1
if struct.name in dont_assert_size:
return None
if struct.bases:
size = struct.bases[0].size
total_align = struct.bases[0].align
else:
size = 0
total_align = 8 # in bits
for i, f in enumerate(fields):
if f.bits: # this code cannot handle bit field sizes.
# print "##XXX FIXME"
return -2 # XXX FIXME
s, a = storage(f.typ)
if pack is not None:
a = min(pack, a)
if size % a:
size += a - size % a
if isStruct:
if size != f.offset:
raise PackingError("field %s offset (%s/%s)" % (f.name, size, f.offset))
size += s
else:
size = max(size, s)
total_align = max(total_align, a)
if total_align != struct.align:
raise PackingError("total alignment (%s/%s)" % (total_align, struct.align))
a = total_align
if pack is not None:
a = min(pack, a)
if size % a:
size += a - size % a
if size != struct.size:
raise PackingError("total size (%s/%s)" % (size, struct.size))


def calc_packing(struct, fields):
# try several packings, starting with unspecified packing
isStruct = isinstance(struct, typedesc.Structure)
for pack in [None, 16 * 8, 8 * 8, 4 * 8, 2 * 8, 1 * 8]:
try:
_calc_packing(struct, fields, pack, isStruct)
except PackingError as details:
continue
else:
if pack is None:
return None

return int(pack / 8)

raise PackingError("PACKING FAILED: %s" % details)


class PackingError(Exception):
pass


# XXX These should be filtered out in gccxmlparser.
dont_assert_size = set(
[
"__si_class_type_info_pseudo",
"__class_type_info_pseudo",
]
)


def storage(t):
# return the size and alignment of a type
if isinstance(t, typedesc.Typedef):
return storage(t.typ)
elif isinstance(t, typedesc.ArrayType):
s, a = storage(t.typ)
return s * (int(t.max) - int(t.min) + 1), a
return int(t.size), int(t.align)


################################################################


def name_wrapper_module(tlib):
"""Determine the name of a typelib wrapper module"""
libattr = tlib.GetLibAttr()
modname = "_%s_%s_%s_%s" % (
str(libattr.guid)[1:-1].replace("-", "_"),
libattr.lcid,
libattr.wMajorVerNum,
libattr.wMinorVerNum,
)
return "comtypes.gen.%s" % modname


def name_friendly_module(tlib):
"""Determine the friendly-name of a typelib module.
If cannot get friendly-name from typelib, returns `None`.
"""
try:
modulename = tlib.GetDocumentation(-1)[0]
except comtypes.COMError:
return
return "comtypes.gen.%s" % modulename


################################################################

_DefValType = _UnionT["lcid", Any, None]
_IdlFlagType = _UnionT[str, dispid, helpstring]

Expand Down Expand Up @@ -460,213 +355,3 @@ def _inspect_PointerType(
if isinstance(t.typ, typedesc.PointerType):
return self._inspect_PointerType(t.typ, count + 1)
return t.typ, count + 1


class ImportedNamespaces(object):
def __init__(self):
self.data = {}

def add(self, name1, name2=None, symbols=None):
"""Adds a namespace will be imported.
Examples:
>>> imports = ImportedNamespaces()
>>> imports.add('datetime')
>>> imports.add('ctypes', '*')
>>> imports.add('decimal', 'Decimal')
>>> imports.add('GUID', symbols={'GUID': 'comtypes'})
>>> for name in ('COMMETHOD', 'DISPMETHOD', 'IUnknown', 'dispid',
... 'CoClass', 'BSTR', 'DISPPROPERTY'):
... imports.add('comtypes', name)
>>> imports.add('ctypes.wintypes')
>>> print(imports.getvalue())
from ctypes import *
import datetime
from decimal import Decimal
from comtypes import (
BSTR, CoClass, COMMETHOD, dispid, DISPMETHOD, DISPPROPERTY, GUID,
IUnknown
)
import ctypes.wintypes
>>> assert imports.get_symbols() == {
... 'Decimal', 'GUID', 'COMMETHOD', 'DISPMETHOD', 'IUnknown',
... 'dispid', 'CoClass', 'BSTR', 'DISPPROPERTY'
... }
"""
if name2 is None:
import_ = name1
if not symbols:
self.data[import_] = None
return
from_ = symbols[import_]
else:
from_, import_ = name1, name2
self.data[import_] = from_

def __contains__(self, item):
"""Returns item has already added.
Examples:
>>> imports = ImportedNamespaces()
>>> imports.add('datetime')
>>> imports.add('ctypes', '*')
>>> 'datetime' in imports
True
>>> ('ctypes', '*') in imports
True
>>> 'os' in imports
False
>>> 'ctypes' in imports
False
>>> ('ctypes', 'c_int') in imports
False
"""
if isinstance(item, tuple):
from_, import_ = item
else:
from_, import_ = None, item
if import_ in self.data:
return self.data[import_] == from_
return False

def get_symbols(self) -> Set[str]:
names = set()
for key, val in self.data.items():
if val is None or key == "*":
continue
names.add(key)
return names

def _make_line(self, from_, imports):
import_ = ", ".join(imports)
code = "from %s import %s" % (from_, import_)
if len(code) <= 80:
return code
wrapper = textwrap.TextWrapper(
subsequent_indent=" ", initial_indent=" ", break_long_words=False
)
import_ = "\n".join(wrapper.wrap(import_))
code = "from %s import (\n%s\n)" % (from_, import_)
return code

def getvalue(self):
ns = {}
lines = []
for key, val in self.data.items():
if val is None:
ns[key] = val
elif key == "*":
lines.append("from %s import *" % val)
else:
ns.setdefault(val, set()).add(key)
for key, val in ns.items():
if val is None:
lines.append("import %s" % key)
else:
names = sorted(val, key=lambda s: s.lower())
lines.append(self._make_line(key, names))
return "\n".join(lines)


class DeclaredNamespaces(object):
def __init__(self):
self.data = {}

def add(self, alias, definition, comment=None):
"""Adds a namespace will be declared.
Examples:
>>> declarations = DeclaredNamespaces()
>>> declarations.add('STRING', 'c_char_p')
>>> declarations.add('_lcid', '0', 'change this if required')
>>> print(declarations.getvalue())
STRING = c_char_p
_lcid = 0 # change this if required
>>> assert declarations.get_symbols() == {
... 'STRING', '_lcid'
... }
"""
self.data[(alias, definition)] = comment

def get_symbols(self) -> Set[str]:
names = set()
for alias, _ in self.data.keys():
names.add(alias)
return names

def getvalue(self):
lines = []
for (alias, definition), comment in self.data.items():
code = "%s = %s" % (alias, definition)
if comment:
code = code + " # %s" % comment
lines.append(code)
return "\n".join(lines)


class EnumerationNamespaces(object):
def __init__(self):
self.data: Dict[str, List[Tuple[str, int]]] = {}

def add(self, enum_name: str, member_name: str, value: int) -> None:
"""Adds a namespace will be enumeration and its member.
Examples:
<BLANKLINE> is necessary for doctest
>>> enums = EnumerationNamespaces()
>>> assert not enums
>>> enums.add('Foo', 'ham', 1)
>>> assert enums
>>> enums.add('Foo', 'spam', 2)
>>> enums.add('Bar', 'bacon', 3)
>>> assert 'Foo' in enums
>>> assert 'Baz' not in enums
>>> print(enums.to_intflags())
class Foo(IntFlag):
ham = 1
spam = 2
<BLANKLINE>
<BLANKLINE>
class Bar(IntFlag):
bacon = 3
>>> print(enums.to_constants())
# values for enumeration 'Foo'
ham = 1
spam = 2
Foo = c_int # enum
<BLANKLINE>
# values for enumeration 'Bar'
bacon = 3
Bar = c_int # enum
"""
self.data.setdefault(enum_name, []).append((member_name, value))

def __contains__(self, item: str) -> bool:
return item in self.data

def __bool__(self) -> bool:
return bool(self.data)

def get_symbols(self) -> Set[str]:
return set(self.data)

def to_constants(self) -> str:
blocks = []
for enum_name, enum_members in self.data.items():
lines = []
lines.append(f"# values for enumeration '{enum_name}'")
for n, v in enum_members:
lines.append(f"{n} = {v}")
lines.append(f"{enum_name} = c_int # enum")
blocks.append("\n".join(lines))
return "\n\n".join(blocks)

def to_intflags(self) -> str:
blocks = []
for enum_name, enum_members in self.data.items():
lines = []
lines.append(f"class {enum_name}(IntFlag):")
for member_name, value in enum_members:
lines.append(f" {member_name} = {value}")
blocks.append("\n".join(lines))
return "\n\n\n".join(blocks)
Loading

0 comments on commit cc071fe

Please sign in to comment.