diff --git a/decart/realtime/audio_stream_manager.py b/decart/realtime/audio_stream_manager.py new file mode 100644 index 0000000..19d58e0 --- /dev/null +++ b/decart/realtime/audio_stream_manager.py @@ -0,0 +1,150 @@ +""" +Audio stream manager for live_avatar mode. + +Mirrors the JS SDK's AudioStreamManager — ensures WebRTC always has +audio frames to send even when no user mic/audio is provided. +""" + +import asyncio +import fractions +import io +import logging +from collections import deque +from pathlib import Path +from typing import Optional, Union + +import av +from aiortc import MediaStreamTrack + +logger = logging.getLogger(__name__) + +SAMPLE_RATE = 48000 +SAMPLES_PER_FRAME = 960 # 20ms at 48kHz +BYTES_PER_SAMPLE = 2 # s16 format +BYTES_PER_FRAME = SAMPLES_PER_FRAME * BYTES_PER_SAMPLE + + +def _make_silence_frame() -> av.AudioFrame: + frame = av.AudioFrame(samples=SAMPLES_PER_FRAME, layout="mono", format="s16") + for plane in frame.planes: + plane.update(bytes(BYTES_PER_FRAME)) + return frame + + +class _AudioTrack(MediaStreamTrack): + kind = "audio" + + def __init__(self) -> None: + super().__init__() + self._queue: deque[av.AudioFrame] = deque() + self._pts = 0 + self._start: Optional[float] = None + self._done_event: Optional[asyncio.Event] = None + + async def recv(self) -> av.AudioFrame: + if self._start is None: + self._start = asyncio.get_event_loop().time() + + target = self._start + (self._pts / SAMPLE_RATE) + delay = target - asyncio.get_event_loop().time() + if delay > 0: + await asyncio.sleep(delay) + + if self._queue: + frame = self._queue.popleft() + if not self._queue and self._done_event: + self._done_event.set() + self._done_event = None + else: + frame = _make_silence_frame() + + frame.pts = self._pts + frame.sample_rate = SAMPLE_RATE + frame.time_base = fractions.Fraction(1, SAMPLE_RATE) + self._pts += SAMPLES_PER_FRAME + + return frame + + def enqueue(self, frames: list[av.AudioFrame], done: asyncio.Event) -> None: + self._queue.extend(frames) + self._done_event = done + + def clear(self) -> None: + self._queue.clear() + if self._done_event: + self._done_event.set() + self._done_event = None + + +class AudioStreamManager: + """Manages audio for live_avatar mode. + + Provides a continuous audio track that outputs silence by default + and allows playing audio data through it via play_audio(). + """ + + def __init__(self) -> None: + self._track = _AudioTrack() + self._playing = False + + def get_track(self) -> MediaStreamTrack: + return self._track + + @property + def is_playing(self) -> bool: + return self._playing + + async def play_audio(self, audio: Union[bytes, str, Path]) -> None: + """Play audio through the stream. Resolves when audio finishes playing. + + Args: + audio: Audio data as bytes, file path string, or Path object. + """ + if self._playing: + self.stop_audio() + + if isinstance(audio, bytes): + container: av.InputContainer = av.open(io.BytesIO(audio)) # type: ignore[assignment] + else: + container: av.InputContainer = av.open(str(audio)) # type: ignore[assignment] + + try: + resampler = av.AudioResampler(format="s16", layout="mono", rate=SAMPLE_RATE) + raw = bytearray() + + for frame in container.decode(audio=0): + for resampled in resampler.resample(frame): + raw.extend(bytes(resampled.planes[0])) + + for resampled in resampler.resample(None): + raw.extend(bytes(resampled.planes[0])) + finally: + container.close() + + if not raw: + return + + frames = [] + for i in range(0, len(raw), BYTES_PER_FRAME): + chunk = raw[i : i + BYTES_PER_FRAME] + if len(chunk) < BYTES_PER_FRAME: + chunk.extend(bytes(BYTES_PER_FRAME - len(chunk))) + + frame = av.AudioFrame(samples=SAMPLES_PER_FRAME, layout="mono", format="s16") + frame.planes[0].update(bytes(chunk)) + frames.append(frame) + + done = asyncio.Event() + self._playing = True + self._track.enqueue(frames, done) + + await done.wait() + self._playing = False + + def stop_audio(self) -> None: + self._track.clear() + self._playing = False + + def cleanup(self) -> None: + self.stop_audio() + self._track.stop() diff --git a/decart/realtime/client.py b/decart/realtime/client.py index 36876ab..3250c69 100644 --- a/decart/realtime/client.py +++ b/decart/realtime/client.py @@ -8,6 +8,7 @@ from aiortc import MediaStreamTrack from pydantic import BaseModel +from .audio_stream_manager import AudioStreamManager from .webrtc_manager import WebRTCManager, WebRTCConfiguration from .messages import PromptMessage, SessionIdMessage, GenerationTickMessage from .subscribe import ( @@ -75,6 +76,7 @@ def __init__( self._manager = manager self._http_session = http_session self._model_name = model_name + self._audio_stream_manager: Optional[AudioStreamManager] = None self._connection_callbacks: list[Callable[[ConnectionState], None]] = [] self._error_callbacks: list[Callable[[DecartSDKError], None]] = [] self._generation_tick_callbacks: list[Callable[[GenerationTickMessage], None]] = [] @@ -111,6 +113,13 @@ async def connect( model_name: RealTimeModels = options.model.name # type: ignore[assignment] + is_avatar_live = model_name == "live_avatar" + audio_stream_manager: Optional[AudioStreamManager] = None + + if is_avatar_live and local_track is None: + audio_stream_manager = AudioStreamManager() + local_track = audio_stream_manager.get_track() + config = WebRTCConfiguration( webrtc_url=ws_url, api_key=api_key, @@ -126,7 +135,6 @@ async def connect( model_name=model_name, ) - # Create HTTP session for file conversions http_session = aiohttp.ClientSession() manager = WebRTCManager(config) @@ -135,6 +143,7 @@ async def connect( http_session=http_session, model_name=model_name, ) + client._audio_stream_manager = audio_stream_manager config.on_connection_state_change = client._emit_connection_change config.on_error = lambda error: client._emit_error(WebRTCError(str(error), cause=error)) @@ -159,6 +168,8 @@ async def connect( initial_prompt=initial_prompt, ) except Exception as e: + if audio_stream_manager: + audio_stream_manager.cleanup() await manager.cleanup() await http_session.close() raise WebRTCError(str(e), cause=e) @@ -319,6 +330,17 @@ async def set_prompt( finally: self._manager.unregister_prompt_wait(prompt) + async def play_audio(self, audio: Union[bytes, str, Path]) -> None: + """Play audio through the avatar stream. Resolves when audio finishes. + + Only available for live_avatar connections without a user-provided audio track. + """ + if self._audio_stream_manager is None: + raise InvalidInputError( + "play_audio() is only available for live_avatar without a user-provided audio track" + ) + await self._audio_stream_manager.play_audio(audio) + async def set_image( self, image: Optional[FileInput], @@ -349,6 +371,9 @@ def get_connection_state(self) -> ConnectionState: async def disconnect(self) -> None: self._buffering = False self._buffer.clear() + if self._audio_stream_manager: + self._audio_stream_manager.cleanup() + self._audio_stream_manager = None await self._manager.cleanup() if self._http_session and not self._http_session.closed: await self._http_session.close() diff --git a/examples/avatar_live.py b/examples/avatar_live.py index 94d4cb4..55c3dd2 100644 --- a/examples/avatar_live.py +++ b/examples/avatar_live.py @@ -63,8 +63,8 @@ async def main(): if audio_file: print(f"🔊 Audio file: {audio_file}") - # Load audio if provided audio_track = None + if audio_file: print("Loading audio file...") player = MediaPlayer(audio_file) @@ -131,12 +131,15 @@ def on_error(error): print("✓ Connected!") print(f"Session ID: {realtime_client.session_id}") - if audio_file: - print("\nPlaying audio through avatar...") - print("(The avatar will animate based on the audio)") + if audio_file and not audio_track: + print("\nPlaying audio via play_audio()...") + await realtime_client.play_audio(audio_file) + print("✓ Audio playback complete") + elif audio_file: + print("\nStreaming audio through avatar via MediaStreamTrack...") else: - print("\nNo audio provided - avatar will be static") - print("You can update the avatar image dynamically using set_image()") + print("\nNo audio provided - avatar will be idle") + print("You can play audio dynamically using play_audio()") print("\nPress Ctrl+C to stop and save the recording...")