Skip to content

Commit

Permalink
Optimized execution flow (-50% time)
Browse files Browse the repository at this point in the history
+ Reduced execution time by half
+ Mulitple functions got inlined
  • Loading branch information
MatrixEditor committed Dec 28, 2023
1 parent 3dac578 commit c773ece
Show file tree
Hide file tree
Showing 12 changed files with 204 additions and 131 deletions.
32 changes: 15 additions & 17 deletions caterpillar/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from typing import List, Any, Union, Iterable

from caterpillar.abc import _GreedyType, _ContextLike, isgreedy, _StreamType, isprefixed
from caterpillar.abc import _GreedyType, _ContextLike, _StreamType, isprefixed
from caterpillar.context import (
Context,
CTX_PATH,
Expand All @@ -25,22 +25,24 @@
CTX_INDEX,
CTX_OBJECT,
CTX_STREAM,
CTX_SEQ
)
from caterpillar.options import F_SEQUENTIAL
from caterpillar.exception import Stop, StructException, InvalidValueError


class WithoutFlag:
def __init__(self, context: _ContextLike, flag) -> None:
class WithoutContextVar:
def __init__(self, context: _ContextLike, name, value) -> None:
self.context = context
self.old_value = context[name]
self.value = value
self.name = name
self.field = context[CTX_FIELD]
self.flag = flag

def __enter__(self) -> None:
self.field ^= self.flag
self.context[self.name] = self.value

def __exit__(self, exc_type, exc_value, traceback) -> None:
self.field |= self.flag
self.context[self.name] = self.old_value
# We have to apply the right field as instance of the Field class
# might set their own value into the context.
self.context[CTX_FIELD] = self.field
Expand All @@ -58,7 +60,7 @@ def unpack_seq(context: _ContextLike, unpack_one) -> List[Any]:
"""
stream = context[CTX_STREAM]
field = context[CTX_FIELD]
assert field and field.is_seq()
assert field and context[CTX_SEQ]

length: Union[int, _GreedyType] = field.length(context)
base_path = context[CTX_PATH]
Expand All @@ -72,13 +74,13 @@ def unpack_seq(context: _ContextLike, unpack_one) -> List[Any]:
_lst=values,
_field=field,
_obj=context.get(CTX_OBJECT),
_pos=context.get(CTX_POS)
)
greedy = isgreedy(length)
greedy = length is Ellipsis
prefixed = isprefixed(length)
seq_context[CTX_POS] = stream.tell()
if prefixed:
# We have to temporarily remove the array status from the parsing field
with WithoutFlag(context, F_SEQUENTIAL):
with WithoutContextVar(context, CTX_SEQ, False):
field.amount = 1
new_length = length.start.__unpack__(context)
field.amount, length = length, new_length
Expand All @@ -90,7 +92,7 @@ def unpack_seq(context: _ContextLike, unpack_one) -> List[Any]:

for i in range(length) if not greedy else itertools.count():
try:
seq_context[CTX_PATH] = ".".join([base_path, str(i)])
seq_context[CTX_PATH] = f"{base_path}.{i}"
seq_context[CTX_INDEX] = i
values.append(unpack_one(seq_context))
seq_context[CTX_POS] = stream.tell()
Expand Down Expand Up @@ -122,17 +124,13 @@ def pack_seq(seq: List[Any], context: _ContextLike, pack_one) -> None:
stream = context[CTX_STREAM]
field = context[CTX_FIELD]
base_path = context[CTX_PATH]
# Treat the 'obj' as a sequence/iterable
if not isinstance(seq, Iterable):
raise InvalidValueError(f"Expected iterable sequence, got {type(seq)}", context)

# REVISIT: when to use field.length(context)
count = len(seq)
length = field.amount
if isprefixed(length):
struct = length.start
# We have to temporatily alter the field's values,
with WithoutFlag(context, F_SEQUENTIAL):
with WithoutContextVar(context, CTX_SEQ, False):
field.amount = 1
struct.__pack__(count, context)
field.amount = length
Expand Down
5 changes: 1 addition & 4 deletions caterpillar/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,5 @@ def typeof(struct: Union[_StructLike, _ContainsStruct]) -> type:
return __type__() or Any


def isgreedy(obj) -> bool:
return isinstance(obj, _GreedyType)

def isprefixed(obj) -> bool:
return isinstance(obj, _PrefixedType)
return type(obj) is _PrefixedType
1 change: 1 addition & 0 deletions caterpillar/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
CTX_POS = "_pos"
CTX_INDEX = "_index"
CTX_PATH = "_path"
CTX_SEQ = "_is_seq"


class Context(_ContextLike):
Expand Down
3 changes: 2 additions & 1 deletion caterpillar/fields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from ._base import (
Flag,
Field,
FieldMixin,
FieldStruct,
Expand All @@ -34,6 +33,8 @@
Pass,
CString,
Prefixed,
Int,
UInt,
padding,
char,
boolean,
Expand Down
66 changes: 37 additions & 29 deletions caterpillar/fields/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
)
from caterpillar.context import CTX_OFFSETS, CTX_STREAM
from caterpillar.context import CTX_FIELD, CTX_POS
from caterpillar.context import CTX_VALUE
from caterpillar.context import CTX_VALUE, CTX_SEQ
from caterpillar._common import unpack_seq, pack_seq


Expand Down Expand Up @@ -83,7 +83,7 @@ class Field(_StructLike):
The minus one indicates that no offset has been associated with this field.
"""

flags: Set[Flag]
flags: Dict[int, Flag]
"""
Additional options that can be enabled using the logical OR operator ``|``.
Expand Down Expand Up @@ -149,8 +149,7 @@ def __init__(
# NOTE: we use a custom init method to automatically set flags
self.struct = struct
self.order = order
self.flags = flags or set([F_KEEP_POSITION])
self.flags.update(GLOBAL_FIELD_FLAGS)
self.flags = {hash(x): x for x in flags or set([F_KEEP_POSITION])}
self.bits = bits

self.arch = arch or get_system_arch()
Expand All @@ -177,27 +176,28 @@ def __or__(self, flag: Flag) -> Self: # add flags
if not isinstance(flag, Flag):
raise TypeError(f"Expected a flag, got {type(flag)}")

self.flags.add(flag)
self.flags[hash(flag)] = flag
return self

def __xor__(self, flag: Flag) -> Self: # remove flags:
self.flags.remove(flag)
self.flags.pop(hash(flag), None)
return self

def __matmul__(self, offset: Union[_ContextLambda, int]) -> Self:
self._verify_context_value(offset, int)
self.offset = offset
# This operation automatically removes the "keep_position"
# flag. It has to be set manually.
if self.has_flag(F_KEEP_POSITION) and self.offset != -1:
self.flags.remove(F_KEEP_POSITION)
if self.offset != -1:
self.flags.pop(F_KEEP_POSITION, None)
return self

def __getitem__(self, dim: Union[_ContextLambda, int, _GreedyType]) -> Self:
self._verify_context_value(dim, (_GreedyType, int, _PrefixedType))
self.amount = dim
if self.amount != 0:
self.flags.add(F_SEQUENTIAL)
# pylint: disable-next=protected-access
self.flags[F_SEQUENTIAL._hash_] = F_SEQUENTIAL
return self

def __rshift__(self, switch: Union[_Switch, dict]) -> Self:
Expand Down Expand Up @@ -234,7 +234,8 @@ def is_seq(self) -> bool:
:return: whether this field is sequental
:rtype: bool
"""
return self.has_flag(F_SEQUENTIAL)
# pylint: disable-next=protected-access
return F_SEQUENTIAL._hash_ in self.flags

def is_enabled(self, context: _ContextLike) -> bool:
"""Evaluates the condition of this field.
Expand All @@ -254,9 +255,10 @@ def has_flag(self, flag: Flag) -> bool:
:return: true if this flag has been found
:rtype: bool
"""
return flag in self.flags
# pylint: disable-next=protected-access
return flag._hash_ in self.flags or flag in GLOBAL_FIELD_FLAGS

def length(self, context: _ContextLike) -> Union[int, _GreedyType]:
def length(self, context: _ContextLike) -> Union[int, _GreedyType, _PrefixedType]:
"""Calculates the sequence length of this field.
:param context: the context on which to operate
Expand Down Expand Up @@ -338,21 +340,27 @@ def __unpack__(self, context: _ContextLike) -> Optional[Any]:
:rtype: Optional[Any]
"""
stream: _StreamType = context[CTX_STREAM]
if not self.is_enabled(context):
# handling the result of this function should be treated carefully
return None
if self.condition is not True and not self.is_enabled(context):
# Disabled fields or context lambdas won't pack any data
return

# pylint: disable-next=protected-access
context[CTX_SEQ] = F_SEQUENTIAL._hash_ in self.flags
# pylint: disable-next=protected-access
keep_pos = F_KEEP_POSITION._hash_ in self.flags
if not callable(self.struct):
fallback = stream.tell()
offset = self.get_offset(context)
start = offset if offset >= 0 else fallback
if not keep_pos:
fallback = stream.tell()

offset = self.offset(context) if callable(self.offset) else self.offset
if offset >= 0:
stream.seek(offset)

context[CTX_FIELD] = self
# Switch is applicable AFTER we parsed the first value
stream.seek(start)
try:
value = self.struct.__unpack__(context)
if not self.has_flag(F_KEEP_POSITION):
if not keep_pos:
stream.seek(fallback)
except StructException as exc:
# Any exception leads to a default value if configured
Expand Down Expand Up @@ -395,10 +403,12 @@ def __pack__(self, obj: Any, context: _ContextLike) -> None:
:raises TypeError: if the value is not iterable but this field is marked
to be sequential
"""
# TODO: revisit code
stream: _StreamType = context[CTX_STREAM]
if not self.is_enabled(context):
# Disabled fields or context lambdas won't pack any data
return
if self.condition is not True:
if not self.is_enabled(context):
# Disabled fields or context lambdas won't pack any data
return

# Setup parsing by specifying the start and end positions
fallback = stream.tell()
Expand Down Expand Up @@ -461,11 +471,12 @@ def __size__(self, context: _ContextLike) -> int:
raise DynamicSizeError("Dynamic sized field!", context)

context[CTX_FIELD] = self
context[CTX_SEQ] = self.is_seq()

# 3. We should gather the element count if this field stores
# a sequential element
count = 1
if self.is_seq():
if context[CTX_SEQ]:
count = self.length(context)
if isinstance(count, _GreedyType):
raise DynamicSizeError(
Expand Down Expand Up @@ -571,9 +582,8 @@ def __pack__(self, obj: Any, context: _ContextLike) -> None:
:param stream: The output stream.
:param context: The current operation context.
"""
field: Field = context[CTX_FIELD]
func = self.pack_single
if field.is_seq():
if context[CTX_SEQ]:
func = self.pack_seq

func(obj, context)
Expand All @@ -586,10 +596,8 @@ def __unpack__(self, context: _ContextLike) -> Any:
:param context: The current operation context.
:return: The unpacked data.
"""
field: Field = context[CTX_FIELD]
if field.is_seq():
if context[CTX_SEQ]:
return self.unpack_seq(context)

return self.unpack_single(context)

def __repr__(self) -> str:
Expand Down
Loading

0 comments on commit c773ece

Please sign in to comment.