Skip to content

Commit 01fd69e

Browse files
committed
feat(realtime): add realtime audio transcription support
This commit adds support for realtime audio transcription using WebSocket connections. The implementation includes: 1. New realtime transcription client in the extra module 2. Examples for microphone and file-based transcription 3. Support for audio format negotiation 4. Proper error handling and connection management The realtime transcription feature requires the websockets package (>=13.0) which is now added as an optional dependency. This implementation allows for streaming audio data to the Mistral API and receiving transcription results in realtime. The changes include new models for realtime events and connection management, as well as updated audio.py to expose the realtime functionality.
1 parent a7783e3 commit 01fd69e

16 files changed

+1134
-30
lines changed
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
#!/usr/bin/env python
2+
# /// script
3+
# requires-python = ">=3.9"
4+
# dependencies = [
5+
# "mistralai[realtime]",
6+
# "pyaudio",
7+
# "rich",
8+
# ]
9+
# [tool.uv.sources]
10+
# mistralai = { path = "../../..", editable = true }
11+
# ///
12+
13+
import argparse
14+
import asyncio
15+
import os
16+
import sys
17+
from typing import AsyncIterator
18+
19+
from rich.align import Align
20+
from rich.console import Console
21+
from rich.layout import Layout
22+
from rich.live import Live
23+
from rich.panel import Panel
24+
from rich.text import Text
25+
26+
from mistralai import Mistral
27+
from mistralai.extra.realtime import UnknownRealtimeEvent
28+
from mistralai.models import (
29+
AudioFormat,
30+
RealtimeTranscriptionError,
31+
RealtimeTranscriptionSessionCreated,
32+
TranscriptionStreamDone,
33+
TranscriptionStreamTextDelta,
34+
)
35+
36+
console = Console()
37+
38+
39+
class TranscriptDisplay:
40+
"""Manages the live transcript display."""
41+
42+
def __init__(self, model: str) -> None:
43+
self.model = model
44+
self.transcript = ""
45+
self.status = "🔌 Connecting..."
46+
self.error: str | None = None
47+
48+
def set_listening(self) -> None:
49+
self.status = "🎤 Listening..."
50+
51+
def add_text(self, text: str) -> None:
52+
self.transcript += text
53+
54+
def set_done(self) -> None:
55+
self.status = "✅ Done"
56+
57+
def set_error(self, error: str) -> None:
58+
self.status = "❌ Error"
59+
self.error = error
60+
61+
def render(self) -> Layout:
62+
layout = Layout()
63+
64+
# Create minimal header
65+
header_text = Text()
66+
header_text.append("│ ", style="dim")
67+
header_text.append(self.model, style="dim")
68+
header_text.append(" │ ", style="dim")
69+
70+
if "Listening" in self.status:
71+
status_style = "green"
72+
elif "Connecting" in self.status:
73+
status_style = "yellow dim"
74+
elif "Done" in self.status or "Stopped" in self.status:
75+
status_style = "dim"
76+
else:
77+
status_style = "red"
78+
header_text.append(self.status, style=status_style)
79+
80+
header = Align.left(header_text, vertical="middle", pad=False)
81+
82+
# Create main transcript area - no title, minimal border
83+
transcript_text = Text(
84+
self.transcript or "...", style="white" if self.transcript else "dim"
85+
)
86+
transcript = Panel(
87+
Align.left(transcript_text, vertical="top"),
88+
border_style="dim",
89+
padding=(1, 2),
90+
)
91+
92+
# Minimal footer
93+
footer_text = Text()
94+
footer_text.append("ctrl+c", style="dim")
95+
footer_text.append(" quit", style="dim italic")
96+
footer = Align.left(footer_text, vertical="middle", pad=False)
97+
98+
# Handle error display
99+
if self.error:
100+
layout.split_column(
101+
Layout(header, name="header", size=1),
102+
Layout(transcript, name="body"),
103+
Layout(
104+
Panel(Text(self.error, style="red"), border_style="red"),
105+
name="error",
106+
size=4,
107+
),
108+
Layout(footer, name="footer", size=1),
109+
)
110+
else:
111+
layout.split_column(
112+
Layout(header, name="header", size=1),
113+
Layout(transcript, name="body"),
114+
Layout(footer, name="footer", size=1),
115+
)
116+
117+
return layout
118+
119+
120+
async def iter_microphone(
121+
*,
122+
sample_rate: int,
123+
chunk_duration_ms: int,
124+
) -> AsyncIterator[bytes]:
125+
"""
126+
Yield microphone PCM chunks using PyAudio (16-bit mono).
127+
Encoding is always pcm_s16le.
128+
"""
129+
import pyaudio
130+
131+
p = pyaudio.PyAudio()
132+
chunk_samples = int(sample_rate * chunk_duration_ms / 1000)
133+
134+
stream = p.open(
135+
format=pyaudio.paInt16,
136+
channels=1,
137+
rate=sample_rate,
138+
input=True,
139+
frames_per_buffer=chunk_samples,
140+
)
141+
142+
loop = asyncio.get_running_loop()
143+
try:
144+
while True:
145+
# stream.read is blocking; run it off-thread
146+
data = await loop.run_in_executor(None, stream.read, chunk_samples, False)
147+
yield data
148+
finally:
149+
stream.stop_stream()
150+
stream.close()
151+
p.terminate()
152+
153+
154+
def parse_args() -> argparse.Namespace:
155+
parser = argparse.ArgumentParser(description="Real-time microphone transcription.")
156+
parser.add_argument("--model", default="voxtral-mini-transcribe-realtime-2602", help="Model ID")
157+
parser.add_argument(
158+
"--sample-rate",
159+
type=int,
160+
default=16000,
161+
choices=[8000, 16000, 22050, 44100, 48000],
162+
help="Sample rate in Hz",
163+
)
164+
parser.add_argument(
165+
"--chunk-duration", type=int, default=10, help="Chunk duration in ms"
166+
)
167+
parser.add_argument(
168+
"--api-key", default=os.environ.get("MISTRAL_API_KEY"), help="Mistral API key"
169+
)
170+
parser.add_argument(
171+
"--base-url",
172+
default=os.environ.get("MISTRAL_BASE_URL", "wss://api.mistral.ai"),
173+
)
174+
return parser.parse_args()
175+
176+
177+
async def main() -> int:
178+
args = parse_args()
179+
api_key = args.api_key or os.environ["MISTRAL_API_KEY"]
180+
181+
client = Mistral(api_key=api_key, server_url=args.base_url)
182+
183+
# microphone is always pcm_s16le here
184+
audio_format = AudioFormat(encoding="pcm_s16le", sample_rate=args.sample_rate)
185+
186+
mic_stream = iter_microphone(
187+
sample_rate=args.sample_rate, chunk_duration_ms=args.chunk_duration
188+
)
189+
190+
display = TranscriptDisplay(model=args.model)
191+
192+
with Live(
193+
display.render(), console=console, refresh_per_second=10, screen=True
194+
) as live:
195+
try:
196+
async for event in client.audio.realtime.transcribe_stream(
197+
audio_stream=mic_stream,
198+
model=args.model,
199+
audio_format=audio_format,
200+
):
201+
if isinstance(event, RealtimeTranscriptionSessionCreated):
202+
display.set_listening()
203+
live.update(display.render())
204+
elif isinstance(event, TranscriptionStreamTextDelta):
205+
display.add_text(event.text)
206+
live.update(display.render())
207+
elif isinstance(event, TranscriptionStreamDone):
208+
display.set_done()
209+
live.update(display.render())
210+
break
211+
elif isinstance(event, RealtimeTranscriptionError):
212+
display.set_error(str(event.error))
213+
live.update(display.render())
214+
return 1
215+
elif isinstance(event, UnknownRealtimeEvent):
216+
continue
217+
except KeyboardInterrupt:
218+
display.status = "⏹️ Stopped"
219+
live.update(display.render())
220+
221+
return 0
222+
223+
224+
if __name__ == "__main__":
225+
sys.exit(asyncio.run(main()))
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
#!/usr/bin/env python
2+
3+
import argparse
4+
import asyncio
5+
import os
6+
import subprocess
7+
import sys
8+
import tempfile
9+
from pathlib import Path
10+
from typing import AsyncIterator
11+
12+
from mistralai import Mistral
13+
from mistralai.extra.realtime.connection import UnknownRealtimeEvent
14+
from mistralai.models import (
15+
AudioFormat,
16+
RealtimeTranscriptionError,
17+
TranscriptionStreamDone,
18+
TranscriptionStreamTextDelta,
19+
)
20+
21+
22+
def convert_audio_to_pcm(
23+
input_path: Path,
24+
) -> Path:
25+
temp_file = tempfile.NamedTemporaryFile(suffix=".pcm", delete=False)
26+
temp_path = Path(temp_file.name)
27+
temp_file.close()
28+
29+
cmd = [
30+
"ffmpeg",
31+
"-y",
32+
"-i",
33+
str(input_path),
34+
"-f",
35+
"s16le",
36+
"-ar",
37+
str(16000),
38+
"-ac",
39+
"1",
40+
str(temp_path),
41+
]
42+
43+
try:
44+
subprocess.run(cmd, check=True, capture_output=True, text=True)
45+
except subprocess.CalledProcessError as exc:
46+
temp_path.unlink(missing_ok=True)
47+
raise RuntimeError(f"ffmpeg conversion failed: {exc.stderr}") from exc
48+
49+
return temp_path
50+
51+
52+
async def aiter_audio_file(
53+
path: Path,
54+
*,
55+
chunk_size: int = 4096,
56+
chunk_delay: float = 0.0,
57+
) -> AsyncIterator[bytes]:
58+
with open(path, "rb") as f:
59+
while True:
60+
chunk = f.read(chunk_size)
61+
if not chunk:
62+
break
63+
yield chunk
64+
if chunk_delay > 0:
65+
await asyncio.sleep(chunk_delay)
66+
67+
68+
def parse_args() -> argparse.Namespace:
69+
parser = argparse.ArgumentParser(
70+
description="Real-time audio transcription via WebSocket (iterator-based)."
71+
)
72+
parser.add_argument("file", type=Path, help="Path to the audio file")
73+
parser.add_argument("--model", default="voxtral-mini-2601", help="Model ID")
74+
parser.add_argument(
75+
"--api-key",
76+
default=os.environ.get("MISTRAL_API_KEY"),
77+
help="Mistral API key",
78+
)
79+
parser.add_argument(
80+
"--base-url",
81+
default=os.environ.get("MISTRAL_BASE_URL", "https://api.mistral.ai"),
82+
help="API base URL (http/https/ws/wss)",
83+
)
84+
parser.add_argument(
85+
"--chunk-size", type=int, default=4096, help="Audio chunk size in bytes"
86+
)
87+
parser.add_argument(
88+
"--chunk-delay",
89+
type=float,
90+
default=0.01,
91+
help="Delay between chunks in seconds",
92+
)
93+
parser.add_argument(
94+
"--no-convert",
95+
action="store_true",
96+
help="Skip ffmpeg conversion (input must be raw PCM)",
97+
)
98+
return parser.parse_args()
99+
100+
101+
async def main() -> int:
102+
args = parse_args()
103+
api_key = args.api_key or os.environ["MISTRAL_API_KEY"]
104+
105+
pcm_path = args.file
106+
temp_path = None
107+
108+
if not args.no_convert and args.file.suffix.lower() not in (".pcm", ".raw"):
109+
pcm_path = convert_audio_to_pcm(args.file)
110+
temp_path = pcm_path
111+
112+
client = Mistral(api_key=api_key, server_url=args.base_url)
113+
114+
try:
115+
async for event in client.audio.realtime.transcribe_stream(
116+
audio_stream=aiter_audio_file(
117+
pcm_path,
118+
chunk_size=args.chunk_size,
119+
chunk_delay=args.chunk_delay,
120+
),
121+
model=args.model,
122+
audio_format=AudioFormat(encoding="pcm_s16le", sample_rate=16000),
123+
):
124+
if isinstance(event, TranscriptionStreamTextDelta):
125+
print(event.text, end="", flush=True)
126+
elif isinstance(event, TranscriptionStreamDone):
127+
print()
128+
break
129+
elif isinstance(event, RealtimeTranscriptionError):
130+
print(f"\nError: {event.error}", file=sys.stderr)
131+
break
132+
elif isinstance(event, UnknownRealtimeEvent):
133+
# ignore future / unknown events; keep going
134+
continue
135+
136+
finally:
137+
if temp_path is not None:
138+
temp_path.unlink(missing_ok=True)
139+
140+
return 0
141+
142+
143+
if __name__ == "__main__":
144+
sys.exit(asyncio.run(main()))

0 commit comments

Comments
 (0)