From 09fe819a9df4a13686c189c9fe47146d8be4fefc Mon Sep 17 00:00:00 2001 From: Daniel Nurkowski Date: Wed, 2 Oct 2024 20:52:42 +0000 Subject: [PATCH] Add GetSpeakers and SpeakersResult client / server messages --- VERSION | 2 +- speechmatics/client.py | 22 +++++++++++++++++++++- speechmatics/models.py | 7 +++++++ tests/test_client.py | 29 +++++++++++++++++++++++++++++ 4 files changed, 58 insertions(+), 2 deletions(-) diff --git a/VERSION b/VERSION index 38f77a6..e9307ca 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.0.1 +2.0.2 diff --git a/speechmatics/client.py b/speechmatics/client.py index 1c886dd..49fa07a 100644 --- a/speechmatics/client.py +++ b/speechmatics/client.py @@ -9,7 +9,7 @@ import json import logging import os -from typing import Dict, Union +from typing import Any, Dict, Optional, Union from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse import httpx @@ -519,6 +519,26 @@ def run_synchronously(self, *args, timeout=None, **kwargs): # pylint: disable=no-value-for-parameter asyncio.run(asyncio.wait_for(self.run(*args, **kwargs), timeout=timeout)) + async def send_message(self, message_type: str, data: Optional[Any] = None): + """ + Sends a message to the server. + """ + if not self.session_running: + raise RuntimeError("Recognition session not running. Cannot send the message.") + + assert self.websocket, "WebSocket not connected" + + data_ = data if data is not None else {} + serialized_data = json.dumps({"message": message_type, **data_}) + try: + await self.websocket.send(serialized_data) + except websockets.exceptions.ConnectionClosedOK as exc: + LOGGER.error("WebSocket connection is closed. Cannot send the message.") + raise exc + except websockets.exceptions.ConnectionClosedError as exc: + LOGGER.error("WebSocket connection closed unexpectedly while sending the message.") + raise exc + async def _get_temp_token(api_key): """ diff --git a/speechmatics/models.py b/speechmatics/models.py index 07cf7ac..71384a8 100644 --- a/speechmatics/models.py +++ b/speechmatics/models.py @@ -507,6 +507,9 @@ class ClientMessageType(str, Enum): SetRecognitionConfig = "SetRecognitionConfig" """Allows the client to re-configure the recognition session.""" + GetSpeakers = "GetSpeakers" + """Allows the client to request the speakers data.""" + class ServerMessageType(str, Enum): # pylint: disable=invalid-name @@ -547,6 +550,10 @@ class ServerMessageType(str, Enum): after the server has finished sending all :py:attr:`AddTranscript` messages.""" + SpeakersResult = "SpeakersResult" + """Server response to :py:attr:`ClientMessageType.GetSpeakers`, containing + the speakers data.""" + Info = "Info" """Indicates a generic info message.""" diff --git a/tests/test_client.py b/tests/test_client.py index bcd5f64..83ae761 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -5,11 +5,13 @@ import json from collections import Counter from unittest.mock import patch, MagicMock +from typing import Any import asynctest import pytest from pytest_httpx import HTTPXMock +import websockets from speechmatics import client from speechmatics.batch_client import BatchClient from speechmatics.exceptions import ForceEndSession @@ -196,6 +198,33 @@ def test_run_synchronously_with_timeout(mock_server): ) +@pytest.mark.asyncio +@pytest.mark.parametrize( + "message_type, message_data", + [ + pytest.param(ClientMessageType.GetSpeakers, None, id="Sending pure string"), + pytest.param("custom_message_type", None, id="Sending random number"), + pytest.param("custom_message_type", {"data": "some_data"}, id="Sending random number"), + ], +) +async def test_send_message(mock_server, message_type: str, message_data: Any): + """ + Tests that the client.send_message method correctly sends message to the server. + """ + ws_client, _, _ = default_ws_client_setup(mock_server.url) + ws_client.session_running = True + + async with websockets.connect( + mock_server.url, + ssl=ws_client.connection_settings.ssl_context, + ping_timeout=ws_client.connection_settings.ping_timeout_seconds, + max_size=None, + extra_headers=None, + ) as ws_client.websocket: + await ws_client.send_message(message_type, message_data) + assert message_type in [msg_types["message"] for msg_types in mock_server.messages_received] + + @pytest.mark.parametrize( "client_message_type, expect_received_count, expect_sent_count", [