Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
277 changes: 157 additions & 120 deletions smpclient/transport/serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,30 +44,17 @@ class SMPSerialTransport(SMPTransport):
_POLLING_INTERVAL_S = 0.005
_CONNECTION_RETRY_INTERVAL_S = 0.500

class _ReadBuffer:
"""The state of the read buffer."""

@unique
class State(IntEnum):
SMP = 0
"""An SMP start or continue delimiter has been received and the
`smp_buffer` is being filled with the remainder of the SMP packet.
"""

SER = 1
"""The SMP start delimiter has not been received and the
`ser_buffer` is being filled with data.
"""

def __init__(self) -> None:
self.smp = bytearray([])
"""The buffer for the SMP packet."""

self.ser = bytearray([])
"""The buffer for serial data that is not part of an SMP packet."""
@unique
class BufferState(IntEnum):
SMP = 0
"""An SMP start or continue delimiter has been received and
`_buffer` is being parsed as an SMP packet.
"""

self.state = SMPSerialTransport._ReadBuffer.State.SER
"""The state of the read buffer."""
SERIAL = 1
"""The SMP start delimiter has not been received and
`_buffer` is being parsed as serial data.
"""

def __init__( # noqa: DOC301
self,
Expand Down Expand Up @@ -120,6 +107,7 @@ def __init__( # noqa: DOC301
self._max_smp_encoded_frame_size: Final = max_smp_encoded_frame_size
self._line_length: Final = line_length
self._line_buffers: Final = line_buffers
self._read_timeout: float = timeout or 120
self._conn: Final = Serial(
baudrate=baudrate,
bytesize=bytesize,
Expand All @@ -133,11 +121,29 @@ def __init__( # noqa: DOC301
inter_byte_timeout=inter_byte_timeout,
exclusive=exclusive,
)
self._buffer = SMPSerialTransport._ReadBuffer()

self._smp_packet_queue: asyncio.Queue[bytes] = asyncio.Queue()
"""Contains full SMP packets."""
self._serial_buffer = bytearray()
"""Contains any non-SMP serial data."""
self._buffer: bytearray = bytearray([])
"""Contains all incoming data (serial + SMP intertwined, may be incomplete)."""
self._buffer_state = SMPSerialTransport.BufferState.SERIAL
"""The state of the read buffer."""

logger.debug(f"Initialized {self.__class__.__name__}")

def _reset_state(self) -> None:
"""Reset internal state and queues for a fresh connection."""

self._smp_packet_queue = asyncio.Queue()
self._serial_buffer.clear()
self._buffer = bytearray([])
self._buffer_state = SMPSerialTransport.BufferState.SERIAL

@override
async def connect(self, address: str, timeout_s: float) -> None:
self._reset_state()
self._conn.port = address
logger.debug(f"Connecting to {self._conn.port=}")
start_time: Final = time.time()
Expand Down Expand Up @@ -191,112 +197,143 @@ async def receive(self) -> bytes:

logger.debug("Waiting for response")
while True:
b = await self._read_one_smp_packet(timeout_s=self._read_timeout)
try:
b = await self._readuntil()
decoder.send(b)
except StopIteration as e:
logger.debug(f"Finished receiving {len(e.value)} byte response")
return e.value
except SerialException as e:
logger.error(f"Failed to receive response: {e}")
raise SMPTransportDisconnected(
f"{self.__class__.__name__} disconnected from {self._conn.port}"
)

async def _readuntil(self) -> bytes:
"""Read `bytes` until the `delimiter` then return the `bytes` including the `delimiter`."""
async def _read_one_smp_packet(self, timeout_s: float) -> bytes:
"""Returns one received SMP packet from the queue.
Raises `SMPTransportDisconnected` if disconnected.
Raises `TimeoutError` if timeout is reached."""
if not self._smp_packet_queue.empty():
# There may already be a response in the queue, if for some reason we've received
# multiple responses and haven't read them in-between. This is not standard but
# it is possible, and easier to implement this way.
return self._smp_packet_queue.get_nowait()

await self._read_and_process(read_until_one_smp_packet=True, timeout_s=timeout_s)
if not self._smp_packet_queue.empty():
return self._smp_packet_queue.get_nowait()
else:
raise TimeoutError("No packet received.")

async def read_serial(self, delimiter: bytes | None = None) -> bytes:
"""Drain regular serial traffic (non-SMP bytes) until given delimiter.
Returns all available bytes if no delimiter is given.
May return empty bytes if nothing has been received."""
await self._read_and_process(read_until_one_smp_packet=False, timeout_s=0)
if delimiter is None:
res = bytes(self._serial_buffer)
self._serial_buffer.clear()
return res
else:
try:
first_match, remaining_data = self._serial_buffer.split(delimiter, 1)
except ValueError:
return bytes()
self._serial_buffer = remaining_data
return bytes(first_match)

async def _read_and_process(self, read_until_one_smp_packet: bool, timeout_s: float) -> None:
"""Reads raw data from serial and processes it into SMP packets and regular serial data."""
start_s = time.time()
while True:
try:
data = self._conn.read_all() or b""
except StopIteration:
data = b""
except SerialException as exc:
raise SMPTransportDisconnected(f"Failed to read from {self._conn.port}: {exc}")

if data:
self._buffer.extend(data)
await self._process_buffer()
else:
await asyncio.sleep(SMPSerialTransport._POLLING_INTERVAL_S)

START_DELIMITER: Final = smppacket.SIXTY_NINE
CONTINUE_DELIMITER: Final = smppacket.FOUR_TWENTY
END_DELIMITER: Final = b"\n"
if read_until_one_smp_packet and self._smp_packet_queue.qsize():
break # Packet found; exit early
if time.time() - start_s > timeout_s:
break # Timeout

# fake async until I get around to replacing pyserial
async def _process_buffer(self) -> None:
"""Process buffered data until more bytes are needed."""

i_smp_start = 0
i_smp_end = 0
i_start: int | None = None
i_continue: int | None = None
while True:
if self._buffer.state == SMPSerialTransport._ReadBuffer.State.SER:
# read the entire OS buffer
try:
self._buffer.ser.extend(self._conn.read_all() or [])
except StopIteration:
pass

try: # search the buffer for the index of the start delimiter
i_start = self._buffer.ser.index(START_DELIMITER)
except ValueError:
i_start = None

try: # search the buffer for the index of the continue delimiter
i_continue = self._buffer.ser.index(CONTINUE_DELIMITER)
except ValueError:
i_continue = None

if i_start is not None and i_continue is not None:
i_smp_start = min(i_start, i_continue)
elif i_start is not None:
i_smp_start = i_start
elif i_continue is not None:
i_smp_start = i_continue
else: # no delimiters found yet, clear non SMP data and wait
while True:
try: # search the buffer for newline characters
i = self._buffer.ser.index(b"\n")
try: # log as a string if possible
logger.warning(
f"{self._conn.port}: {self._buffer.ser[:i].decode()}"
)
except UnicodeDecodeError: # log as bytes if not
logger.warning(f"{self._conn.port}: {self._buffer.ser[:i].hex()}")
self._buffer.ser = self._buffer.ser[i + 1 :]
except ValueError:
break
await asyncio.sleep(SMPSerialTransport._POLLING_INTERVAL_S)
continue

if i_smp_start != 0: # log the rest of the serial buffer
try: # log as a string if possible
logger.warning(
f"{self._conn.port}: {self._buffer.ser[:i_smp_start].decode()}"
)
except UnicodeDecodeError: # log as bytes if not
logger.warning(f"{self._conn.port}: {self._buffer.ser[:i_smp_start].hex()}")

self._buffer.smp = self._buffer.ser[i_smp_start:]
self._buffer.ser.clear()
self._buffer.state = SMPSerialTransport._ReadBuffer.State.SMP
i_smp_end = 0

# don't await since the buffer may already contain the end delimiter

elif self._buffer.state == SMPSerialTransport._ReadBuffer.State.SMP:
# read the entire OS buffer
try:
self._buffer.smp.extend(self._conn.read_all() or [])
except StopIteration:
pass

try: # search the buffer for the index of the delimiter
i_smp_end = self._buffer.smp.index(END_DELIMITER, i_smp_end) + len(
END_DELIMITER
)
except ValueError: # delimiter not found yet, wait
await asyncio.sleep(SMPSerialTransport._POLLING_INTERVAL_S)
continue

# out is everything up to and including the delimiter
out = self._buffer.smp[:i_smp_end]
logger.debug(f"Received {len(out)} byte chunk")

# there may be some leftover to save for the next read, but
# it's not necessarily SMP data
self._buffer.ser = self._buffer.smp[i_smp_end:]

self._buffer.state = SMPSerialTransport._ReadBuffer.State.SER

return out
if self._buffer_state == SMPSerialTransport.BufferState.SERIAL:
should_continue = await self._process_buffer_as_serial_data()
else:
should_continue = await self._process_buffer_as_smp_data()

if not should_continue:
break

async def _process_buffer_as_serial_data(self) -> bool:
"""Handle non-SMP data and transition to SMP state when finding SMP frame-start delimiters.
Return True if further data remains to process in the buffer; return False otherwise."""

if len(self._buffer) == 1 and self._could_be_smp_packet_start(self._buffer[0]):
return False # Not enough information to process

smp_packet_start: int = self._find_smp_packet_start(self._buffer)
if smp_packet_start >= 0:
serial_data, remaining_data = (
self._buffer[:smp_packet_start],
self._buffer[smp_packet_start:],
)
self._serial_buffer.extend(serial_data)

self._buffer = remaining_data
self._buffer_state = SMPSerialTransport.BufferState.SMP
return True

# Everything is serial data:
self._serial_buffer.extend(self._buffer)
self._buffer.clear()
return False

async def _process_buffer_as_smp_data(self) -> bool:
"""Handle SMP data and transition to SERIAL state when finding SMP frame-end delimiter.
Return True if further data remains to process in the buffer; return False otherwise."""

smp_packet_end: int = self._buffer.find(smppacket.END_DELIMITER)
if smp_packet_end == -1:
return False
smp_packet_end += len(smppacket.END_DELIMITER)

smp_data, remaining_data = (
self._buffer[:smp_packet_end],
self._buffer[smp_packet_end:],
)
await self._smp_packet_queue.put(bytes(smp_data))

self._buffer = remaining_data
# Even if the remaining data is actually SMP data, then the next serial parse
# will simply put us right back into SMP state - no need to check here.
self._buffer_state = SMPSerialTransport.BufferState.SERIAL

return bool(self._buffer)

def _find_smp_packet_start(self, buf: bytearray) -> int:
"""Return index of the earliest SMP frame-start delimiter, if any; -1 if none found."""

indices = [
i
for i in (
buf.find(smppacket.START_DELIMITER),
buf.find(smppacket.CONTINUE_DELIMITER),
)
if i != -1
]
return min(indices) if indices else -1

def _could_be_smp_packet_start(self, byte: int) -> bool:
"""Return True if the given byte value matches the start of any SMP packet delimiter."""

return byte == smppacket.START_DELIMITER[0] or byte == smppacket.CONTINUE_DELIMITER[0]

@override
async def send_and_receive(self, data: bytes) -> bytes:
Expand Down
Loading
Loading