Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add connection test feature to assist_satellite #126256

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions homeassistant/components/assist_satellite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.typing import ConfigType

from .connection_test import ConnectionTestView
from .const import DOMAIN, AssistSatelliteEntityFeature
from .entity import (
AssistSatelliteConfiguration,
Expand Down Expand Up @@ -56,6 +57,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
[AssistSatelliteEntityFeature.ANNOUNCE],
)
async_register_websocket_api(hass)
hass.http.register_view(ConnectionTestView())

return True

Expand Down
Binary file not shown.
36 changes: 36 additions & 0 deletions homeassistant/components/assist_satellite/connection_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Assist satellite connection test."""

import logging
from pathlib import Path

from aiohttp import web

from homeassistant.components.http import KEY_HASS, HomeAssistantView
from homeassistant.helpers.dispatcher import async_dispatcher_send

_LOGGER = logging.getLogger(__name__)

CONNECTION_TEST_CONTENT_TYPE = "audio/mpeg"
CONNECTION_TEST_FILENAME = "connection_test.mp3"
CONNECTION_TEST_SIGNAL = "assist_satellite.connection_test_{}"
CONNECTION_TEST_URL_BASE = "/api/assist_satellite/connection_test"


class ConnectionTestView(HomeAssistantView):
"""View to serve an audio sample for connection test."""

requires_auth = False
url = CONNECTION_TEST_URL_BASE + "/{connection_id}"
synesthesiam marked this conversation as resolved.
Show resolved Hide resolved
name = "api:assist_satellite_connection_test"

async def get(self, request: web.Request, connection_id: str) -> web.Response:
"""Start a get request."""
_LOGGER.debug("Request for connection test with id %s", connection_id)
_LOGGER.warning("Request for connection test with id %s", connection_id)
synesthesiam marked this conversation as resolved.
Show resolved Hide resolved

hass = request.app[KEY_HASS]
audio_path = Path(__file__).parent / CONNECTION_TEST_FILENAME
audio_data = await hass.async_add_executor_job(audio_path.read_bytes)

async_dispatcher_send(hass, CONNECTION_TEST_SIGNAL.format(connection_id))
return web.Response(body=audio_data, content_type=CONNECTION_TEST_CONTENT_TYPE)
2 changes: 1 addition & 1 deletion homeassistant/components/assist_satellite/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"domain": "assist_satellite",
"name": "Assist Satellite",
"codeowners": ["@home-assistant/core", "@synesthesiam"],
"dependencies": ["assist_pipeline", "stt", "tts"],
"dependencies": ["assist_pipeline", "http", "stt", "tts"],
"documentation": "https://www.home-assistant.io/integrations/assist_satellite",
"integration_type": "entity",
"quality_scale": "internal"
Expand Down
67 changes: 66 additions & 1 deletion homeassistant/components/assist_satellite/websocket_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Assist satellite Websocket API."""

import asyncio
from dataclasses import asdict, replace
from typing import Any

Expand All @@ -9,18 +10,24 @@
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.util import uuid as uuid_util

from .const import DOMAIN
from .connection_test import CONNECTION_TEST_SIGNAL, CONNECTION_TEST_URL_BASE
from .const import DOMAIN, AssistSatelliteEntityFeature
from .entity import AssistSatelliteEntity

CONNECTION_TEST_TIMEOUT = 30


@callback
def async_register_websocket_api(hass: HomeAssistant) -> None:
"""Register the websocket API."""
websocket_api.async_register_command(hass, websocket_intercept_wake_word)
websocket_api.async_register_command(hass, websocket_get_configuration)
websocket_api.async_register_command(hass, websocket_set_wake_words)
websocket_api.async_register_command(hass, websocket_test_connection)


@callback
Expand Down Expand Up @@ -143,3 +150,61 @@ async def websocket_set_wake_words(
replace(config, active_wake_words=actual_ids)
)
connection.send_result(msg["id"])


@callback
synesthesiam marked this conversation as resolved.
Show resolved Hide resolved
@websocket_api.websocket_command(
{
vol.Required("type"): "assist_satellite/test_connection",
vol.Required("entity_id"): cv.entity_domain(DOMAIN),
}
)
@websocket_api.async_response
async def websocket_test_connection(
hass: HomeAssistant,
connection: websocket_api.connection.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Intercept the next wake word from a satellite."""
synesthesiam marked this conversation as resolved.
Show resolved Hide resolved
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
satellite = component.get_entity(msg["entity_id"])
if satellite is None:
connection.send_error(
msg["id"], websocket_api.ERR_NOT_FOUND, "Entity not found"
)
return
if not (satellite.supported_features or 0) & AssistSatelliteEntityFeature.ANNOUNCE:
connection.send_error(
msg["id"],
websocket_api.ERR_NOT_SUPPORTED,
"Entity does not support announce",
)
return

# Send response indicating the command is accepted
connection.send_result(msg["id"])

# Announce and wait for event
connection_id = uuid_util.random_uuid_hex()
connection_test_event = asyncio.Event()

@callback
def on_connection_test_signal() -> None:
"""Wrap set method in function decorated with callback."""
connection_test_event.set()

connection.subscriptions[msg["id"]] = async_dispatcher_connect(
hass,
CONNECTION_TEST_SIGNAL.format(connection_id),
on_connection_test_signal,
)
await satellite.async_internal_announce(
media_id=f"{CONNECTION_TEST_URL_BASE}/{connection_id}"
)

try:
async with asyncio.timeout(CONNECTION_TEST_TIMEOUT):
await connection_test_event.wait()
connection.send_result(msg["id"], {"status": "connection_test_successful"})
except TimeoutError:
connection.send_result(msg["id"], {"status": "connection_test_timed_out"})
2 changes: 1 addition & 1 deletion tests/components/assist_satellite/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class MockAssistSatellite(AssistSatelliteEntity):
def __init__(self) -> None:
"""Initialize the mock entity."""
self.events = []
self.announcements = []
self.announcements: list[tuple[str, str]] = []
self.config = AssistSatelliteConfiguration(
available_wake_words=[
AssistSatelliteWakeWord(
Expand Down
141 changes: 140 additions & 1 deletion tests/components/assist_satellite/test_websocket_api.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
"""Test WebSocket API."""

import asyncio
from http import HTTPStatus
from unittest.mock import patch

from freezegun.api import FrozenDateTimeFactory
import pytest

from homeassistant.components.assist_pipeline import PipelineStage
from homeassistant.components.assist_satellite.websocket_api import (
CONNECTION_TEST_TIMEOUT,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant

from . import ENTITY_ID
from .conftest import MockAssistSatellite

from tests.common import MockUser
from tests.typing import WebSocketGenerator
from tests.typing import ClientSessionGenerator, WebSocketGenerator


async def test_intercept_wake_word(
Expand Down Expand Up @@ -385,3 +390,137 @@ async def test_set_wake_words_bad_id(
"code": "not_supported",
"message": "Wake word id is not supported: abcd",
}


async def test_connection_test(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
hass_client: ClientSessionGenerator,
) -> None:
"""Test connection test."""
ws_client = await hass_ws_client(hass)

await ws_client.send_json_auto_id(
{
"type": "assist_satellite/test_connection",
"entity_id": ENTITY_ID,
}
)

response = await ws_client.receive_json()
assert response["success"]
assert response["result"] is None

for _ in range(3):
await asyncio.sleep(0)

assert len(entity.announcements) == 1
assert entity.announcements[0][0] == ""
announcement_media_id = entity.announcements[0][1]
hass_url = "http://10.10.10.10:8123"
assert announcement_media_id.startswith(
f"{hass_url}/api/assist_satellite/connection_test/"
)

# Fake satellite fetches the URL
client = await hass_client()
resp = await client.get(announcement_media_id[len(hass_url) :])
assert resp.status == HTTPStatus.OK

response = await ws_client.receive_json()
assert response["success"]
assert response["result"] == {"status": "connection_test_successful"}


async def test_connection_test_timeout(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
hass_client: ClientSessionGenerator,
freezer: FrozenDateTimeFactory,
) -> None:
"""Test connection test timeout."""
ws_client = await hass_ws_client(hass)

await ws_client.send_json_auto_id(
{
"type": "assist_satellite/test_connection",
"entity_id": ENTITY_ID,
}
)

response = await ws_client.receive_json()
assert response["success"]
assert response["result"] is None

for _ in range(3):
await asyncio.sleep(0)

assert len(entity.announcements) == 1
assert entity.announcements[0][0] == ""
announcement_media_id = entity.announcements[0][1]
hass_url = "http://10.10.10.10:8123"
assert announcement_media_id.startswith(
f"{hass_url}/api/assist_satellite/connection_test/"
)

freezer.tick(CONNECTION_TEST_TIMEOUT + 1)

# Timeout
response = await ws_client.receive_json()
assert response["success"]
assert response["result"] == {"status": "connection_test_timed_out"}


async def test_connection_test_invalid_satellite(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test connection test with unknown entity id."""
ws_client = await hass_ws_client(hass)

await ws_client.send_json_auto_id(
{
"type": "assist_satellite/test_connection",
"entity_id": "assist_satellite.invalid",
}
)
response = await ws_client.receive_json()

assert not response["success"]
assert response["error"] == {
"code": "not_found",
"message": "Entity not found",
}


async def test_connection_test_timeout_announcement_unsupported(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test connection test entity which does not support announce."""
ws_client = await hass_ws_client(hass)

# Disable announce support
entity.supported_features = 0

await ws_client.send_json_auto_id(
{
"type": "assist_satellite/test_connection",
"entity_id": ENTITY_ID,
}
)
response = await ws_client.receive_json()

assert not response["success"]
assert response["error"] == {
"code": "not_supported",
"message": "Entity does not support announce",
}
Loading