Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions decart/realtime/audio_stream_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""
Audio stream manager for live_avatar mode.

Mirrors the JS SDK's AudioStreamManager — ensures WebRTC always has
audio frames to send even when no user mic/audio is provided.
"""

import asyncio
import fractions
import io
import logging
from collections import deque
from pathlib import Path
from typing import Optional, Union

import av
from aiortc import MediaStreamTrack

logger = logging.getLogger(__name__)

SAMPLE_RATE = 48000
SAMPLES_PER_FRAME = 960 # 20ms at 48kHz
BYTES_PER_SAMPLE = 2 # s16 format
BYTES_PER_FRAME = SAMPLES_PER_FRAME * BYTES_PER_SAMPLE


def _make_silence_frame() -> av.AudioFrame:
frame = av.AudioFrame(samples=SAMPLES_PER_FRAME, layout="mono", format="s16")
for plane in frame.planes:
plane.update(bytes(BYTES_PER_FRAME))
return frame


class _AudioTrack(MediaStreamTrack):
kind = "audio"

def __init__(self) -> None:
super().__init__()
self._queue: deque[av.AudioFrame] = deque()
self._pts = 0
self._start: Optional[float] = None
self._done_event: Optional[asyncio.Event] = None

async def recv(self) -> av.AudioFrame:
if self._start is None:
self._start = asyncio.get_event_loop().time()

target = self._start + (self._pts / SAMPLE_RATE)
delay = target - asyncio.get_event_loop().time()
if delay > 0:
await asyncio.sleep(delay)

if self._queue:
frame = self._queue.popleft()
if not self._queue and self._done_event:
self._done_event.set()
self._done_event = None
else:
frame = _make_silence_frame()

frame.pts = self._pts
frame.sample_rate = SAMPLE_RATE
frame.time_base = fractions.Fraction(1, SAMPLE_RATE)
self._pts += SAMPLES_PER_FRAME

return frame

def enqueue(self, frames: list[av.AudioFrame], done: asyncio.Event) -> None:
self._queue.extend(frames)
self._done_event = done

def clear(self) -> None:
self._queue.clear()
if self._done_event:
self._done_event.set()
self._done_event = None


class AudioStreamManager:
"""Manages audio for live_avatar mode.

Provides a continuous audio track that outputs silence by default
and allows playing audio data through it via play_audio().
"""

def __init__(self) -> None:
self._track = _AudioTrack()
self._playing = False

def get_track(self) -> MediaStreamTrack:
return self._track

@property
def is_playing(self) -> bool:
return self._playing

async def play_audio(self, audio: Union[bytes, str, Path]) -> None:
"""Play audio through the stream. Resolves when audio finishes playing.

Args:
audio: Audio data as bytes, file path string, or Path object.
"""
if self._playing:
self.stop_audio()

if isinstance(audio, bytes):
container: av.InputContainer = av.open(io.BytesIO(audio)) # type: ignore[assignment]
else:
container: av.InputContainer = av.open(str(audio)) # type: ignore[assignment]

try:
resampler = av.AudioResampler(format="s16", layout="mono", rate=SAMPLE_RATE)
raw = bytearray()

for frame in container.decode(audio=0):
for resampled in resampler.resample(frame):
raw.extend(bytes(resampled.planes[0]))

for resampled in resampler.resample(None):
raw.extend(bytes(resampled.planes[0]))
finally:
container.close()

if not raw:
return

frames = []
for i in range(0, len(raw), BYTES_PER_FRAME):
chunk = raw[i : i + BYTES_PER_FRAME]
if len(chunk) < BYTES_PER_FRAME:
chunk.extend(bytes(BYTES_PER_FRAME - len(chunk)))

frame = av.AudioFrame(samples=SAMPLES_PER_FRAME, layout="mono", format="s16")
frame.planes[0].update(bytes(chunk))
frames.append(frame)

done = asyncio.Event()
self._playing = True
self._track.enqueue(frames, done)

await done.wait()
self._playing = False

def stop_audio(self) -> None:
self._track.clear()
self._playing = False

def cleanup(self) -> None:
self.stop_audio()
self._track.stop()
27 changes: 26 additions & 1 deletion decart/realtime/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from aiortc import MediaStreamTrack
from pydantic import BaseModel

from .audio_stream_manager import AudioStreamManager
from .webrtc_manager import WebRTCManager, WebRTCConfiguration
from .messages import PromptMessage, SessionIdMessage, GenerationTickMessage
from .subscribe import (
Expand Down Expand Up @@ -75,6 +76,7 @@ def __init__(
self._manager = manager
self._http_session = http_session
self._model_name = model_name
self._audio_stream_manager: Optional[AudioStreamManager] = None
self._connection_callbacks: list[Callable[[ConnectionState], None]] = []
self._error_callbacks: list[Callable[[DecartSDKError], None]] = []
self._generation_tick_callbacks: list[Callable[[GenerationTickMessage], None]] = []
Expand Down Expand Up @@ -111,6 +113,13 @@ async def connect(

model_name: RealTimeModels = options.model.name # type: ignore[assignment]

is_avatar_live = model_name == "live_avatar"
audio_stream_manager: Optional[AudioStreamManager] = None

if is_avatar_live and local_track is None:
audio_stream_manager = AudioStreamManager()
local_track = audio_stream_manager.get_track()

config = WebRTCConfiguration(
webrtc_url=ws_url,
api_key=api_key,
Expand All @@ -126,7 +135,6 @@ async def connect(
model_name=model_name,
)

# Create HTTP session for file conversions
http_session = aiohttp.ClientSession()

manager = WebRTCManager(config)
Expand All @@ -135,6 +143,7 @@ async def connect(
http_session=http_session,
model_name=model_name,
)
client._audio_stream_manager = audio_stream_manager

config.on_connection_state_change = client._emit_connection_change
config.on_error = lambda error: client._emit_error(WebRTCError(str(error), cause=error))
Expand All @@ -159,6 +168,8 @@ async def connect(
initial_prompt=initial_prompt,
)
except Exception as e:
if audio_stream_manager:
audio_stream_manager.cleanup()
await manager.cleanup()
await http_session.close()
raise WebRTCError(str(e), cause=e)
Expand Down Expand Up @@ -319,6 +330,17 @@ async def set_prompt(
finally:
self._manager.unregister_prompt_wait(prompt)

async def play_audio(self, audio: Union[bytes, str, Path]) -> None:
"""Play audio through the avatar stream. Resolves when audio finishes.

Only available for live_avatar connections without a user-provided audio track.
"""
if self._audio_stream_manager is None:
raise InvalidInputError(
"play_audio() is only available for live_avatar without a user-provided audio track"
)
await self._audio_stream_manager.play_audio(audio)

async def set_image(
self,
image: Optional[FileInput],
Expand Down Expand Up @@ -349,6 +371,9 @@ def get_connection_state(self) -> ConnectionState:
async def disconnect(self) -> None:
self._buffering = False
self._buffer.clear()
if self._audio_stream_manager:
self._audio_stream_manager.cleanup()
self._audio_stream_manager = None
await self._manager.cleanup()
if self._http_session and not self._http_session.closed:
await self._http_session.close()
Expand Down
15 changes: 9 additions & 6 deletions examples/avatar_live.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ async def main():
if audio_file:
print(f"🔊 Audio file: {audio_file}")

# Load audio if provided
audio_track = None

if audio_file:
print("Loading audio file...")
player = MediaPlayer(audio_file)
Expand Down Expand Up @@ -131,12 +131,15 @@ def on_error(error):
print("✓ Connected!")
print(f"Session ID: {realtime_client.session_id}")

if audio_file:
print("\nPlaying audio through avatar...")
print("(The avatar will animate based on the audio)")
if audio_file and not audio_track:
print("\nPlaying audio via play_audio()...")
await realtime_client.play_audio(audio_file)
print("✓ Audio playback complete")
elif audio_file:
print("\nStreaming audio through avatar via MediaStreamTrack...")
else:
print("\nNo audio provided - avatar will be static")
print("You can update the avatar image dynamically using set_image()")
print("\nNo audio provided - avatar will be idle")
print("You can play audio dynamically using play_audio()")

print("\nPress Ctrl+C to stop and save the recording...")

Expand Down
Loading