diff --git a/Cargo.toml b/Cargo.toml index 542b75a..6fbfdd3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ gstreamer-base = { version = "0.20", default-features = false } once_cell = { version = "1", default-features = false, features = ["std"] } ringbuf = { version = "0.3", default-features = false, features = ["std"] } webrtc-vad = { version = "0.4", default-features = false } -whisper-rs = { version = "0.8", default-features = false } +whisper-rs = { version = "0.11.1", default-features = false } [dev-dependencies] gstreamer-check = { version = "0.20", default-features = false } diff --git a/src/filter/imp.rs b/src/filter/imp.rs index 02eff96..2b32b99 100644 --- a/src/filter/imp.rs +++ b/src/filter/imp.rs @@ -32,7 +32,7 @@ use gstreamer_base::{ use once_cell::sync::Lazy; use webrtc_vad::{Vad, VadMode}; use whisper_rs::{ - convert_integer_to_float_audio, FullParams, SamplingStrategy, WhisperContext, WhisperState, + convert_integer_to_float_audio, FullParams, SamplingStrategy, WhisperContext, WhisperState, WhisperContextParameters, }; const SAMPLE_RATE: usize = 16_000; @@ -42,10 +42,11 @@ const DEFAULT_MIN_VOICE_ACTIVITY_MS: u64 = 200; const DEFAULT_LANGUAGE: &str = "en"; const DEFAULT_TRANSLATE: bool = false; const DEFAULT_CONTEXT: bool = true; +const DEFAULT_NUM_THREAD: i32 = 1; static WHISPER_CONTEXT: Lazy = Lazy::new(|| { let path = env::var("WHISPER_MODEL_PATH").unwrap(); - WhisperContext::new(&path).unwrap() + WhisperContext::new_with_params(&path, WhisperContextParameters::default()).unwrap() }); static CAT: Lazy = Lazy::new(|| { @@ -74,6 +75,7 @@ struct Settings { language: String, translate: bool, context: bool, + num_thread: i32, } struct State { @@ -114,12 +116,18 @@ impl WhisperFilter { } params.set_translate(settings.translate); params.set_no_context(!settings.context); + params.set_n_threads(settings.num_thread); } params } fn run_model(&self, state: &mut State, chunk: Chunk) -> Result, FlowError> { - let samples = convert_integer_to_float_audio(&chunk.buffer); + let mut samples = vec![0.0f32; chunk.buffer.len()]; + + if let Err(err) = convert_integer_to_float_audio(&chunk.buffer, &mut samples) { + gstreamer::debug!(CAT, "err {:?}", err); + return Ok(None); + } let start = Instant::now(); state @@ -195,6 +203,7 @@ impl ObjectSubclass for WhisperFilter { language: DEFAULT_LANGUAGE.into(), translate: DEFAULT_TRANSLATE, context: DEFAULT_CONTEXT, + num_thread: DEFAULT_NUM_THREAD, }), state: Mutex::new(None), } @@ -240,6 +249,13 @@ impl ObjectImpl for WhisperFilter { .mutable_paused() .mutable_playing() .build(), + glib::ParamSpecInt::builder("num-thread") + .nick("Number of thread") + .blurb(&format!("Set the number of threads for inference. Defaults to {}.", DEFAULT_NUM_THREAD)) + .mutable_ready() + .mutable_paused() + .mutable_playing() + .build(), ] }); PROPERTIES.as_ref() @@ -263,6 +279,9 @@ impl ObjectImpl for WhisperFilter { "context" => { settings.context = value.get().unwrap(); }, + "num-thread" => { + settings.num_thread = value.get().unwrap(); + }, other => panic!("no such property: {}", other), } } @@ -275,6 +294,7 @@ impl ObjectImpl for WhisperFilter { "language" => settings.language.to_value(), "translate" => settings.translate.to_value(), "context" => settings.context.to_value(), + "num-thread" => settings.num_thread.to_value(), other => panic!("no such property: {}", other), } }