Skip to content

Commit

Permalink
Add support for Audio Events
Browse files Browse the repository at this point in the history
  • Loading branch information
HennerM committed Feb 9, 2024
1 parent 7792081 commit ab1ebad
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 11 deletions.
24 changes: 22 additions & 2 deletions speechmatics/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import sys
from dataclasses import dataclass
from socket import gaierror
from typing import List
from typing import Any, Dict, List

import httpx
import toml
Expand All @@ -26,6 +26,7 @@
from speechmatics.exceptions import JobNotFoundException, TranscriptionError
from speechmatics.helpers import _process_status_errors
from speechmatics.models import (
AudioEventsConfig,
AudioSettings,
AutoChaptersConfig,
BatchLanguageIdentificationConfig,
Expand Down Expand Up @@ -198,7 +199,7 @@ def get_transcription_config(
config = json.load(config_file)
else:
# Ensure "en" is the default language as to not break existing API behavior.
config = {"language": "en"}
config: Dict[str, Any] = {"language": "en"}

# transcription_config is flattened in the BatchTranscriptionConfig,
# so the config entry from JSON must be flattened here, otherwise the JSON entry would be ignored
Expand Down Expand Up @@ -341,6 +342,15 @@ def get_transcription_config(
if args_auto_chapters or auto_chapters_config is not None:
config["auto_chapters_config"] = AutoChaptersConfig()

audio_events_config = config.get("audio_events_config", None)
arg_audio_events = args.get("audio_events")
if audio_events_config or arg_audio_events is not None:
types = None
if audio_events_config and audio_events_config.get("types"):
types = audio_events_config.get("types")
config["audio_events_config"] = AudioEventsConfig(types)


if args["mode"] == "rt":
# pylint: disable=unexpected-keyword-arg
return TranscriptionConfig(**config)
Expand Down Expand Up @@ -448,6 +458,14 @@ def transcript_handler(message):
sys.stdout.write(f"{escape_seq}{plaintext}\n")
transcripts.text += plaintext

def audio_event_handler(message):
if print_json:
print(json.dumps(message))
return
event_name = message["event"].get("type", "").upper()
sys.stdout.write(f"{escape_seq}[{event_name}]\n")
transcripts.text += f"[{event_name}] "

def partial_translation_handler(message):
if print_json:
print(json.dumps(message))
Expand Down Expand Up @@ -480,6 +498,8 @@ def end_of_transcript_handler(_):
# print both transcription and translation messages (if json was requested)
# print translation (if text was requested then)
# print transcription (if text was requested without translation)

api.add_event_handler(ServerMessageType.AudioEventStarted, audio_event_handler)
if print_json:
if enable_partials or enable_translation_partials:
api.add_event_handler(
Expand Down
4 changes: 4 additions & 0 deletions speechmatics/cli_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,10 @@ def get_arg_parser():
help="Which type of diarization to use.",
)

rt_transcribe_command_parser.add_argument(
"--audio-events", action="store_true", help="Enable audio event detection and print events in square-brakcets to the console, e.g. [MUSIC]"
)

# Build our actual parsers.
mode_subparsers = parser.add_subparsers(title="Mode", dest="mode")

Expand Down
24 changes: 18 additions & 6 deletions speechmatics/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def __init__(
self.connection_settings.set_missing_values_from_config(UsageMode.RealTime)
self.websocket = None
self.transcription_config = None
self.translation_config = None

self.event_handlers = {x: [] for x in ServerMessageType}
self.middlewares = {x: [] for x in ClientMessageType}
Expand Down Expand Up @@ -135,12 +134,19 @@ def _set_recognition_config(self):
:py:attr:`speechmatics.models.ClientMessageType.SetRecognitionConfig`
message.
"""
assert self.transcription_config is not None
msg = {
"message": ClientMessageType.SetRecognitionConfig,
"transcription_config": self.transcription_config.as_config(),
}
if self.translation_config is not None:
msg["translation_config"] = self.translation_config.asdict()
if self.transcription_config.translation_config is not None:
msg[
"translation_config"
] = self.transcription_config.translation_config.asdict()
if self.transcription_config.audio_events_config is not None:
msg[
"audio_events_config"
] = self.transcription_config.audio_events_config.asdict()
self._call_middleware(ClientMessageType.SetRecognitionConfig, msg, False)
return msg

Expand All @@ -155,13 +161,20 @@ def _start_recognition(self, audio_settings):
:param audio_settings: Audio settings to use.
:type audio_settings: speechmatics.models.AudioSettings
"""
assert self.transcription_config is not None
msg = {
"message": ClientMessageType.StartRecognition,
"audio_format": audio_settings.asdict(),
"transcription_config": self.transcription_config.as_config(),
}
if self.translation_config is not None:
msg["translation_config"] = self.translation_config.asdict()
if self.transcription_config.translation_config is not None:
msg[
"translation_config"
] = self.transcription_config.translation_config.asdict()
if self.transcription_config.audio_events_config is not None:
msg[
"audio_events_config"
] = self.transcription_config.audio_events_config.asdict()
self.session_running = True
self._call_middleware(ClientMessageType.StartRecognition, msg, False)
LOGGER.debug(msg)
Expand Down Expand Up @@ -435,7 +448,6 @@ async def run(
consumer/producer tasks.
"""
self.transcription_config = transcription_config
self.translation_config = transcription_config.translation_config
self.seq_no = 0
self._language_pack_info = None
await self._init_synchronization_primitives()
Expand Down
28 changes: 25 additions & 3 deletions speechmatics/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ class BatchTranslationConfig(TranslationConfig):
class BatchLanguageIdentificationConfig:
"""Batch mode: Language identification config."""

expected_languages: List[str] = None
expected_languages: Optional[List[str]] = None
"""Expected languages for language identification"""


Expand All @@ -203,7 +203,7 @@ class SentimentAnalysisConfig:
class TopicDetectionConfig:
"""Defines topic detection parameters."""

topics: List[str] = None
topics: Optional[List[str]] = None
"""Optional list of topics for topic detection."""


Expand All @@ -212,6 +212,18 @@ class AutoChaptersConfig:
"""Auto Chapters config."""


@dataclass
class AudioEventsConfig:

types: Optional[List[str]]
"""Optional list of audio event types to detect."""

def asdict(self):
if self.types is None:
self.types = []
return asdict(self)


@dataclass(init=False)
class TranscriptionConfig(_TranscriptionConfig):
# pylint: disable=too-many-instance-attributes
Expand Down Expand Up @@ -254,12 +266,16 @@ class TranscriptionConfig(_TranscriptionConfig):
"""Indicates if partial translation, where words are produced
immediately, is enabled."""

translation_config: TranslationConfig = None
translation_config: Optional[TranslationConfig] = None
"""Optional configuration for translation."""

audio_events_config: Optional[AudioEventsConfig] = None
"""Optional configuration for audio events"""

def as_config(self):
dictionary = self.asdict()
dictionary.pop("translation_config", None)
dictionary.pop("audio_events_config", None)
dictionary.pop("enable_translation_partials", None)
enable_transcription_partials = dictionary.pop(
"enable_transcription_partials", False
Expand Down Expand Up @@ -504,6 +520,12 @@ class ServerMessageType(str, Enum):
AddTranscript = "AddTranscript"
"""Indicates the final transcript of a part of the audio."""

AudioEventStarted = "AudioEventStarted"
"""Indicates the start of an audio event."""

AudioEventEnded = "AudioEventEnded"
"""Indicates the end of an audio event."""

AddPartialTranslation = "AddPartialTranslation"
"""Indicates a partial translation, which is an incomplete translation that
is immediately produced and may change as more context becomes available.
Expand Down

0 comments on commit ab1ebad

Please sign in to comment.