Skip to content

Commit

Permalink
Add GetSpeakers and SpeakersResult client / server messages
Browse files Browse the repository at this point in the history
  • Loading branch information
dln22 committed Oct 4, 2024
1 parent 78e60ac commit dfd1347
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 2 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.0.1
2.0.2
22 changes: 21 additions & 1 deletion speechmatics/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import json
import logging
import os
from typing import Dict, Union
from typing import Any, Dict, Optional, Union
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse

import httpx
Expand Down Expand Up @@ -519,6 +519,26 @@ def run_synchronously(self, *args, timeout=None, **kwargs):
# pylint: disable=no-value-for-parameter
asyncio.run(asyncio.wait_for(self.run(*args, **kwargs), timeout=timeout))

async def send_message(self, message_type: str, data: Optional[Any] = None):
"""
Sends a message to the server.
"""
if not self.session_running:
raise RuntimeError(f"Recognition session not running - not sending the message: {data}")

assert self.websocket, "WebSocket not connected"

data_ = data if data is not None else {}
serialized_data = json.dumps({"message": message_type, **data_})
try:
await self.websocket.send(serialized_data)
except websockets.exceptions.ConnectionClosedOK as exc:
LOGGER.error("WebSocket connection is closed. Cannot send the message.")
raise exc
except websockets.exceptions.ConnectionClosedError as exc:
LOGGER.error("WebSocket connection closed unexpectedly while sending the message.")
raise exc


async def _get_temp_token(api_key):
"""
Expand Down
7 changes: 7 additions & 0 deletions speechmatics/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,9 @@ class ClientMessageType(str, Enum):
SetRecognitionConfig = "SetRecognitionConfig"
"""Allows the client to re-configure the recognition session."""

GetSpeakers = "GetSpeakers"
"""Allows the client to request the speakers data."""


class ServerMessageType(str, Enum):
# pylint: disable=invalid-name
Expand Down Expand Up @@ -547,6 +550,10 @@ class ServerMessageType(str, Enum):
after the server has finished sending all :py:attr:`AddTranscript`
messages."""

SpeakersResult = "SpeakersResult"
"""Server response to :py:attr:`ClientMessageType.GetSpeakers`, containing
the speakers data."""

Info = "Info"
"""Indicates a generic info message."""

Expand Down
29 changes: 29 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import json
from collections import Counter
from unittest.mock import patch, MagicMock
from typing import Any

import asynctest
import pytest

from pytest_httpx import HTTPXMock
import websockets
from speechmatics import client
from speechmatics.batch_client import BatchClient
from speechmatics.exceptions import ForceEndSession
Expand Down Expand Up @@ -196,6 +198,33 @@ def test_run_synchronously_with_timeout(mock_server):
)


@pytest.mark.asyncio
@pytest.mark.parametrize(
"message_type, message_data",
[
pytest.param(ClientMessageType.GetSpeakers, None, id="Sending pure string"),
pytest.param("custom_message_type", None, id="Sending random number"),
pytest.param("custom_message_type", {"data": "some_data"}, id="Sending random number"),
],
)
async def test_send_message(mock_server, message_type: str, message_data: Any):
"""
Tests that the client.send_message method correctly sends message to the server.
"""
ws_client, _, _ = default_ws_client_setup(mock_server.url)
ws_client.session_running = True

async with websockets.connect(
mock_server.url,
ssl=ws_client.connection_settings.ssl_context,
ping_timeout=ws_client.connection_settings.ping_timeout_seconds,
max_size=None,
extra_headers=None,
) as ws_client.websocket:
await ws_client.send_message(message_type, message_data)
assert message_type in [msg_types["message"] for msg_types in mock_server.messages_received]


@pytest.mark.parametrize(
"client_message_type, expect_received_count, expect_sent_count",
[
Expand Down

0 comments on commit dfd1347

Please sign in to comment.