Skip to content
16 changes: 16 additions & 0 deletions config/canary_streamatt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
type: "simulstream.server.speech_processors.canary_streamatt.CanaryStreamAtt"
Comment thread
azziko marked this conversation as resolved.
model_name: "nvidia/canary-1b-v2"
text_history:
type: "simulstream.server.speech_processors.base_streamatt.FixedWordsTextHistory"
history_words: 10
speech_chunk_size: 0.960 # seconds
detokenizer_type: "canary"
cross_attn_layer: -2
cutoff_frame_num: 8
num_beams: 5
audio_subsampling_factor: 8
audio_history_max_duration: 160 # Maximum length for the audio buffer, in seconds
mel_hop_samples: 160 # Number of audio samples between adjacent mel frames
Comment thread
mgaido91 marked this conversation as resolved.
text_history_max_len: 128
word_level_postprocess: True # Disable if character-level language
use_raw_audio_history: True
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ hf = [

canary = [
"Cython",
"nemo_toolkit[asr]==2.4.0",
"nemo_toolkit[asr]==2.8.0",
]

vad = [
Expand Down
13 changes: 11 additions & 2 deletions simulstream/server/speech_processors/base_streamatt.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ class BaseStreamAtt(BaseSpeechProcessor):
context for next predictions.
- **audio_subsampling_factor (int)**: Subsampling factor of the model, if any.
Defaults to 1.
- **mel_hop_samples (int)**: Number of raw waveform samples per mel frame.
Defaults to 160, i.e. 10ms at 16kHz.
- **use_raw_audio_history (bool)**: Returns whether ``audio_history`` stores raw
waveform samples rather than processed frames. Defaults to False.
- **text_history_max_len (int)**: The maximum length of the textual history after which
the current content is cut. Defaults to 128.
- **cross_attention_layer (int)**: Layer from which to extract the cross-attention from.
Expand All @@ -77,6 +81,11 @@ def __init__(self, config: SimpleNamespace):
text_history_cls = class_load(text_history_config.type)
self.text_history_method = text_history_cls(text_history_config)
self.audio_subsampling_factor = getattr(self.config, "audio_subsampling_factor", 1)
self.mel_hop_samples = getattr(self.config, "mel_hop_samples", 160)
self.use_raw_audio_history = getattr(self.config, "use_raw_audio_history", False)
self.frames_to_audio_history = self.audio_subsampling_factor
if self.use_raw_audio_history:
self.frames_to_audio_history *= self.mel_hop_samples
self.text_history_max_len = getattr(self.config, "text_history_max_len", 128)
self.cross_attn_layer = getattr(self.config, "cross_attention_layer", 3)
self.cutoff_frame_num = getattr(self.config, "cutoff_frame_num", 2)
Expand Down Expand Up @@ -173,8 +182,8 @@ def _update_speech_history(self, discarded_text: int, cross_attn: torch.Tensor)
# Only one token: use the unique most attended frame
earliest_attended_idx = most_attended_idxs[0]

# Multiply by the subsampling factor to recover the original number of frames
frames_to_cut = earliest_attended_idx * self.audio_subsampling_factor
# Multiply by the number of frames/samples corresponding to the audio history
frames_to_cut = earliest_attended_idx * self.frames_to_audio_history

# Cut the unattended audio features
self.audio_history = self.audio_history[frames_to_cut:]
Expand Down
158 changes: 158 additions & 0 deletions simulstream/server/speech_processors/canary_streamatt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright 2025 FBK

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import torch
import numpy as np

from types import SimpleNamespace
from typing import List, Tuple

import copy

from simulstream.server.speech_processors import SAMPLE_RATE
from simulstream.server.speech_processors.base_streamatt import BaseStreamAtt

from nemo.collections.asr.models import ASRModel
from nemo.collections.asr.parts.submodules.multitask_decoding import (
MultiTaskDecodingConfig,
)
from nemo.collections.asr.models.aed_multitask_models import (
MultiTaskTranscriptionConfig,
)


class CanaryStreamAtt(BaseStreamAtt):
"""
StreamAtt policy implementation for NVIDIA's Canary-v2 model.

Args:
config (SimpleNamespace): Configuration object.
Comment thread
azziko marked this conversation as resolved.
Supported attributes:
- **audio_history_max_duration (int)**: Maximum audio history in seconds.
Defaults to ``30``.
- **num_beams (int)**: Number of beams to use for beam search decoding.
Defaults to ``5``.
"""

def __init__(self, config: SimpleNamespace):
super().__init__(config)
self._audio_history_max_duration = getattr(self.config, "audio_history_max_duration", 30)

expected_mel_hop_samples = (
self.model.cfg.preprocessor.window_stride * self.model.cfg.preprocessor.sample_rate
)

assert self.mel_hop_samples == expected_mel_hop_samples, (
f"mel_hop_samples is set to {self.mel_hop_samples} in the config, but the loaded "
f"model's preprocessor uses {expected_mel_hop_samples} samples per mel frame"
)

# Build the transcription config, which is reused for every transcribe() call.
self.transcription_cfg = MultiTaskTranscriptionConfig(
batch_size=1,
return_hypotheses=True,
enable_chunking=False,
verbose=False,
)

@property
def audio_max_len(self) -> int:
"""Maximum audio history length in raw waveform samples."""
return self._audio_history_max_duration * SAMPLE_RATE

def set_source_language(self, language: str) -> None:
self.src_lang = language

def set_target_language(self, language: str) -> None:
self.tgt_lang = language

@classmethod
def load_model(cls, config: SimpleNamespace):
if not hasattr(cls, "model") or cls.model is None:
cls.model = ASRModel.from_pretrained(model_name=config.model_name)

# Configure decoding strategy
multitask_decoding = MultiTaskDecodingConfig()
multitask_decoding.strategy = "beam"
multitask_decoding.return_xattn_scores = True
multitask_decoding.beam.beam_size = getattr(config, "num_beams", 5)
cls.model.change_decoding_strategy(multitask_decoding)

cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
assert cls.model.cfg.preprocessor.sample_rate == SAMPLE_RATE
cls.model.to(cls.device)

def _build_transcription_config(self):
"""
Return a ``MultiTaskTranscriptionConfig`` whose prompt encodes the current source/target
languages, task, PNC preference, and forced decoder prefix.
"""

default_turns = self.model.prompt.get_default_dialog_slots()
default_slots = copy.deepcopy(default_turns[0]["slots"])
default_slots["source_lang"] = self.src_lang
default_slots["target_lang"] = self.tgt_lang

turns = [
{
"role": "user", "slots": default_slots
},
{
"role": "user_prefix",
"slots": {
"prefix": self.model.tokenizer.tokens_to_text(self.text_history)
Comment thread
mgaido91 marked this conversation as resolved.
},
},
]

cfg_copy = copy.deepcopy(self.transcription_cfg)
cfg_copy.prompt = turns

return cfg_copy

def _preprocess(self, waveform: np.ndarray) -> np.ndarray:
"""
Append the incoming waveform chunk to the raw audio history and return it.

Returns:
np.ndarray: Accumulated raw audio history.
"""
waveform = waveform.astype(np.float32)
if self.audio_history is None:
self.audio_history = waveform
else:
self.audio_history = np.concatenate(
[self.audio_history, waveform])

return self.audio_history

def _generate(self, speech: np.ndarray) -> Tuple[List[str], torch.Tensor]:
override_config = self._build_transcription_config()

with torch.inference_mode():
output = self.model.transcribe(audio=speech, override_config=override_config)

hypothesis = output[0]

token_ids = hypothesis.y_sequence.detach().cpu().tolist()
tokens = self.model.tokenizer.ids_to_tokens(token_ids)

xatt_raw = hypothesis.xatt_scores[self.cross_attn_layer]
xatt = xatt_raw.mean(dim=0).cpu() # we average over heads
xatt = self.normalize_attn(xatt)

return tokens, xatt

def tokens_to_string(self, tokens: List[str]) -> str:
return self.model.tokenizer.tokens_to_text(tokens)
68 changes: 67 additions & 1 deletion uts/speech_processors/test_streamatt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,14 @@

import unittest
from types import SimpleNamespace
import torch
import numpy as np
from typing import Dict, List, Tuple, Union

from simulstream.server.speech_processors.base_streamatt import PunctuationTextHistory
from simulstream.server.speech_processors.base_streamatt import (
BaseStreamAtt,
PunctuationTextHistory,
)


class TestPunctuationTextHistory(unittest.TestCase):
Expand Down Expand Up @@ -60,5 +66,65 @@ def test_no_strong_punctuation(self):
self.assertEqual(selected_history, ['回', '到', '纽', '约', '后', ',', '我'])


class FakeStreamAtt(BaseStreamAtt):

def _preprocess(self, waveform: np.float32) -> Union[Dict[str, torch.Tensor], torch.Tensor]:
raise NotImplementedError("_preprocess not implemented in FakeStreamAtt")

@classmethod
def load_model(cls, config: SimpleNamespace):
raise NotImplementedError("load_model not implemented in FakeStreamAtt")

def set_source_language(self, language: str) -> None:
pass

def set_target_language(self, language: str) -> None:
pass

def tokens_to_string(self, tokens: List[str]) -> str:
return " ".join(tokens)

def _generate(self, speech: torch.Tensor) -> Tuple[List[str], torch.Tensor]:
raise NotImplementedError("_generate not implemented in FakeStreamAtt")

@property
def audio_max_len(self) -> float:
return 10000


class TestUpdateSpeechHistory(unittest.TestCase):
def _run_update_speech_history(self, use_raw_audio_history):
config = SimpleNamespace(
use_raw_audio_history=use_raw_audio_history,
audio_subsampling_factor=2,
mel_hop_samples=2,
text_history=SimpleNamespace(
type="simulstream.server.speech_processors.base_streamatt.FixedWordsTextHistory",
)

)
audio = np.arange(40, dtype=np.float32)
proc = FakeStreamAtt(config)
proc.text_history = ["▁hello"]
proc.audio_history = audio.copy()

attn = torch.zeros(2, 10)
attn[1, 2] = 1.0

proc._update_speech_history(discarded_text=1, cross_attn=attn)
return proc.audio_history.tolist()

def test_update_speech_history_trims_audio_with_raw_audio(self):
audio_hist = self._run_update_speech_history(use_raw_audio_history=True)
# 2 audio token discarded, subsampling factor is 2,
# num mel hop is 2, so 2*2*2=8 samples removed
self.assertListEqual(audio_hist, list(np.arange(8, 40, dtype=np.float32)))

def test_update_speech_history_trims_audio(self):
audio_hist = self._run_update_speech_history(use_raw_audio_history=False)
# 2 audio token discarded, subsampling factor is 2, so 2*2=4 samples removed
self.assertListEqual(audio_hist, list(np.arange(4, 40, dtype=np.float32)))


if __name__ == "__main__":
unittest.main()
Loading