From 39d80961ee0c8617fd9f602533f1251c3e9ed03a Mon Sep 17 00:00:00 2001 From: Perry Kundert Date: Wed, 4 Dec 2024 08:53:00 -0700 Subject: [PATCH] Be more careful with type checking for framer --- pymodbus/client/base.py | 8 ++++++-- pymodbus/server/async_io.py | 14 ++++++++------ 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/pymodbus/client/base.py b/pymodbus/client/base.py index 98d34aca5..ea8dd443a 100644 --- a/pymodbus/client/base.py +++ b/pymodbus/client/base.py @@ -35,9 +35,11 @@ def __init__( """ ModbusClientMixin.__init__(self) # type: ignore[arg-type] self.comm_params = comm_params + if isinstance(framer, FramerType): + framer = FRAMER_NAME_TO_CLASS[framer] self.ctx = TransactionManager( comm_params, - FRAMER_NAME_TO_CLASS.get(framer, framer)(DecodePDU(False)), + framer(DecodePDU(False)), retries, False, trace_packet, @@ -133,7 +135,9 @@ def __init__( self.slaves: list[int] = [] # Common variables. - self.framer: FramerBase = FRAMER_NAME_TO_CLASS.get(framer, framer)(DecodePDU(False)) + if isinstance(framer, FramerType): + framer = FRAMER_NAME_TO_CLASS[framer] + self.framer: FramerBase = framer(DecodePDU(False)) self.transaction = TransactionManager( self.comm_params, self.framer, diff --git a/pymodbus/server/async_io.py b/pymodbus/server/async_io.py index 41751c534..8bf08083e 100644 --- a/pymodbus/server/async_io.py +++ b/pymodbus/server/async_io.py @@ -10,7 +10,7 @@ from pymodbus.datastore import ModbusServerContext from pymodbus.device import ModbusControlBlock, ModbusDeviceIdentification from pymodbus.exceptions import NoSuchSlaveException -from pymodbus.framer import FRAMER_NAME_TO_CLASS, FramerType, FramerBase +from pymodbus.framer import FRAMER_NAME_TO_CLASS, FramerBase, FramerType from pymodbus.logging import Log from pymodbus.pdu import DecodePDU, ModbusPDU from pymodbus.pdu.pdu import ExceptionResponse @@ -224,7 +224,9 @@ def __init__( if isinstance(identity, ModbusDeviceIdentification): self.control.Identity.update(identity) # Support mapping of FramerType to a Framer class, or a Framer class - self.framer = FRAMER_NAME_TO_CLASS.get(framer, framer) + if isinstance(framer, FramerType): + framer = FRAMER_NAME_TO_CLASS[framer] + self.framer = framer self.serving: asyncio.Future = asyncio.Future() def callback_new_connection(self): @@ -273,7 +275,7 @@ def __init__( self, context: ModbusServerContext, *, - framer=FramerType.SOCKET, + framer: FramerType | type[FramerBase] = FramerType.SOCKET, identity: ModbusDeviceIdentification | None = None, address: tuple[str, int] = ("", 502), ignore_missing_slaves: bool = False, @@ -336,7 +338,7 @@ def __init__( # pylint: disable=too-many-arguments self, context: ModbusServerContext, *, - framer=FramerType.TLS, + framer: FramerType | type[FramerBase] = FramerType.TLS, identity: ModbusDeviceIdentification | None = None, address: tuple[str, int] = ("", 502), sslctx=None, @@ -403,7 +405,7 @@ def __init__( self, context: ModbusServerContext, *, - framer=FramerType.SOCKET, + framer: FramerType | type[FramerBase] = FramerType.SOCKET, identity: ModbusDeviceIdentification | None = None, address: tuple[str, int] = ("", 502), ignore_missing_slaves: bool = False, @@ -463,7 +465,7 @@ def __init__( self, context: ModbusServerContext, *, - framer: FramerType = FramerType.RTU, + framer: FramerType | type[FramerBase] = FramerType.RTU, ignore_missing_slaves: bool = False, identity: ModbusDeviceIdentification | None = None, broadcast_enable: bool = False,