diff --git a/CLAUDE.md b/CLAUDE.md index c169005..8b2d0f7 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -12,10 +12,11 @@ Claudio is a messaging-to-Claude Code bridge. It supports both Telegram and What - `lib/cli.py` — CLI dispatch logic. Uses `sys.argv` (not argparse) with lazy imports per command for fast startup. - `lib/config.py` — `ClaudioConfig` class for global config (`~/.claudio/service.env`), `BotConfig` class for per-bot config (`~/.claudio/bots//bot.env`). Functions: `parse_env_file()`, `save_bot_env()`, `save_model()`. Auto-migrates single-bot to multi-bot layout. - `lib/server.py` — Python HTTP server (stdlib `http.server`), listens on port 8421, routes POST `/telegram/webhook` and POST/GET `/whatsapp/webhook`. Multi-bot dispatch: matches Telegram webhooks via secret-token header, WhatsApp webhooks via HMAC-SHA256 signature verification. Supports dual-platform bots (same bot_id serving both Telegram and WhatsApp). Loads bot registry from `~/.claudio/bots/*/bot.env`. SIGHUP handler for hot-reload. Composite queue keys (`bot_id:chat_id` for Telegram, `bot_id:phone_number` for WhatsApp) for per-bot, per-user message isolation. `/reload` endpoint (requires `MANAGEMENT_SECRET` authentication). Webhook processing delegates to `lib/handlers.py`. -- `lib/handlers.py` — Webhook orchestrator: parses webhooks, runs unified message pipeline (media download, voice transcription, Claude invocation, response delivery). Entry point: `process_webhook()`. +- `lib/handlers.py` — Webhook orchestrator: parses webhooks, runs unified message pipeline (media download, voice transcription, Claude invocation, response delivery). Speech provider dispatch (`_stt_transcribe()`, `_tts_convert()`) selects ElevenLabs or Speechmatics based on `SPEECH_PROVIDER` config. Entry point: `process_webhook()`. - `lib/telegram_api.py` — `TelegramClient` class: send messages (4096-char chunking with Markdown fallback), send voice, typing indicator, reactions, file downloads with magic byte validation. Retry on 429/5xx. - `lib/whatsapp_api.py` — `WhatsAppClient` class: send messages (4096-char chunking), send audio, mark read, media downloads (two-step URL resolution). Retry on 429/5xx. - `lib/elevenlabs.py` — ElevenLabs TTS (`tts_convert()`) and STT (`stt_transcribe()`). Stdlib only. +- `lib/speechmatics.py` — Speechmatics TTS (`tts_convert()`) and STT (`stt_transcribe()`). TTS returns WAV audio; STT uses async batch API (submit job, poll, fetch transcript). Stdlib only. - `lib/claude_runner.py` — Claude CLI invocation with `start_new_session=True`, MCP config, JSON output parsing, token usage persistence. Returns `ClaudeResult` namedtuple. - `lib/setup.py` — Interactive setup wizards: `telegram_setup()`, `whatsapp_setup()`, `bot_setup()`. Validates credentials via API calls, polls for Telegram `/start`, generates secrets, saves config. - `lib/service.py` — Service management: systemd/launchd unit generation, symlink install, cloudflared tunnel setup, webhook registration with retry, cron health-check install, Claude hooks install, `service_status()`, `service_restart()`, `service_update()`, `service_install()`, `service_uninstall()`. @@ -32,4 +33,4 @@ Claudio is a messaging-to-Claude Code bridge. It supports both Telegram and What Run locally with `./claudio start`. Requires `python3`, `sqlite3`, `cloudflared`, and `claude` CLI. The memory system optionally requires the `fastembed` Python package (degrades gracefully without it). -**Tests:** `python3 -m pytest tests/` — 640 tests covering all modules (config, util, setup, service, backup, health_check, server, handlers, telegram_api, whatsapp_api, elevenlabs, claude_runner, cli). +**Tests:** `python3 -m pytest tests/` — 673 tests covering all modules (config, util, setup, service, backup, health_check, server, handlers, telegram_api, whatsapp_api, elevenlabs, speechmatics, claude_runner, cli). diff --git a/lib/config.py b/lib/config.py index dbcc30e..874b94a 100644 --- a/lib/config.py +++ b/lib/config.py @@ -69,9 +69,14 @@ class BotConfig: 'whatsapp_app_secret', 'whatsapp_verify_token', 'whatsapp_phone_number', # Common 'model', 'max_history_lines', + # Speech provider selection (from service.env) + 'speech_provider', # ElevenLabs (from service.env) 'elevenlabs_api_key', 'elevenlabs_voice_id', 'elevenlabs_model', 'elevenlabs_stt_model', + # Speechmatics (from service.env) + 'speechmatics_api_key', 'speechmatics_voice_id', + 'speechmatics_stt_region', # Memory (from service.env) 'memory_enabled', # Database @@ -84,9 +89,12 @@ def __init__(self, bot_id, bot_dir=None, whatsapp_app_secret='', whatsapp_verify_token='', whatsapp_phone_number='', model='haiku', max_history_lines=100, + speech_provider='elevenlabs', elevenlabs_api_key='', elevenlabs_voice_id='iP95p4xoKVk53GoZ742B', elevenlabs_model='eleven_multilingual_v2', elevenlabs_stt_model='scribe_v1', + speechmatics_api_key='', speechmatics_voice_id='sarah', + speechmatics_stt_region='eu1', memory_enabled=True, db_file=''): self.bot_id = bot_id self.bot_dir = bot_dir or '' @@ -100,10 +108,14 @@ def __init__(self, bot_id, bot_dir=None, self.whatsapp_phone_number = whatsapp_phone_number self.model = model self.max_history_lines = int(max_history_lines) + self.speech_provider = speech_provider self.elevenlabs_api_key = elevenlabs_api_key self.elevenlabs_voice_id = elevenlabs_voice_id self.elevenlabs_model = elevenlabs_model self.elevenlabs_stt_model = elevenlabs_stt_model + self.speechmatics_api_key = speechmatics_api_key + self.speechmatics_voice_id = speechmatics_voice_id + self.speechmatics_stt_region = speechmatics_stt_region self.memory_enabled = memory_enabled self.db_file = db_file or (os.path.join(bot_dir, 'history.db') if bot_dir else '') @@ -135,11 +147,17 @@ def from_bot_config(cls, bot_id, bot_config, service_env=None): # Common model=bot_config.get('model', 'haiku'), max_history_lines=bot_config.get('max_history_lines', '100'), + # Speech provider (from service.env) + speech_provider=svc.get('SPEECH_PROVIDER', 'elevenlabs'), # ElevenLabs (from service.env) elevenlabs_api_key=svc.get('ELEVENLABS_API_KEY', ''), elevenlabs_voice_id=svc.get('ELEVENLABS_VOICE_ID', 'iP95p4xoKVk53GoZ742B'), elevenlabs_model=svc.get('ELEVENLABS_MODEL', 'eleven_multilingual_v2'), elevenlabs_stt_model=svc.get('ELEVENLABS_STT_MODEL', 'scribe_v1'), + # Speechmatics (from service.env) + speechmatics_api_key=svc.get('SPEECHMATICS_API_KEY', ''), + speechmatics_voice_id=svc.get('SPEECHMATICS_VOICE_ID', 'sarah'), + speechmatics_stt_region=svc.get('SPEECHMATICS_STT_REGION', 'eu1'), # Memory memory_enabled=svc.get('MEMORY_ENABLED', '1') == '1', db_file=os.path.join(bot_dir, 'history.db') if bot_dir else '', @@ -181,11 +199,17 @@ def from_env_files(cls, bot_id, claudio_path=None): # Common model=bot_env.get('MODEL', 'haiku'), max_history_lines=bot_env.get('MAX_HISTORY_LINES', '100'), + # Speech provider (from service.env) + speech_provider=svc.get('SPEECH_PROVIDER', 'elevenlabs'), # ElevenLabs (from service.env) elevenlabs_api_key=svc.get('ELEVENLABS_API_KEY', ''), elevenlabs_voice_id=svc.get('ELEVENLABS_VOICE_ID', 'iP95p4xoKVk53GoZ742B'), elevenlabs_model=svc.get('ELEVENLABS_MODEL', 'eleven_multilingual_v2'), elevenlabs_stt_model=svc.get('ELEVENLABS_STT_MODEL', 'scribe_v1'), + # Speechmatics (from service.env) + speechmatics_api_key=svc.get('SPEECHMATICS_API_KEY', ''), + speechmatics_voice_id=svc.get('SPEECHMATICS_VOICE_ID', 'sarah'), + speechmatics_stt_region=svc.get('SPEECHMATICS_STT_REGION', 'eu1'), # Memory memory_enabled=svc.get('MEMORY_ENABLED', '1') == '1', ) @@ -277,8 +301,12 @@ class ClaudioConfig: # Keys managed in service.env (global, not per-bot) _MANAGED_KEYS = [ 'PORT', 'WEBHOOK_URL', 'TUNNEL_NAME', 'TUNNEL_HOSTNAME', - 'WEBHOOK_RETRY_DELAY', 'ELEVENLABS_API_KEY', 'ELEVENLABS_VOICE_ID', - 'ELEVENLABS_MODEL', 'ELEVENLABS_STT_MODEL', 'MEMORY_ENABLED', + 'WEBHOOK_RETRY_DELAY', 'SPEECH_PROVIDER', + 'ELEVENLABS_API_KEY', 'ELEVENLABS_VOICE_ID', + 'ELEVENLABS_MODEL', 'ELEVENLABS_STT_MODEL', + 'SPEECHMATICS_API_KEY', 'SPEECHMATICS_VOICE_ID', + 'SPEECHMATICS_STT_REGION', + 'MEMORY_ENABLED', 'MEMORY_EMBEDDING_MODEL', 'MEMORY_CONSOLIDATION_MODEL', ] @@ -295,10 +323,14 @@ class ClaudioConfig: 'TUNNEL_NAME': '', 'TUNNEL_HOSTNAME': '', 'WEBHOOK_RETRY_DELAY': '60', + 'SPEECH_PROVIDER': 'elevenlabs', 'ELEVENLABS_API_KEY': '', 'ELEVENLABS_VOICE_ID': 'iP95p4xoKVk53GoZ742B', 'ELEVENLABS_MODEL': 'eleven_multilingual_v2', 'ELEVENLABS_STT_MODEL': 'scribe_v1', + 'SPEECHMATICS_API_KEY': '', + 'SPEECHMATICS_VOICE_ID': 'sarah', + 'SPEECHMATICS_STT_REGION': 'eu1', 'MEMORY_ENABLED': '1', 'MEMORY_EMBEDDING_MODEL': 'sentence-transformers/all-MiniLM-L6-v2', 'MEMORY_CONSOLIDATION_MODEL': 'haiku', diff --git a/lib/handlers.py b/lib/handlers.py index de2ed5d..39997c5 100644 --- a/lib/handlers.py +++ b/lib/handlers.py @@ -25,7 +25,8 @@ ) from lib.telegram_api import TelegramClient from lib.whatsapp_api import WhatsAppClient -from lib.elevenlabs import tts_convert, stt_transcribe +from lib.elevenlabs import tts_convert as elevenlabs_tts, stt_transcribe as elevenlabs_stt +from lib.speechmatics import tts_convert as speechmatics_tts, stt_transcribe as speechmatics_stt from lib.claude_runner import run_claude # -- Constants -- @@ -384,6 +385,46 @@ def _memory_consolidate(): pass +# -- Speech provider dispatch -- + +def _get_speech_api_key(config): + """Return the API key for the configured speech provider, or ''.""" + if config.speech_provider == 'speechmatics': + return config.speechmatics_api_key + return config.elevenlabs_api_key + + +def _stt_transcribe(audio_path, config): + """Transcribe audio using the configured speech provider.""" + if config.speech_provider == 'speechmatics': + return speechmatics_stt( + audio_path, + config.speechmatics_api_key, + region=config.speechmatics_stt_region, + ) + return elevenlabs_stt( + audio_path, + config.elevenlabs_api_key, + model=config.elevenlabs_stt_model, + ) + + +def _tts_convert(text, output_path, config): + """Convert text to speech using the configured speech provider.""" + if config.speech_provider == 'speechmatics': + return speechmatics_tts( + text, output_path, + config.speechmatics_api_key, + config.speechmatics_voice_id, + ) + return elevenlabs_tts( + text, output_path, + config.elevenlabs_api_key, + config.elevenlabs_voice_id, + config.elevenlabs_model, + ) + + # -- Main entry point -- def process_webhook(body, bot_id, platform, bot_config_dict): @@ -562,10 +603,14 @@ def _process_message(msg, text, config, client, platform, bot_id): # -- Voice transcription -- if msg.has_voice: - if not config.elevenlabs_api_key: + stt_api_key = _get_speech_api_key(config) + if not stt_api_key: + provider = config.speech_provider or 'elevenlabs' + key_name = ('SPEECHMATICS_API_KEY' if provider == 'speechmatics' + else 'ELEVENLABS_API_KEY') client.send_message( msg.chat_id, - f"_{voice_label.capitalize()} messages require ELEVENLABS_API_KEY " + f"_{voice_label.capitalize()} messages require {key_name} " f"to be configured._", reply_to=msg.message_id, ) @@ -593,11 +638,7 @@ def _process_message(msg, text, config, client, platform, bot_id): ) return - transcription = stt_transcribe( - voice_file, - config.elevenlabs_api_key, - model=config.elevenlabs_stt_model, - ) + transcription = _stt_transcribe(voice_file, config) if not transcription: client.send_message( msg.chat_id, @@ -743,7 +784,7 @@ def _typing_loop(): # -- Response delivery -- if response: - if has_voice and config.elevenlabs_api_key: + if has_voice and _get_speech_api_key(config): _deliver_voice_response( response, config, client, msg, platform, tmp_dir, tmp_files, bot_id, @@ -787,15 +828,15 @@ def _typing_loop(): def _deliver_voice_response(response, config, client, msg, platform, tmp_dir, tmp_files, bot_id): """Convert response to voice/audio and send, falling back to text.""" + tts_ext = '.wav' if config.speech_provider == 'speechmatics' else '.mp3' fd, tts_file = tempfile.mkstemp( - prefix='claudio-tts-', suffix='.mp3', dir=tmp_dir, + prefix='claudio-tts-', suffix=tts_ext, dir=tmp_dir, ) os.close(fd) os.chmod(tts_file, 0o600) tmp_files.append(tts_file) - if tts_convert(response, tts_file, config.elevenlabs_api_key, - config.elevenlabs_voice_id, config.elevenlabs_model): + if _tts_convert(response, tts_file, config): if platform == 'telegram': ok = client.send_voice(msg.chat_id, tts_file, reply_to=msg.message_id) else: diff --git a/lib/speechmatics.py b/lib/speechmatics.py new file mode 100644 index 0000000..fa12941 --- /dev/null +++ b/lib/speechmatics.py @@ -0,0 +1,324 @@ +"""Speechmatics TTS and STT integration for Claudio. + +Alternative speech provider to ElevenLabs. Stdlib only — no external dependencies. + +TTS: https://preview.tts.speechmatics.com/generate/{voice_id} +STT: https://asr.api.speechmatics.com/v2/jobs/ (batch async) +""" + +import json +import os +import re +import time +import urllib.request +import urllib.error + +from lib.util import MultipartEncoder, log, log_error, strip_markdown + +# -- Constants -- + +TTS_MAX_CHARS = 5000 # Conservative limit +STT_MAX_SIZE = 20 * 1024 * 1024 # 20 MB + +TTS_API = "https://preview.tts.speechmatics.com/generate" +STT_API_TEMPLATE = "https://{region}.asr.api.speechmatics.com/v2" + +_VOICE_ID_RE = re.compile(r'^[a-z]{1,32}$') + +# WAV magic bytes: RIFF....WAVE +_WAV_MAGIC = b'RIFF' +_WAV_FORMAT = b'WAVE' + +# Polling config for batch STT +STT_POLL_INTERVAL = 2 # seconds between polls +STT_POLL_MAX_WAIT = 120 # max seconds to wait for job completion + + +def _validate_wav_magic(path): + """Validate that a file starts with WAV (RIFF/WAVE) magic bytes.""" + try: + with open(path, 'rb') as f: + header = f.read(12) + except OSError: + return False + + if len(header) < 12: + return False + + return header[:4] == _WAV_MAGIC and header[8:12] == _WAV_FORMAT + + +def tts_convert(text, output_path, api_key, voice_id='sarah'): + """Convert text to speech using Speechmatics TTS API. + + Args: + text: Text to convert to speech. + output_path: Path to write the WAV output file. + api_key: Speechmatics API key. + voice_id: Speechmatics voice ID (sarah, theo, megan, jack). + + Returns: + True on success, False on failure. + """ + if not api_key: + log_error("tts", "api_key not provided") + return False + + if not voice_id: + log_error("tts", "voice_id not provided") + return False + + if not _VOICE_ID_RE.match(voice_id): + log_error("tts", "Invalid voice_id format") + return False + + # Strip markdown formatting for cleaner speech + text = strip_markdown(text) + + if not text or not text.strip(): + log_error("tts", "No text to convert after stripping markdown") + return False + + # Truncate if over limit + if len(text) > TTS_MAX_CHARS: + text = text[:TTS_MAX_CHARS] + log("tts", f"Text truncated to {TTS_MAX_CHARS} characters") + + url = f"{TTS_API}/{voice_id}?output_format=wav_16000" + payload = json.dumps({"text": text}).encode('utf-8') + + req = urllib.request.Request( + url, + data=payload, + method='POST', + headers={ + 'Authorization': f'Bearer {api_key}', + 'Content-Type': 'application/json', + }, + ) + + try: + with urllib.request.urlopen(req, timeout=120) as resp: + data = resp.read() + except urllib.error.HTTPError as e: + error_detail = f"HTTP {e.code}" + try: + raw = e.read(500).decode('utf-8', errors='replace') + parsed = json.loads(raw) + msg = parsed.get('detail', parsed.get('message', '')) + if isinstance(msg, str) and msg: + error_detail = f"HTTP {e.code}: {msg[:100]}" + except Exception: + pass + log_error("tts", f"Speechmatics TTS API error: {error_detail}") + _safe_delete(output_path) + return False + except (urllib.error.URLError, OSError) as e: + log_error("tts", f"Speechmatics TTS request failed: {type(e).__name__}") + _safe_delete(output_path) + return False + + # Write output file + try: + with open(output_path, 'wb') as f: + f.write(data) + except OSError as e: + log_error("tts", f"Failed to write output file: {e}") + return False + + # Validate output is actually WAV audio + if not _validate_wav_magic(output_path): + log_error("tts", "Speechmatics returned non-audio content") + _safe_delete(output_path) + return False + + file_size = os.path.getsize(output_path) + log("tts", f"Generated voice audio: {file_size} bytes") + return True + + +def stt_transcribe(audio_path, api_key, region='eu1', language='en'): + """Transcribe audio using Speechmatics batch STT API. + + Submits a transcription job, polls for completion, and retrieves + the plain-text transcript. + + Args: + audio_path: Path to the audio file to transcribe. + api_key: Speechmatics API key. + region: API region (eu1, us1, au1). + language: ISO language code for transcription. + + Returns: + Transcription text on success, None on failure. + """ + if not api_key: + log_error("stt", "api_key not provided") + return None + + if not os.path.isfile(audio_path): + log_error("stt", f"Audio file not found: {audio_path}") + return None + + try: + file_size = os.path.getsize(audio_path) + except OSError as e: + log_error("stt", f"Cannot stat audio file: {e}") + return None + + if file_size == 0: + log_error("stt", f"Audio file is empty: {audio_path}") + return None + + if file_size > STT_MAX_SIZE: + log_error("stt", + f"Audio file too large: {file_size} bytes " + f"(max {STT_MAX_SIZE})") + return None + + if not re.match(r'^[a-z]{2}[0-9]$', region): + log_error("stt", f"Invalid region format: {region}") + return None + + base_url = STT_API_TEMPLATE.format(region=region) + + # Step 1: Submit transcription job + job_id = _submit_job(audio_path, api_key, base_url, language) + if not job_id: + return None + + # Step 2: Poll for completion + if not _wait_for_job(job_id, api_key, base_url): + return None + + # Step 3: Get transcript + text = _get_transcript(job_id, api_key, base_url) + if not text: + log_error("stt", "Speechmatics STT returned empty transcription") + return None + + log("stt", + f"Transcribed {file_size} bytes of audio ({len(text)} chars)") + + return text + + +def _submit_job(audio_path, api_key, base_url, language): + """Submit a batch transcription job. Returns job_id or None.""" + config_json = json.dumps({ + "type": "transcription", + "transcription_config": {"language": language}, + }) + + enc = MultipartEncoder() + enc.add_file('data_file', audio_path) + enc.add_field('config', config_json) + body = enc.finish() + + url = f"{base_url}/jobs/" + req = urllib.request.Request( + url, + data=body, + method='POST', + headers={ + 'Authorization': f'Bearer {api_key}', + 'Content-Type': enc.content_type, + }, + ) + + try: + with urllib.request.urlopen(req, timeout=60) as resp: + resp_data = resp.read() + except urllib.error.HTTPError as e: + error_detail = f"HTTP {e.code}" + try: + raw = e.read(500).decode('utf-8', errors='replace') + parsed = json.loads(raw) + msg = parsed.get('detail', parsed.get('message', '')) + if isinstance(msg, str) and msg: + error_detail = f"HTTP {e.code}: {msg[:100]}" + except Exception: + pass + log_error("stt", f"Speechmatics job submit error: {error_detail}") + return None + except (urllib.error.URLError, OSError) as e: + log_error("stt", f"Speechmatics job submit failed: {type(e).__name__}") + return None + + try: + result = json.loads(resp_data) + except (json.JSONDecodeError, ValueError): + log_error("stt", "Failed to parse job submit response") + return None + + job_id = result.get('id', '') + if not job_id: + log_error("stt", "No job ID in submit response") + return None + + log("stt", f"Submitted transcription job: {job_id}") + return job_id + + +def _wait_for_job(job_id, api_key, base_url): + """Poll job status until done. Returns True on success, False on failure.""" + url = f"{base_url}/jobs/{job_id}" + deadline = time.monotonic() + STT_POLL_MAX_WAIT + + while time.monotonic() < deadline: + req = urllib.request.Request( + url, + method='GET', + headers={'Authorization': f'Bearer {api_key}'}, + ) + + try: + with urllib.request.urlopen(req, timeout=30) as resp: + data = json.loads(resp.read()) + except (urllib.error.HTTPError, urllib.error.URLError, OSError, + json.JSONDecodeError, ValueError) as e: + log_error("stt", f"Error polling job {job_id}: {e}") + return False + + status = data.get('job', {}).get('status', '') + + if status == 'done': + return True + if status in ('rejected', 'deleted', 'expired'): + log_error("stt", f"Job {job_id} failed with status: {status}") + return False + + time.sleep(STT_POLL_INTERVAL) + + log_error("stt", f"Job {job_id} timed out after {STT_POLL_MAX_WAIT}s") + return False + + +def _get_transcript(job_id, api_key, base_url): + """Retrieve plain-text transcript for a completed job.""" + url = f"{base_url}/jobs/{job_id}/transcript?format=txt" + req = urllib.request.Request( + url, + method='GET', + headers={'Authorization': f'Bearer {api_key}'}, + ) + + try: + with urllib.request.urlopen(req, timeout=30) as resp: + text = resp.read().decode('utf-8', errors='replace').strip() + except urllib.error.HTTPError as e: + log_error("stt", f"Speechmatics transcript fetch error: HTTP {e.code}") + return None + except (urllib.error.URLError, OSError) as e: + log_error("stt", f"Speechmatics transcript fetch failed: {type(e).__name__}") + return None + + return text + + +def _safe_delete(path): + """Delete a file, ignoring errors if it does not exist.""" + try: + os.unlink(path) + except OSError: + pass diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 720f5f6..9c8b97a 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -531,8 +531,8 @@ def setUp(self): self._add_patch("lib.handlers.run_claude", return_value=_mock_claude_result()) self._add_patch("lib.handlers._memory_retrieve", return_value='') self._add_patch("lib.handlers._memory_consolidate") - self._add_patch("lib.handlers.stt_transcribe", return_value="transcribed text") - self._add_patch("lib.handlers.tts_convert", return_value=True) + self._add_patch("lib.handlers._stt_transcribe", return_value="transcribed text") + self._add_patch("lib.handlers._tts_convert", return_value=True) # Start all patches for p in self.patches: @@ -672,8 +672,8 @@ def tearDown(self): @patch("lib.handlers._memory_consolidate") @patch("lib.handlers._memory_retrieve", return_value='') - @patch("lib.handlers.stt_transcribe", return_value="Hello from voice") - @patch("lib.handlers.tts_convert", return_value=True) + @patch("lib.handlers._stt_transcribe", return_value="Hello from voice") + @patch("lib.handlers._tts_convert", return_value=True) @patch("lib.handlers.run_claude", return_value=_mock_claude_result("Voice reply")) def test_voice_message_flow(self, mock_claude, mock_tts, mock_stt, mock_mem_r, mock_mem_c): @@ -828,7 +828,7 @@ def test_image_download_failure(self, mock_claude, mock_mem_r, mock_mem_c): @patch("lib.handlers._memory_consolidate") @patch("lib.handlers._memory_retrieve", return_value='') - @patch("lib.handlers.stt_transcribe", return_value='') + @patch("lib.handlers._stt_transcribe", return_value='') @patch("lib.handlers.run_claude") def test_voice_transcription_failure(self, mock_claude, mock_stt, mock_mem_r, mock_mem_c): @@ -860,8 +860,8 @@ def test_whatsapp_audio_flow(self, mock_claude, mock_mem_r, mock_mem_c): ) self.client.download_audio.return_value = True - with patch("lib.handlers.stt_transcribe", return_value="wa transcription"), \ - patch("lib.handlers.tts_convert", return_value=True): + with patch("lib.handlers._stt_transcribe", return_value="wa transcription"), \ + patch("lib.handlers._tts_convert", return_value=True): _process_message(msg, '', self.config, self.client, "whatsapp", "test-bot") diff --git a/tests/test_speechmatics.py b/tests/test_speechmatics.py new file mode 100644 index 0000000..edb4c1c --- /dev/null +++ b/tests/test_speechmatics.py @@ -0,0 +1,383 @@ +#!/usr/bin/env python3 +"""Tests for lib/speechmatics.py — Speechmatics TTS and STT integration.""" + +import io +import json +import os +import sys +import urllib.error +import urllib.request +from unittest.mock import MagicMock, patch + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from lib.speechmatics import ( + TTS_MAX_CHARS, + STT_MAX_SIZE, + STT_POLL_MAX_WAIT, + _validate_wav_magic, + tts_convert, + stt_transcribe, + _submit_job, + _wait_for_job, + _get_transcript, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _mock_response(body, code=200): + """Build a MagicMock that behaves like an HTTPResponse (context manager).""" + if isinstance(body, dict): + body = json.dumps(body).encode("utf-8") + elif isinstance(body, str): + body = body.encode("utf-8") + resp = MagicMock() + resp.read.return_value = body + resp.code = code + resp.__enter__ = MagicMock(return_value=resp) + resp.__exit__ = MagicMock(return_value=False) + return resp + + +def _mock_http_error(code, body=""): + fp = io.BytesIO(body.encode("utf-8") if isinstance(body, str) else body) + return urllib.error.HTTPError( + url="https://preview.tts.speechmatics.com/test", + code=code, + msg=f"HTTP {code}", + hdrs={}, + fp=fp, + ) + + +# Valid WAV data: RIFF....WAVE header +_VALID_WAV = b'RIFF' + b'\x00\x00\x00\x00' + b'WAVE' + b'\x00' * 200 + + +# --------------------------------------------------------------------------- +# tts_convert +# --------------------------------------------------------------------------- + +class TestTtsConvert: + + @patch("urllib.request.urlopen") + def test_success_writes_wav(self, mock_urlopen, tmp_path): + out = str(tmp_path / "output.wav") + mock_urlopen.return_value = _mock_response(_VALID_WAV) + + result = tts_convert("Hello world", out, api_key="key123", voice_id="sarah") + assert result is True + assert os.path.isfile(out) + with open(out, "rb") as f: + assert f.read() == _VALID_WAV + + @patch("urllib.request.urlopen") + def test_api_error_returns_false(self, mock_urlopen, tmp_path): + out = str(tmp_path / "output.wav") + mock_urlopen.side_effect = _mock_http_error(500, "server error") + + result = tts_convert("Hello", out, api_key="key123", voice_id="sarah") + assert result is False + assert not os.path.exists(out) + + @patch("urllib.request.urlopen") + def test_url_error_returns_false(self, mock_urlopen, tmp_path): + out = str(tmp_path / "output.wav") + mock_urlopen.side_effect = urllib.error.URLError("connection refused") + + result = tts_convert("Hello", out, api_key="key123", voice_id="sarah") + assert result is False + + def test_empty_text_after_markdown_stripping(self, tmp_path): + out = str(tmp_path / "output.wav") + result = tts_convert("```\ncode only\n```", out, api_key="key123", voice_id="sarah") + assert result is False + + @patch("urllib.request.urlopen") + def test_truncation_at_max_chars(self, mock_urlopen, tmp_path): + out = str(tmp_path / "output.wav") + mock_urlopen.return_value = _mock_response(_VALID_WAV) + long_text = "A" * 8000 + + result = tts_convert(long_text, out, api_key="key123", voice_id="sarah") + assert result is True + + req = mock_urlopen.call_args[0][0] + payload = json.loads(req.data.decode("utf-8")) + assert len(payload["text"]) == TTS_MAX_CHARS + + def test_invalid_voice_id_rejected(self, tmp_path): + out = str(tmp_path / "output.wav") + assert tts_convert("Hello", out, api_key="key", voice_id="Bad/ID!") is False + assert tts_convert("Hello", out, api_key="key", voice_id="") is False + assert tts_convert("Hello", out, api_key="key", voice_id="UPPER") is False + + def test_missing_api_key_rejected(self, tmp_path): + out = str(tmp_path / "output.wav") + assert tts_convert("Hello", out, api_key="", voice_id="sarah") is False + + @patch("urllib.request.urlopen") + def test_non_audio_response_rejected(self, mock_urlopen, tmp_path): + out = str(tmp_path / "output.wav") + mock_urlopen.return_value = _mock_response(b"Error") + + result = tts_convert("Hello", out, api_key="key123", voice_id="sarah") + assert result is False + assert not os.path.exists(out) + + @patch("urllib.request.urlopen") + def test_correct_api_url_and_headers(self, mock_urlopen, tmp_path): + out = str(tmp_path / "output.wav") + mock_urlopen.return_value = _mock_response(_VALID_WAV) + + tts_convert("Hello", out, api_key="mykey", voice_id="theo") + + req = mock_urlopen.call_args[0][0] + assert "theo" in req.full_url + assert "output_format=wav_16000" in req.full_url + assert req.get_header("Authorization") == "Bearer mykey" + assert req.get_header("Content-type") == "application/json" + + payload = json.loads(req.data.decode("utf-8")) + assert payload["text"] == "Hello" + + @patch("urllib.request.urlopen") + def test_valid_voice_ids(self, mock_urlopen, tmp_path): + """All documented voice IDs should be accepted.""" + out = str(tmp_path / "output.wav") + mock_urlopen.return_value = _mock_response(_VALID_WAV) + + for voice in ("sarah", "theo", "megan", "jack"): + result = tts_convert("Hello", out, api_key="key", voice_id=voice) + assert result is True + + +# --------------------------------------------------------------------------- +# stt_transcribe +# --------------------------------------------------------------------------- + +class TestSttTranscribe: + + @patch("lib.speechmatics.time.sleep") + @patch("urllib.request.urlopen") + def test_success_returns_text(self, mock_urlopen, mock_sleep, tmp_path): + audio = tmp_path / "audio.ogg" + audio.write_bytes(b"OggS" + b"\x00" * 100) + + # Sequence: submit job -> poll (done) -> get transcript + mock_urlopen.side_effect = [ + _mock_response({"id": "job123"}), + _mock_response({"job": {"status": "done"}}), + _mock_response("Hello world"), + ] + + result = stt_transcribe(str(audio), api_key="key123") + assert result == "Hello world" + + @patch("urllib.request.urlopen") + def test_submit_error_returns_none(self, mock_urlopen, tmp_path): + audio = tmp_path / "audio.ogg" + audio.write_bytes(b"OggS" + b"\x00" * 100) + + mock_urlopen.side_effect = _mock_http_error(500, "server error") + + result = stt_transcribe(str(audio), api_key="key123") + assert result is None + + def test_empty_file_returns_none(self, tmp_path): + audio = tmp_path / "empty.ogg" + audio.write_bytes(b"") + assert stt_transcribe(str(audio), api_key="key123") is None + + def test_file_too_large_returns_none(self, tmp_path): + audio = tmp_path / "huge.ogg" + audio.write_bytes(b"\x00" * (STT_MAX_SIZE + 1)) + assert stt_transcribe(str(audio), api_key="key123") is None + + def test_missing_file_returns_none(self): + assert stt_transcribe("/nonexistent/audio.ogg", api_key="key123") is None + + def test_missing_api_key_returns_none(self, tmp_path): + audio = tmp_path / "audio.ogg" + audio.write_bytes(b"OggS" + b"\x00" * 100) + assert stt_transcribe(str(audio), api_key="") is None + + def test_invalid_region_returns_none(self, tmp_path): + audio = tmp_path / "audio.ogg" + audio.write_bytes(b"OggS" + b"\x00" * 100) + assert stt_transcribe(str(audio), api_key="key123", region="bad") is None + assert stt_transcribe(str(audio), api_key="key123", region="EU1") is None + + @patch("lib.speechmatics.time.sleep") + @patch("urllib.request.urlopen") + def test_empty_transcription_returns_none(self, mock_urlopen, mock_sleep, tmp_path): + audio = tmp_path / "audio.ogg" + audio.write_bytes(b"OggS" + b"\x00" * 100) + + mock_urlopen.side_effect = [ + _mock_response({"id": "job123"}), + _mock_response({"job": {"status": "done"}}), + _mock_response(""), + ] + + result = stt_transcribe(str(audio), api_key="key123") + assert result is None + + @patch("lib.speechmatics.time.sleep") + @patch("urllib.request.urlopen") + def test_polling_waits_for_done(self, mock_urlopen, mock_sleep, tmp_path): + """Job that is 'running' then 'done' should succeed.""" + audio = tmp_path / "audio.ogg" + audio.write_bytes(b"OggS" + b"\x00" * 100) + + mock_urlopen.side_effect = [ + _mock_response({"id": "job123"}), # submit + _mock_response({"job": {"status": "running"}}), # poll 1 + _mock_response({"job": {"status": "running"}}), # poll 2 + _mock_response({"job": {"status": "done"}}), # poll 3 + _mock_response("Transcribed text"), # transcript + ] + + result = stt_transcribe(str(audio), api_key="key123") + assert result == "Transcribed text" + assert mock_sleep.call_count == 2 # slept during 'running' polls + + @patch("lib.speechmatics.time.sleep") + @patch("urllib.request.urlopen") + def test_rejected_job_returns_none(self, mock_urlopen, mock_sleep, tmp_path): + audio = tmp_path / "audio.ogg" + audio.write_bytes(b"OggS" + b"\x00" * 100) + + mock_urlopen.side_effect = [ + _mock_response({"id": "job123"}), + _mock_response({"job": {"status": "rejected"}}), + ] + + result = stt_transcribe(str(audio), api_key="key123") + assert result is None + + @patch("urllib.request.urlopen") + def test_correct_multipart_request(self, mock_urlopen, tmp_path): + """Verify the multipart submit request carries the file and config.""" + audio = tmp_path / "audio.ogg" + audio.write_bytes(b"OggS" + b"\x00" * 50) + + # Just test the submit step — make it fail after so we don't need all steps + mock_urlopen.return_value = _mock_response({"id": "job123"}) + + # Manually call _submit_job to inspect the request + job_id = _submit_job(str(audio), "key123", + "https://eu1.asr.api.speechmatics.com/v2", "en") + assert job_id == "job123" + + req = mock_urlopen.call_args[0][0] + assert "/jobs/" in req.full_url + assert req.get_header("Authorization") == "Bearer key123" + assert "multipart/form-data" in req.get_header("Content-type") + body = req.data + assert b"OggS" in body + assert b"transcription" in body + + @patch("urllib.request.urlopen") + def test_url_error_returns_none(self, mock_urlopen, tmp_path): + audio = tmp_path / "audio.ogg" + audio.write_bytes(b"OggS" + b"\x00" * 100) + mock_urlopen.side_effect = urllib.error.URLError("network error") + assert stt_transcribe(str(audio), api_key="key123") is None + + +# --------------------------------------------------------------------------- +# _validate_wav_magic +# --------------------------------------------------------------------------- + +class TestValidateWavMagic: + + def test_valid_wav(self, tmp_path): + f = tmp_path / "test.wav" + f.write_bytes(_VALID_WAV) + assert _validate_wav_magic(str(f)) is True + + def test_mp3_rejected(self, tmp_path): + f = tmp_path / "test.mp3" + f.write_bytes(b"ID3\x04\x00" + b"\x00" * 100) + assert _validate_wav_magic(str(f)) is False + + def test_empty_file_rejected(self, tmp_path): + f = tmp_path / "empty.bin" + f.write_bytes(b"") + assert _validate_wav_magic(str(f)) is False + + def test_too_short_file_rejected(self, tmp_path): + f = tmp_path / "tiny.bin" + f.write_bytes(b"RIFF\x00\x00") + assert _validate_wav_magic(str(f)) is False + + def test_nonexistent_file_rejected(self): + assert _validate_wav_magic("/nonexistent/file.wav") is False + + def test_riff_without_wave_rejected(self, tmp_path): + """RIFF file that is not WAVE (e.g. AVI) should be rejected.""" + f = tmp_path / "test.avi" + f.write_bytes(b'RIFF' + b'\x00\x00\x00\x00' + b'AVI ' + b'\x00' * 100) + assert _validate_wav_magic(str(f)) is False + + def test_plain_text_rejected(self, tmp_path): + f = tmp_path / "text.txt" + f.write_bytes(b"This is just plain text, not audio.") + assert _validate_wav_magic(str(f)) is False + + +# --------------------------------------------------------------------------- +# _wait_for_job +# --------------------------------------------------------------------------- + +class TestWaitForJob: + + @patch("lib.speechmatics.time.sleep") + @patch("lib.speechmatics.time.monotonic") + @patch("urllib.request.urlopen") + def test_timeout(self, mock_urlopen, mock_mono, mock_sleep): + """Job that never completes should time out.""" + # Simulate time progressing past deadline + mock_mono.side_effect = [0, 0, STT_POLL_MAX_WAIT + 1] + mock_urlopen.return_value = _mock_response({"job": {"status": "running"}}) + + result = _wait_for_job("job123", "key", + "https://eu1.asr.api.speechmatics.com/v2") + assert result is False + + @patch("lib.speechmatics.time.sleep") + @patch("urllib.request.urlopen") + def test_poll_error_returns_false(self, mock_urlopen, mock_sleep): + mock_urlopen.side_effect = urllib.error.URLError("network error") + + result = _wait_for_job("job123", "key", + "https://eu1.asr.api.speechmatics.com/v2") + assert result is False + + +# --------------------------------------------------------------------------- +# _get_transcript +# --------------------------------------------------------------------------- + +class TestGetTranscript: + + @patch("urllib.request.urlopen") + def test_success(self, mock_urlopen): + mock_urlopen.return_value = _mock_response("Hello world") + result = _get_transcript("job123", "key", + "https://eu1.asr.api.speechmatics.com/v2") + assert result == "Hello world" + + req = mock_urlopen.call_args[0][0] + assert "format=txt" in req.full_url + + @patch("urllib.request.urlopen") + def test_http_error(self, mock_urlopen): + mock_urlopen.side_effect = _mock_http_error(404, "not found") + result = _get_transcript("job123", "key", + "https://eu1.asr.api.speechmatics.com/v2") + assert result is None