Skip to content

Commit

Permalink
Merge pull request #51 from puddly/puddly/zigpy-api-cleanup
Browse files Browse the repository at this point in the history
Introduce compatibility with recent zigpy changes and improve UART reset stability
  • Loading branch information
DamKast committed Jun 26, 2024
2 parents 6f26720 + 183c22f commit 1d0179e
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 257 deletions.
100 changes: 67 additions & 33 deletions zigpy_zboss/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from zigpy_zboss.utils import OneShotResponseListener

LOGGER = logging.getLogger(__name__)
LISTENER_LOGGER = LOGGER.getChild("listener")
LISTENER_LOGGER.propagate = False

# All of these are in seconds
AFTER_BOOTLOADER_SKIP_BYTE_DELAY = 2.5
Expand All @@ -30,6 +32,10 @@

DEFAULT_TIMEOUT = 5

EXPECTED_DISCONNECT_TIMEOUT = 5.0
MAX_RESET_RECONNECT_ATTEMPTS = 5
RESET_RECONNECT_DELAY = 1.0


class ZBOSS:
"""Class linking zigpy with ZBOSS running on nRF SoC."""
Expand All @@ -52,6 +58,8 @@ def __init__(self, config: conf.ConfigType):
self._rx_fragments = []

self._ncp_debug = None
self._reset_uart_reconnect = asyncio.Lock()
self._disconnected_event = asyncio.Event()

def set_application(self, app):
"""Set the application using the ZBOSS class."""
Expand Down Expand Up @@ -87,30 +95,28 @@ async def connect(self) -> None:
LOGGER.debug(
"Connected to %s at %s baud", self._uart.name, self._uart.baudrate)

def connection_made(self) -> None:
"""Notify that connection has been made.
Called by the UART object when a connection has been made.
"""
pass

def connection_lost(self, exc) -> None:
"""Port has been closed.
Called by the UART object to indicate that the port was closed.
Propagates up to the `ControllerApplication` that owns this ZBOSS
instance.
"""
LOGGER.debug("We were disconnected from %s: %s", self._port_path, exc)
self._uart = None
self._disconnected_event.set()

if self._app is not None and not self._reset_uart_reconnect.locked():
self._app.connection_lost(exc)

def close(self) -> None:
"""Clean up resources, namely the listener queues.
Calling this will reset ZBOSS to the same internal state as a fresh
ZBOSS instance.
"""
self._app = None
self.version = None
if not self._reset_uart_reconnect.locked():
self._app = None
self.version = None

if self._uart is not None:
self._uart.close()
Expand Down Expand Up @@ -152,11 +158,11 @@ def frame_received(self, frame: Frame) -> bool:
continue

if not listener.resolve(command):
LOGGER.debug(f"{command} does not match {listener}")
LISTENER_LOGGER.debug(f"{command} does not match {listener}")
continue

matched = True
LOGGER.debug(f"{command} matches {listener}")
LISTENER_LOGGER.debug(f"{command} matches {listener}")

if isinstance(listener, OneShotResponseListener):
one_shot_matched = True
Expand All @@ -181,6 +187,8 @@ async def request(
raise ValueError(
f"Cannot send a command that isn't a request: {request!r}")

LOGGER.debug("Sending request: %s", request)

frame = request.to_frame()
# If the frame is too long, it needs fragmentation.
fragments = frame.handle_tx_fragmentation()
Expand All @@ -199,13 +207,14 @@ async def _send_frags(self, fragments, response_future, timeout):
await self._send_to_uart(frag, None)

async def _send_to_uart(
self, frame, response_future, timeout=DEFAULT_TIMEOUT):
self, frame, response_future=None, timeout=DEFAULT_TIMEOUT):
"""Send the frame and waits for the response."""
if self._uart is None:
return

try:
await self._uart.send(frame)
if response_future:
if response_future is not None:
async with async_timeout.timeout(timeout):
return await response_future
except asyncio.TimeoutError:
Expand All @@ -229,7 +238,7 @@ def wait_for_responses(
"""
listener = OneShotResponseListener(responses)

LOGGER.debug("Creating one-shot listener %s", listener)
LISTENER_LOGGER.debug("Creating one-shot listener %s", listener)

for header in listener.matching_headers():
self._listeners[header].append(listener)
Expand Down Expand Up @@ -258,7 +267,7 @@ def remove_listener(self, listener: BaseResponseListener) -> None:
if not self._listeners:
return

LOGGER.debug("Removing listener %s", listener)
LISTENER_LOGGER.debug("Removing listener %s", listener)

for header in listener.matching_headers():
try:
Expand All @@ -267,7 +276,7 @@ def remove_listener(self, listener: BaseResponseListener) -> None:
pass

if not self._listeners[header]:
LOGGER.debug(
LISTENER_LOGGER.debug(
"Cleaning up empty listener list for header %s", header
)
del self._listeners[header]
Expand All @@ -278,7 +287,7 @@ def remove_listener(self, listener: BaseResponseListener) -> None:
self._listeners.values()):
counts[type(listener)] += 1

LOGGER.debug(
LISTENER_LOGGER.debug(
f"There are {counts[IndicationListener]} callbacks and"
f" {counts[OneShotResponseListener]} one-shot listeners remaining"
)
Expand All @@ -291,7 +300,7 @@ def register_indication_listeners(
"""
listener = IndicationListener(responses, callback=callback)

LOGGER.debug(f"Creating callback {listener}")
LISTENER_LOGGER.debug(f"Creating callback {listener}")

for header in listener.matching_headers():
self._listeners[header].append(listener)
Expand Down Expand Up @@ -324,20 +333,45 @@ async def version(self):
version[idx] = ".".join([major, minor, revision, commit])
return tuple(version)

async def reset(self, option=t.ResetOptions(0)):
async def reset(
self,
option: t.ResetOptions = t.ResetOptions.NoOptions,
wait_for_reset: bool = True,
):
"""Reset the NCP module (see ResetOptions)."""
if self._app is not None:
tsn = self._app.get_sequence()
else:
tsn = 0
LOGGER.debug("Sending a reset: %s", option)

tsn = self._app.get_sequence() if self._app is not None else 0
req = c.NcpConfig.NCPModuleReset.Req(TSN=tsn, Option=option)
self._uart.reset_flag = True
res = await self._send_to_uart(
req.to_frame(),
self.wait_for_response(
c.NcpConfig.NCPModuleReset.Rsp(partial=True)
),
timeout=10
)
if not res.TSN == 0xFF:
raise ValueError("Should get TSN 0xFF")

async with self._reset_uart_reconnect:
await self._send_to_uart(req.to_frame())

if not wait_for_reset:
return

LOGGER.debug("Waiting for radio to disconnect")

try:
async with async_timeout.timeout(EXPECTED_DISCONNECT_TIMEOUT):
await self._disconnected_event.wait()
except asyncio.TimeoutError:
LOGGER.debug(
"Radio did not disconnect, must be using external UART"
)
return

LOGGER.debug("Radio has disconnected, reconnecting")

for attempt in range(MAX_RESET_RECONNECT_ATTEMPTS):
await asyncio.sleep(RESET_RECONNECT_DELAY)

try:
await self.connect()
break
except Exception as exc:
if attempt == MAX_RESET_RECONNECT_ATTEMPTS - 1:
raise

LOGGER.debug("Failed to reconnect, retrying: %r", exc)
2 changes: 1 addition & 1 deletion zigpy_zboss/types/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,5 +717,5 @@ class Relationship(t.enum8):
STATUS_SCHEMA = (
t.Param("TSN", t.uint8_t, "Transmit Sequence Number"),
t.Param("StatusCat", StatusCategory, "Status category code"),
t.Param("StatusCode", t.uint8_t, "Status code inside category"),
t.Param("StatusCode", StatusCodeGeneric, "Status code inside category"),
)
37 changes: 3 additions & 34 deletions zigpy_zboss/types/named.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,42 +36,11 @@ class BindAddrMode(basic.enum8):
IEEE = 0x03


class ChannelEntry:
class ChannelEntry(Struct):
"""Class representing a channel entry."""

def __new__(cls, page=None, channel_mask=None):
"""Create a channel entry instance."""
instance = super().__new__(cls)

instance.page = basic.uint8_t(page)
instance.channel_mask = channel_mask

return instance

@classmethod
def deserialize(cls, data: bytes) -> "ChannelEntry":
"""Deserialize the object."""
page, data = basic.uint8_t.deserialize(data)
channel_mask, data = Channels.deserialize(data)

return cls(page=page, channel_mask=channel_mask), data

def serialize(self) -> bytes:
"""Serialize the object."""
return self.page.serialize() + self.channel_mask.serialize()

def __eq__(self, other):
"""Return True if channel_masks and pages are equal."""
if not isinstance(other, type(self)):
return NotImplemented

return self.page == other.page and \
self.channel_mask == other.channel_mask

def __repr__(self) -> str:
"""Return a representation of a channel entry."""
return f"{type(self).__name__}(page={self.page!r}," \
f" channels={self.channel_mask!r})"
page: basic.uint8_t
channel_mask: Channels


@dataclasses.dataclass(frozen=True)
Expand Down
40 changes: 3 additions & 37 deletions zigpy_zboss/uart.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import logging
import zigpy.serial
import async_timeout
import serial # type: ignore
import zigpy_zboss.config as conf
from zigpy_zboss import types as t
from zigpy_zboss.frames import Frame
Expand Down Expand Up @@ -82,48 +81,17 @@ def connection_made(

def connection_lost(self, exc: typing.Optional[Exception]) -> None:
"""Lost connection."""
LOGGER.debug("Connection has been lost: %r", exc)

if self._api is not None:
self._api.connection_lost(exc)
self.close()

# Do not try to reconnect if no exception occured.
if exc is None:
return

if not self._reset_flag:
SERIAL_LOGGER.warning(
f"Unexpected connection lost... {exc}")
self._reconnect_task = asyncio.create_task(self._reconnect())

async def _reconnect(self, timeout=RECONNECT_TIMEOUT):
"""Try to reconnect the disconnected serial port."""
SERIAL_LOGGER.info("Trying to reconnect to the NCP module!")
assert self._api is not None
loop = asyncio.get_running_loop()
async with async_timeout.timeout(timeout):
while True:
try:
_, proto = await zigpy.serial.create_serial_connection(
loop=loop,
protocol_factory=lambda: self,
url=self._port,
baudrate=self._baudrate,
xonxoff=(self._flow_control == "software"),
rtscts=(self._flow_control == "hardware"),
)
self._api._uart = proto
break
except serial.serialutil.SerialException:
await asyncio.sleep(0.1)

def close(self) -> None:
"""Close serial connection."""
self._buffer.clear()
self._ack_seq = 0
self._pack_seq = 0
if self._reconnect_task is not None:
self._reconnect_task.cancel()
self._reconnect_task = None

# Reset transport
if self._transport:
message = "Closing serial port"
Expand Down Expand Up @@ -275,8 +243,6 @@ async def connect(config: conf.ConfigType, api) -> ZbossNcpProtocol:
baudrate = config[conf.CONF_DEVICE_BAUDRATE]
flow_control = config[conf.CONF_DEVICE_FLOW_CONTROL]

LOGGER.debug("Connecting to %s at %s baud", port, baudrate)

_, protocol = await zigpy.serial.create_serial_connection(
loop=loop,
protocol_factory=lambda: ZbossNcpProtocol(config, api),
Expand Down
Loading

0 comments on commit 1d0179e

Please sign in to comment.