Skip to content

Commit

Permalink
refactor: move implementation of xdr primitives to serializer / deser…
Browse files Browse the repository at this point in the history
…ializer subclass.

- Replace enums with intenums
  • Loading branch information
twiggler committed Jan 23, 2025
1 parent 92ede66 commit fba07e5
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 98 deletions.
4 changes: 2 additions & 2 deletions dissect/target/helpers/nfs/nfs3.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass
from enum import Enum, IntEnum
from enum import IntEnum
from typing import ClassVar, NamedTuple

# See https://datatracker.ietf.org/doc/html/rfc1057
Expand All @@ -20,7 +20,7 @@ class ProcedureDescriptor(NamedTuple):
ReadFileProc = ProcedureDescriptor(100003, 3, 6)


class Nfs3Stat(Enum):
class Nfs3Stat(IntEnum):
OK = 0
ERR_PERM = 1
ERR_NOENT = 2
Expand Down
22 changes: 11 additions & 11 deletions dissect/target/helpers/nfs/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@
SpecData3,
)
from dissect.target.helpers.sunrpc.serializer import (
Deserializer,
Int32Serializer,
OpaqueVarLengthSerializer,
Serializer,
XdrDeserializer,
XdrSerializer,
)
from dissect.target.helpers.sunrpc.sunrpc import Bool


# Used Union because 3.9 does not support '|' here even with future annotations
class MountResultDeserializer(Deserializer[Union[MountOK, MountStat3]]):
class MountResultDeserializer(XdrDeserializer[Union[MountOK, MountStat3]]):
def deserialize(self, payload: io.BytesIO) -> MountOK | MountStat3:
mount_stat = self._read_enum(payload, MountStat3)
if mount_stat != MountStat3.OK:
Expand All @@ -40,7 +40,7 @@ def deserialize(self, payload: io.BytesIO) -> MountOK | MountStat3:
return MountOK(FileHandle3(filehandle_bytes), auth_flavors)


class ReadDirPlusParamsSerializer(Serializer[ReadDirPlusParams]):
class ReadDirPlusParamsSerializer(XdrSerializer[ReadDirPlusParams]):
def serialize(self, params: ReadDirPlusParams) -> bytes:
result = self._write_var_length_opaque(params.dir.opaque)
result += self._write_uint64(params.cookie)
Expand All @@ -51,23 +51,23 @@ def serialize(self, params: ReadDirPlusParams) -> bytes:
return result


class SpecDataSerializer(Deserializer[SpecData3]):
class SpecDataSerializer(XdrDeserializer[SpecData3]):
def deserialize(self, payload: io.BytesIO) -> bytes:
specdata1 = self._read_uint32(payload)
specdata2 = self._read_uint32(payload)

return SpecData3(specdata1, specdata2)


class NfsTimeSerializer(Deserializer[NfsTime3]):
class NfsTimeSerializer(XdrDeserializer[NfsTime3]):
def deserialize(self, payload: io.BytesIO) -> bytes:
seconds = self._read_uint32(payload)
nseconds = self._read_uint32(payload)

return NfsTime3(seconds, nseconds)


class FileAttributesSerializer(Deserializer[FileAttributes3]):
class FileAttributesSerializer(XdrDeserializer[FileAttributes3]):
def deserialize(self, payload: io.BytesIO) -> FileAttributes3:
type = self._read_enum(payload, FileType3)
mode = self._read_uint32(payload)
Expand All @@ -87,7 +87,7 @@ def deserialize(self, payload: io.BytesIO) -> FileAttributes3:
return FileAttributes3(type, mode, nlink, uid, gid, size, used, rdev, fsid, fileid, atime, mtime, ctime)


class EntryPlusSerializer(Deserializer[EntryPlus3]):
class EntryPlusSerializer(XdrDeserializer[EntryPlus3]):
def deserialize(self, payload: io.BytesIO) -> EntryPlus3:
fileid = self._read_uint64(payload)
name = self._read_string(payload)
Expand All @@ -100,7 +100,7 @@ def deserialize(self, payload: io.BytesIO) -> EntryPlus3:


# Used Union because 3.9 does not support '|' here even with future annotations
class ReadDirPlusResultDeserializer(Deserializer[Union[ReadDirPlusResult3, Nfs3Stat]]):
class ReadDirPlusResultDeserializer(XdrDeserializer[Union[ReadDirPlusResult3, Nfs3Stat]]):
def deserialize(self, payload: io.BytesIO) -> ReadDirPlusResult3:
stat = self._read_enum(payload, Nfs3Stat)
if stat != Nfs3Stat.OK:
Expand All @@ -122,15 +122,15 @@ def deserialize(self, payload: io.BytesIO) -> ReadDirPlusResult3:
return ReadDirPlusResult3(dir_attributes, CookieVerf3(cookieverf), entries, eof)


class Read3ArgsSerializer(Serializer[ReadDirPlusParams]):
class Read3ArgsSerializer(XdrSerializer[ReadDirPlusParams]):
def serialize(self, args: Read3args) -> bytes:
result = self._write_var_length_opaque(args.file.opaque)
result += self._write_uint64(args.offset)
result += self._write_uint32(args.count)
return result


class Read3ResultDeserializer(Deserializer[Read3resok]):
class Read3ResultDeserializer(XdrDeserializer[Read3resok]):
def deserialize(self, payload: io.BytesIO) -> Read3resok:
stat = self._read_enum(payload, Nfs3Stat)
if stat != Nfs3Stat.OK:
Expand Down
8 changes: 4 additions & 4 deletions dissect/target/helpers/sunrpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
AuthNullSerializer,
AuthSerializer,
AuthUnixSerializer,
Deserializer,
MessageSerializer,
Serializer,
XdrDeserializer,
XdrSerializer,
)

Credentials = TypeVar("Credentials")
Expand Down Expand Up @@ -102,8 +102,8 @@ def call(
self,
proc_desc: ProcedureDescriptor,
params: Params,
params_serializer: Serializer[Params],
result_deserializer: Deserializer[Results],
params_serializer: XdrSerializer[Params],
result_deserializer: XdrDeserializer[Results],
) -> Results:
"""Synchronously call an RPC procedure and return the result"""

Expand Down
149 changes: 74 additions & 75 deletions dissect/target/helpers/sunrpc/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,68 @@ class Serializer(ABC, Generic[Serializable]):
def serialize(self, _: Serializable) -> bytes:
pass

# Unfortunately xdrlib is deprecated in Python 3.11, so we implement the following serialization methods
# to be used by descendants of the Serializer class.
# See https://datatracker.ietf.org/doc/html/rfc1014 for the XDR specification.
def _write_uint32(self, i: int) -> bytes:

class Deserializer(ABC, Generic[Serializable]):
def deserialize_from_bytes(self, payload: bytes) -> Serializable:
return self.deserialize(io.BytesIO(payload))

@abstractmethod
def deserialize(self, _: io.BytesIO) -> Serializable:
pass


# Unfortunately xd)rlib is deprecated in Python 3.11, so we implement the following serialization methods
# to be used by descendants of the Serializer class.
# See https://datatracker.ietf.org/doc/html/rfc1014 for the XDR specification.


class Int32Serializer(Serializer[int], Deserializer[int]):
def serialize(self, i: int) -> bytes:
return i.to_bytes(length=4, byteorder="big", signed=True)

def deserialize(self, payload: io.BytesIO) -> int:
return int.from_bytes(payload.read(4), byteorder="big", signed=True)


class UInt32Serializer(Serializer[int], Deserializer[int]):
def serialize(self, i: int) -> bytes:
return i.to_bytes(length=4, byteorder="big", signed=False)

def deserialize(self, payload: io.BytesIO) -> int:
return int.from_bytes(payload.read(4), byteorder="big", signed=False)


class OpaqueVarLengthSerializer(Serializer[bytes], Deserializer[bytes]):
def serialize(self, body: bytes) -> bytes:
length = len(body)
result = UInt32Serializer().serialize(length)
result += body

padding_bytes = (ALIGNMENT - (length % ALIGNMENT)) % ALIGNMENT
return result + b"\x00" * padding_bytes

def deserialize(self, payload: io.BytesIO) -> bytes:
length = UInt32Serializer().deserialize(payload)
result = payload.read(length)
padding_bytes = (ALIGNMENT - (length % ALIGNMENT)) % ALIGNMENT
payload.read(padding_bytes)
return result


class StringSerializer(Serializer[str], Deserializer[str]):
def serialize(self, s: str) -> bytes:
return OpaqueVarLengthSerializer().serialize(s.encode("ascii"))

def deserialize(self, payload: io.BytesIO) -> str:
return OpaqueVarLengthSerializer().deserialize(payload).decode("ascii")


class XdrSerializer(Generic[Serializable], Serializer[Serializable]):
def _write_uint32(self, i: int) -> bytes:
return UInt32Serializer().serialize(i)

def _write_int32(self, i: int) -> bytes:
return i.to_bytes(length=4, byteorder="big", signed=True)
return Int32Serializer().serialize(i)

def _write_uint64(self, i: int) -> bytes:
return i.to_bytes(length=8, byteorder="big", signed=False)
Expand All @@ -51,33 +105,18 @@ def _write_var_length(self, elements: list[ElementType], serializer: Serializer[
return result + b"".join(payload)

def _write_var_length_opaque(self, body: bytes) -> bytes:
length = len(body)
result = self._write_uint32(length)
result += body

padding_bytes = (ALIGNMENT - (length % ALIGNMENT)) % ALIGNMENT
return result + b"\x00" * padding_bytes
return OpaqueVarLengthSerializer().serialize(body)

def _write_string(self, s: str) -> bytes:
return self._write_var_length_opaque(s.encode("ascii"))
return StringSerializer().serialize(s)


class Deserializer(ABC, Generic[Serializable]):
def deserialize_from_bytes(self, payload: bytes) -> Serializable:
return self.deserialize(io.BytesIO(payload))

@abstractmethod
def deserialize(self, _: io.BytesIO) -> Serializable:
pass

# Unfortunately xdrlib is deprecated in Python 3.11, so we implement the following serialization methods
# to be used by descendants of the Serializer class.
# See https://datatracker.ietf.org/doc/html/rfc1014 for the XDR specification.
class XdrDeserializer(Generic[Serializable], Deserializer[Serializable]):
def _read_uint32(self, payload: io.BytesIO) -> int:
return int.from_bytes(payload.read(4), byteorder="big", signed=False)
return UInt32Serializer().deserialize(payload)

def _read_int32(self, payload: io.BytesIO) -> int:
return int.from_bytes(payload.read(4), byteorder="big", signed=True)
return Int32Serializer().deserialize(payload)

def _read_uint64(self, payload: io.BytesIO) -> int:
return int.from_bytes(payload.read(8), byteorder="big", signed=False)
Expand All @@ -86,19 +125,15 @@ def _read_enum(self, payload: io.BytesIO, enum: EnumType) -> EnumType:
value = self._read_int32(payload)
return enum(value)

def _read_var_length_opaque(self, payload: io.BytesIO) -> bytes:
length = self._read_uint32(payload)
result = payload.read(length)
padding_bytes = (ALIGNMENT - (length % ALIGNMENT)) % ALIGNMENT
payload.read(padding_bytes)
return result

def _read_var_length(self, payload: io.BytesIO, deserializer: Deserializer[ElementType]) -> list[ElementType]:
length = self._read_uint32(payload)
return [deserializer.deserialize(payload) for _ in range(length)]

def _read_var_length_opaque(self, payload: io.BytesIO) -> bytes:
return OpaqueVarLengthSerializer().deserialize(payload)

def _read_string(self, payload: io.BytesIO) -> str:
return self._read_var_length_opaque(payload).decode("ascii")
return StringSerializer().deserialize(payload)

def _read_optional(self, payload: io.BytesIO, deserializer: Deserializer[ElementType]) -> ElementType | None:
has_value = self._read_enum(payload, sunrpc.Bool)
Expand All @@ -107,42 +142,6 @@ def _read_optional(self, payload: io.BytesIO, deserializer: Deserializer[Element
return deserializer.deserialize(payload)


# RdJ: A bit clunky having to lift the primitives inside the Serializer/Deserializer class
# to enable composition.
# Possible design mistake, Alternatively, make serializers functions, since no state is kept.
# But most of our stuff is OOP, so it would be inconsistent.
class Int32Serializer(Serializer[int], Deserializer[int]):
def serialize(self, i: int) -> bytes:
return self._write_int32(i)

def deserialize(self, payload: io.BytesIO) -> int:
return self._read_int32(payload)


class UInt32Serializer(Serializer[int], Deserializer[int]):
def serialize(self, i: int) -> bytes:
return self._write_uint32(i)

def deserialize(self, payload: io.BytesIO) -> int:
return self._read_uint32(payload)


class StringSerializer(Serializer[str], Deserializer[str]):
def serialize(self, s: str) -> bytes:
return self._write_string(s)

def deserialize(self, payload: io.BytesIO) -> str:
return self._read_string(payload)


class OpaqueVarLengthSerializer(Serializer[bytes], Deserializer[bytes]):
def serialize(self, body: bytes) -> bytes:
return self._write_var_length_opaque(body)

def deserialize(self, payload: io.BytesIO) -> bytes:
return self._read_var_length_opaque(payload)


class ReplyStat(Enum):
MSG_ACCEPTED = 0
MSG_DENIED = 1
Expand All @@ -155,7 +154,7 @@ class AuthFlavor(Enum):
AUTH_DES = 3


class AuthSerializer(Generic[AuthProtocol], Serializer[AuthProtocol], Deserializer[AuthProtocol]):
class AuthSerializer(Generic[AuthProtocol], XdrSerializer[AuthProtocol], XdrDeserializer[AuthProtocol]):
def serialize(self, protocol: AuthProtocol) -> bytes:
flavor = self._flavor()
result = self._write_int32(flavor)
Expand Down Expand Up @@ -219,13 +218,13 @@ def _read_body(self, payload: io.BytesIO) -> sunrpc.AuthUnix:

class MessageSerializer(
Generic[ProcedureParams, ProcedureResults, Credentials, Verifier],
Serializer[sunrpc.Message[ProcedureParams, ProcedureResults, Credentials, Verifier]],
Deserializer[sunrpc.Message[ProcedureParams, ProcedureResults, Credentials, Verifier]],
XdrSerializer[sunrpc.Message[ProcedureParams, ProcedureResults, Credentials, Verifier]],
XdrDeserializer[sunrpc.Message[ProcedureParams, ProcedureResults, Credentials, Verifier]],
):
def __init__(
self,
paramsSerializer: Serializer[ProcedureParams],
resultsDeserializer: Deserializer[ProcedureResults],
paramsSerializer: XdrSerializer[ProcedureParams],
resultsDeserializer: XdrDeserializer[ProcedureResults],
credentialsSerializer: AuthSerializer[Credentials],
verifierSerializer: AuthSerializer[Verifier],
):
Expand Down Expand Up @@ -296,7 +295,7 @@ def _read_mismatch(self, payload: io.BytesIO) -> sunrpc.Mismatch:
return sunrpc.Mismatch(low, high)


class PortMappingSerializer(Serializer[sunrpc.PortMapping]):
class PortMappingSerializer(XdrSerializer[sunrpc.PortMapping]):
def serialize(self, port_mapping: sunrpc.PortMapping) -> bytes:
result = self._write_uint32(port_mapping.program)
result += self._write_uint32(port_mapping.version)
Expand Down
12 changes: 6 additions & 6 deletions dissect/target/helpers/sunrpc/sunrpc.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
from __future__ import annotations

from dataclasses import dataclass
from enum import Enum
from enum import IntEnum
from typing import Generic, TypeVar


class Bool(Enum):
class Bool(IntEnum):
FALSE = 0
TRUE = 1


class AcceptStat(Enum):
class AcceptStat(IntEnum):
SUCCESS = 0 # RPC executed successfully
PROG_UNAVAIL = 1 # remote hasn't exported program
PROG_MISMATCH = 2 # remote can't support version #
PROC_UNAVAIL = 3 # program can't support procedure
GARBAGE_ARGS = 4 # procedure can't decode params


class RejectStat(Enum):
class RejectStat(IntEnum):
RPC_MISMATCH = 0
AUTH_ERROR = 1


class AuthStat(Enum):
class AuthStat(IntEnum):
AUTH_BADCRED = 1 # bad credentials (seal broken)
AUTH_REJECTEDCRED = 2 # client must begin new session
AUTH_BADVERF = 3 # bad verifier (seal broken)
Expand Down Expand Up @@ -94,7 +94,7 @@ class Message(Generic[ProcedureParams, ProcedureResults, Credentials, Verifier])
body: CallBody[ProcedureParams, Credentials, Verifier] | AcceptedReply[ProcedureResults, Verifier] | RejectedReply


class Protocol(Enum):
class Protocol(IntEnum):
TCP = 6
UDP = 17

Expand Down

0 comments on commit fba07e5

Please sign in to comment.