diff --git a/roborock/data/b01_q10/b01_q10_containers.py b/roborock/data/b01_q10/b01_q10_containers.py index 0e805593..0e8ea3d1 100644 --- a/roborock/data/b01_q10/b01_q10_containers.py +++ b/roborock/data/b01_q10/b01_q10_containers.py @@ -1,6 +1,26 @@ -from ..containers import RoborockBase +"""Data container classes for Q10 B01 devices. + +Many of these classes use the `field(metadata={"dps": ...})` convention to map +fields to device Data Points (DPS). This metadata is utilized by the +`update_from_dps` helper in `roborock.devices.traits.b01.q10.common` to +automatically update objects from raw device responses. +""" +from dataclasses import dataclass, field +from ..containers import RoborockBase +from .b01_q10_code_mappings import ( + B01_Q10_DP, + YXBackType, + YXDeviceCleanTask, + YXDeviceState, + YXDeviceWorkMode, + YXFanLevel, + YXWaterLevel, +) + + +@dataclass class dpCleanRecord(RoborockBase): op: str result: int @@ -8,24 +28,28 @@ class dpCleanRecord(RoborockBase): data: list +@dataclass class dpMultiMap(RoborockBase): op: str result: int data: list +@dataclass class dpGetCarpet(RoborockBase): op: str result: int data: str +@dataclass class dpSelfIdentifyingCarpet(RoborockBase): op: str result: int data: str +@dataclass class dpNetInfo(RoborockBase): wifiName: str ipAdress: str @@ -33,6 +57,7 @@ class dpNetInfo(RoborockBase): signal: int +@dataclass class dpNotDisturbExpand(RoborockBase): disturb_dust_enable: int disturb_light: int @@ -40,14 +65,38 @@ class dpNotDisturbExpand(RoborockBase): disturb_voice: int +@dataclass class dpCurrentCleanRoomIds(RoborockBase): room_id_list: list +@dataclass class dpVoiceVersion(RoborockBase): version: int +@dataclass class dpTimeZone(RoborockBase): timeZoneCity: str timeZoneSec: int + + +@dataclass +class Q10Status(RoborockBase): + """Status for Q10 devices. + + Fields are mapped to DPS values using metadata. Objects of this class can be + automatically updated using the `update_from_dps` helper. + """ + + clean_time: int | None = field(default=None, metadata={"dps": B01_Q10_DP.CLEAN_TIME}) + clean_area: int | None = field(default=None, metadata={"dps": B01_Q10_DP.CLEAN_AREA}) + battery: int | None = field(default=None, metadata={"dps": B01_Q10_DP.BATTERY}) + status: YXDeviceState | None = field(default=None, metadata={"dps": B01_Q10_DP.STATUS}) + fan_level: YXFanLevel | None = field(default=None, metadata={"dps": B01_Q10_DP.FAN_LEVEL}) + water_level: YXWaterLevel | None = field(default=None, metadata={"dps": B01_Q10_DP.WATER_LEVEL}) + clean_count: int | None = field(default=None, metadata={"dps": B01_Q10_DP.CLEAN_COUNT}) + clean_mode: YXDeviceWorkMode | None = field(default=None, metadata={"dps": B01_Q10_DP.CLEAN_MODE}) + clean_task_type: YXDeviceCleanTask | None = field(default=None, metadata={"dps": B01_Q10_DP.CLEAN_TASK_TYPE}) + back_type: YXBackType | None = field(default=None, metadata={"dps": B01_Q10_DP.BACK_TYPE}) + cleaning_progress: int | None = field(default=None, metadata={"dps": B01_Q10_DP.CLEANING_PROGRESS}) diff --git a/roborock/data/containers.py b/roborock/data/containers.py index 57d5e6b2..1c1a3f8f 100644 --- a/roborock/data/containers.py +++ b/roborock/data/containers.py @@ -91,10 +91,10 @@ def from_dict(cls, data: dict[str, Any]): if not isinstance(data, dict): return None field_types = {field.name: field.type for field in dataclasses.fields(cls)} - result: dict[str, Any] = {} + normalized_data: dict[str, Any] = {} for orig_key, value in data.items(): key = _decamelize(orig_key) - if (field_type := field_types.get(key)) is None: + if field_types.get(key) is None: if (log_key := f"{cls.__name__}.{key}") not in RoborockBase._missing_logged: _LOGGER.debug( "Key '%s' (decamelized: '%s') not found in %s fields, skipping", @@ -104,6 +104,19 @@ def from_dict(cls, data: dict[str, Any]): ) RoborockBase._missing_logged.add(log_key) continue + normalized_data[key] = value + + result = RoborockBase.convert_dict(field_types, normalized_data) + return cls(**result) + + @staticmethod + def convert_dict(types_map: dict[Any, type], data: dict[Any, Any]) -> dict[Any, Any]: + """Convert a dictionary of values based on a schema map of types.""" + result: dict[Any, Any] = {} + for key, value in data.items(): + if key not in types_map: + continue + field_type = types_map[key] if value == "None" or value is None: result[key] = None continue @@ -124,7 +137,7 @@ def from_dict(cls, data: dict[str, Any]): _LOGGER.exception(f"Failed to convert {key} with value {value} to type {field_type}") continue - return cls(**result) + return result def as_dict(self) -> dict: return asdict( diff --git a/roborock/devices/device.py b/roborock/devices/device.py index ca1fbf14..29f1fd28 100644 --- a/roborock/devices/device.py +++ b/roborock/devices/device.py @@ -197,12 +197,14 @@ async def connect(self) -> None: if self._unsub: raise ValueError("Already connected to the device") unsub = await self._channel.subscribe(self._on_message) - if self.v1_properties is not None: - try: + try: + if self.v1_properties is not None: await self.v1_properties.discover_features() - except RoborockException: - unsub() - raise + elif self.b01_q10_properties is not None: + await self.b01_q10_properties.start() + except RoborockException: + unsub() + raise self._logger.info("Connected to device") self._unsub = unsub @@ -214,6 +216,8 @@ async def close(self) -> None: await self._connect_task except asyncio.CancelledError: pass + if self.b01_q10_properties is not None: + await self.b01_q10_properties.close() if self._unsub: self._unsub() self._unsub = None diff --git a/roborock/devices/rpc/b01_q10_channel.py b/roborock/devices/rpc/b01_q10_channel.py index a482e109..7c018d58 100644 --- a/roborock/devices/rpc/b01_q10_channel.py +++ b/roborock/devices/rpc/b01_q10_channel.py @@ -3,24 +3,41 @@ from __future__ import annotations import logging +from collections.abc import AsyncGenerator +from typing import Any from roborock.data.b01_q10.b01_q10_code_mappings import B01_Q10_DP from roborock.devices.transport.mqtt_channel import MqttChannel from roborock.exceptions import RoborockException -from roborock.protocols.b01_q10_protocol import ( - ParamsType, - encode_mqtt_payload, -) +from roborock.protocols.b01_q10_protocol import ParamsType, decode_rpc_response, encode_mqtt_payload _LOGGER = logging.getLogger(__name__) +async def stream_decoded_responses( + mqtt_channel: MqttChannel, +) -> AsyncGenerator[dict[B01_Q10_DP, Any], None]: + """Stream decoded DPS messages received via MQTT.""" + + async for response_message in mqtt_channel.subscribe_stream(): + try: + decoded_dps = decode_rpc_response(response_message) + except RoborockException as ex: + _LOGGER.debug( + "Failed to decode B01 RPC response: %s: %s", + response_message, + ex, + ) + continue + yield decoded_dps + + async def send_command( mqtt_channel: MqttChannel, command: B01_Q10_DP, params: ParamsType, ) -> None: - """Send a command on the MQTT channel, without waiting for a response""" + """Send a command on the MQTT channel, without waiting for a response.""" _LOGGER.debug("Sending B01 MQTT command: cmd=%s params=%s", command, params) roborock_message = encode_mqtt_payload(command, params) _LOGGER.debug("Sending MQTT message: %s", roborock_message) diff --git a/roborock/devices/traits/b01/q10/__init__.py b/roborock/devices/traits/b01/q10/__init__.py index ac897259..cddfaeba 100644 --- a/roborock/devices/traits/b01/q10/__init__.py +++ b/roborock/devices/traits/b01/q10/__init__.py @@ -1,15 +1,23 @@ """Traits for Q10 B01 devices.""" +import asyncio +import logging + +from roborock.data.b01_q10.b01_q10_code_mappings import B01_Q10_DP +from roborock.devices.rpc.b01_q10_channel import stream_decoded_responses from roborock.devices.traits import Trait from roborock.devices.transport.mqtt_channel import MqttChannel from .command import CommandTrait +from .status import StatusTrait from .vacuum import VacuumTrait __all__ = [ "Q10PropertiesApi", ] +_LOGGER = logging.getLogger(__name__) + class Q10PropertiesApi(Trait): """API for interacting with B01 devices.""" @@ -17,13 +25,43 @@ class Q10PropertiesApi(Trait): command: CommandTrait """Trait for sending commands to Q10 devices.""" + status: StatusTrait + """Trait for managing the status of Q10 devices.""" + vacuum: VacuumTrait """Trait for sending vacuum related commands to Q10 devices.""" def __init__(self, channel: MqttChannel) -> None: """Initialize the B01Props API.""" + self._channel = channel self.command = CommandTrait(channel) self.vacuum = VacuumTrait(self.command) + self.status = StatusTrait() + self._subscribe_task: asyncio.Task[None] | None = None + + async def start(self) -> None: + """Start any necessary subscriptions for the trait.""" + self._subscribe_task = asyncio.create_task(self._subscribe_loop()) + + async def close(self) -> None: + """Close any resources held by the trait.""" + if self._subscribe_task is not None: + self._subscribe_task.cancel() + try: + await self._subscribe_task + except asyncio.CancelledError: + pass + self._subscribe_task = None + + async def refresh(self) -> None: + """Refresh all traits.""" + await self.command.send(B01_Q10_DP.REQUEST_DPS, params={}) + + async def _subscribe_loop(self) -> None: + """Persistent loop to listen for status updates.""" + async for decoded_dps in stream_decoded_responses(self._channel): + _LOGGER.debug("Received Q10 status update: %s", decoded_dps) + self.status.update_from_dps(decoded_dps) def create(channel: MqttChannel) -> Q10PropertiesApi: diff --git a/roborock/devices/traits/b01/q10/common.py b/roborock/devices/traits/b01/q10/common.py new file mode 100644 index 00000000..62ef66ef --- /dev/null +++ b/roborock/devices/traits/b01/q10/common.py @@ -0,0 +1,40 @@ +"""Common utilities for Q10 traits. + +This module provides infrastructure for mapping Roborock Data Points (DPS) to +Python dataclass fields and handling the lifecycle of data updates from the +device. +""" + +import dataclasses +from typing import Any + +from roborock.data.b01_q10.b01_q10_code_mappings import B01_Q10_DP +from roborock.data.containers import RoborockBase + + +class DpsDataConverter: + """Utility to handle the transformation and merging of DPS data into models.""" + + def __init__(self, dps_type_map: dict[B01_Q10_DP, type], dps_field_map: dict[B01_Q10_DP, str]): + """Initialize the converter for a specific RoborockBase-derived class.""" + self._dps_type_map = dps_type_map + self._dps_field_map = dps_field_map + + @classmethod + def from_dataclass(cls, dataclass_type: type[RoborockBase]): + """Initialize the converter for a specific RoborockBase-derived class.""" + dps_type_map: dict[B01_Q10_DP, type] = {} + dps_field_map: dict[B01_Q10_DP, str] = {} + for field_obj in dataclasses.fields(dataclass_type): + if field_obj.metadata and "dps" in field_obj.metadata: + dps_id = field_obj.metadata["dps"] + dps_type_map[dps_id] = field_obj.type + dps_field_map[dps_id] = field_obj.name + return cls(dps_type_map, dps_field_map) + + def update_from_dps(self, target: RoborockBase, decoded_dps: dict[B01_Q10_DP, Any]) -> None: + """Convert and merge raw DPS data into the target object.""" + conversions = RoborockBase.convert_dict(self._dps_type_map, decoded_dps) + for dps_id, value in conversions.items(): + field_name = self._dps_field_map[dps_id] + setattr(target, field_name, value) diff --git a/roborock/devices/traits/b01/q10/status.py b/roborock/devices/traits/b01/q10/status.py new file mode 100644 index 00000000..0fe73221 --- /dev/null +++ b/roborock/devices/traits/b01/q10/status.py @@ -0,0 +1,20 @@ +"""Status trait for Q10 B01 devices.""" + +from roborock.data.b01_q10.b01_q10_containers import Q10Status + +from .common import DpsDataConverter + +_CONVERTER = DpsDataConverter.from_dataclass(Q10Status) + + +class StatusTrait(Q10Status): + """Trait for managing the status of Q10 Roborock devices. + + This is a thin wrapper around Q10Status that provides the Trait interface. + The current values reflect the most recently received data from the device. + New values can be requested through the `Q10PropertiesApi` refresh method. + """ + + def update_from_dps(self, decoded_dps: dict) -> None: + """Update the trait from raw DPS data.""" + _CONVERTER.update_from_dps(self, decoded_dps) diff --git a/roborock/devices/transport/mqtt_channel.py b/roborock/devices/transport/mqtt_channel.py index 498cef13..249633e1 100644 --- a/roborock/devices/transport/mqtt_channel.py +++ b/roborock/devices/transport/mqtt_channel.py @@ -1,7 +1,8 @@ """Modules for communicating with specific Roborock devices over MQTT.""" +import asyncio import logging -from collections.abc import Callable +from collections.abc import AsyncGenerator, Callable from roborock.callbacks import decoder_callback from roborock.data import HomeDataDevice, RRiot, UserData @@ -73,6 +74,17 @@ async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callab dispatch = decoder_callback(self._decoder, callback, _LOGGER) return await self._mqtt_session.subscribe(self._subscribe_topic, dispatch) + async def subscribe_stream(self) -> AsyncGenerator[RoborockMessage, None]: + """Subscribe to the device's message stream.""" + message_queue: asyncio.Queue[RoborockMessage] = asyncio.Queue() + unsub = await self.subscribe(message_queue.put_nowait) + try: + while True: + message = await message_queue.get() + yield message + finally: + unsub() + async def publish(self, message: RoborockMessage) -> None: """Publish a command message. diff --git a/tests/devices/rpc/test_b01_q10_channel.py b/tests/devices/rpc/test_b01_q10_channel.py new file mode 100644 index 00000000..74e6f224 --- /dev/null +++ b/tests/devices/rpc/test_b01_q10_channel.py @@ -0,0 +1,25 @@ +"""Tests for B01 Q10 channel functions.""" + +import json + +import pytest + +from roborock.data.b01_q10.b01_q10_code_mappings import B01_Q10_DP +from roborock.devices.rpc.b01_q10_channel import send_command +from tests.fixtures.channel_fixtures import FakeChannel + + +@pytest.fixture(name="fake_channel") +def fake_channel_fixture() -> FakeChannel: + return FakeChannel() + + +async def test_send_command(fake_channel: FakeChannel) -> None: + """Test sending a command without waiting for response.""" + await send_command(fake_channel, B01_Q10_DP.START_CLEAN, {"cmd": 1}) # type: ignore[arg-type] + + assert len(fake_channel.published_messages) == 1 + message = fake_channel.published_messages[0] + assert message.payload is not None + payload_data = json.loads(message.payload.decode()) + assert payload_data == {"dps": {"201": {"cmd": 1}}} diff --git a/tests/devices/traits/b01/q10/__init__.py b/tests/devices/traits/b01/q10/__init__.py new file mode 100644 index 00000000..78977420 --- /dev/null +++ b/tests/devices/traits/b01/q10/__init__.py @@ -0,0 +1 @@ +"""Tests for the Q10 B01 traits.""" diff --git a/tests/devices/traits/b01/q10/test_status.py b/tests/devices/traits/b01/q10/test_status.py new file mode 100644 index 00000000..a24f65c6 --- /dev/null +++ b/tests/devices/traits/b01/q10/test_status.py @@ -0,0 +1,124 @@ +"""Tests for the Q10 B01 status trait.""" + +import asyncio +import json +import pathlib +from collections.abc import AsyncGenerator +from typing import Any +from unittest.mock import AsyncMock, Mock + +import pytest + +from roborock.data.b01_q10.b01_q10_code_mappings import ( + YXDeviceCleanTask, + YXDeviceState, + YXFanLevel, +) +from roborock.devices.traits.b01.q10 import Q10PropertiesApi, create +from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol + +TEST_DATA_DIR = pathlib.Path("tests/protocols/testdata/b01_q10_protocol") + +TESTDATA_DP_STATUS_DP_CLEAN_TASK_TYPE = (TEST_DATA_DIR / "dpStatus-dpCleanTaskType.json").read_bytes() +TESTDATA_DP_REQUEST_DPS = (TEST_DATA_DIR / "dpRequetdps.json").read_bytes() + + +@pytest.fixture +def mock_channel(): + """Fixture for a mocked MQTT channel.""" + mock = AsyncMock() + return mock + + +@pytest.fixture +def message_queue() -> asyncio.Queue[RoborockMessage]: + """Fixture for a message queue used by the mock stream.""" + return asyncio.Queue() + + +@pytest.fixture +def mock_subscribe_stream(mock_channel: AsyncMock, message_queue: asyncio.Queue[RoborockMessage]) -> Mock: + """Fixture to mock the subscribe_stream method to yield from a queue.""" + + async def mock_stream() -> AsyncGenerator[RoborockMessage, None]: + while True: + yield await message_queue.get() + + mock = Mock(return_value=mock_stream()) + mock_channel.subscribe_stream = mock + return mock + + +@pytest.fixture +async def q10_api(mock_channel: AsyncMock, mock_subscribe_stream: Mock) -> AsyncGenerator[Q10PropertiesApi, None]: + """Fixture to create and manage the Q10PropertiesApi.""" + api = create(mock_channel) + await api.start() + yield api + await api.close() + + +def build_message(payload: bytes) -> RoborockMessage: + """Helper to build a RoborockMessage for testing.""" + return RoborockMessage( + protocol=RoborockMessageProtocol.RPC_RESPONSE, + payload=payload, + version=b"B01", + ) + + +async def wait_for_attribute_value(obj: Any, attribute: str, value: Any, timeout: float = 2.0) -> None: + """Wait for an attribute on an object to reach a specific value.""" + for _ in range(int(timeout / 0.1)): + if getattr(obj, attribute) == value: + return + await asyncio.sleep(0.1) + pytest.fail(f"Timeout waiting for {attribute} to become {value} on {obj}") + + +async def test_status_trait_streaming( + q10_api: Q10PropertiesApi, + message_queue: asyncio.Queue[RoborockMessage], +) -> None: + """Test that the StatusTrait updates its state from streaming messages.""" + message = build_message(TESTDATA_DP_STATUS_DP_CLEAN_TASK_TYPE) + + assert q10_api.status.status is None + assert q10_api.status.clean_task_type is None + + message_queue.put_nowait(message) + + await wait_for_attribute_value(q10_api.status, "status", YXDeviceState.CHARGING_STATE) + + assert q10_api.status.status == YXDeviceState.CHARGING_STATE + assert q10_api.status.clean_task_type == YXDeviceCleanTask.IDLE + + +async def test_status_trait_refresh( + q10_api: Q10PropertiesApi, + mock_channel: AsyncMock, + message_queue: asyncio.Queue[RoborockMessage], +) -> None: + """Test that the StatusTrait sends a refresh command and updates state.""" + assert q10_api.status.battery is None + assert q10_api.status.status is None + assert q10_api.status.fan_level is None + + message = build_message(TESTDATA_DP_REQUEST_DPS) + + await q10_api.refresh() + mock_channel.publish.assert_called_once() + sent_message = mock_channel.publish.call_args[0][0] + assert sent_message.protocol == RoborockMessageProtocol.RPC_REQUEST + data = json.loads(sent_message.payload) + assert data + assert data.get("dps") + assert data.get("dps").get("102") == {} + + message_queue.put_nowait(message) + + await wait_for_attribute_value(q10_api.status, "battery", 100) + + assert q10_api.status.battery == 100 + assert q10_api.status.status == YXDeviceState.CHARGING_STATE + assert q10_api.status.fan_level == YXFanLevel.NORMAL diff --git a/tests/devices/traits/b01/q10/test_vacuum.py b/tests/devices/traits/b01/q10/test_vacuum.py index c8bdb3a4..af908cdc 100644 --- a/tests/devices/traits/b01/q10/test_vacuum.py +++ b/tests/devices/traits/b01/q10/test_vacuum.py @@ -6,6 +6,7 @@ from roborock.data.b01_q10.b01_q10_code_mappings import YXCleanType, YXFanLevel from roborock.devices.traits.b01.q10 import Q10PropertiesApi +from roborock.devices.traits.b01.q10.status import StatusTrait from roborock.devices.traits.b01.q10.vacuum import VacuumTrait from tests.fixtures.channel_fixtures import FakeChannel @@ -52,3 +53,15 @@ async def test_vacuum_commands( assert message.payload payload_data = json.loads(message.payload.decode()) assert payload_data == {"dps": expected_payload} + + +def test_q10_api_has_status_trait(q10_api: Q10PropertiesApi) -> None: + """Test that Q10PropertiesApi exposes StatusTrait.""" + assert hasattr(q10_api, "status") + assert isinstance(q10_api.status, StatusTrait) + + +def test_q10_api_has_vacuum_trait(q10_api: Q10PropertiesApi) -> None: + """Test that Q10PropertiesApi exposes VacuumTrait.""" + assert hasattr(q10_api, "vacuum") + assert isinstance(q10_api.vacuum, VacuumTrait)