Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ gstreamer-base = { version = "0.20", default-features = false }
once_cell = { version = "1", default-features = false, features = ["std"] }
ringbuf = { version = "0.3", default-features = false, features = ["std"] }
webrtc-vad = { version = "0.4", default-features = false }
whisper-rs = { version = "0.8", default-features = false }
whisper-rs = { version = "0.11.1", default-features = false }

[dev-dependencies]
gstreamer-check = { version = "0.20", default-features = false }
Expand Down
26 changes: 23 additions & 3 deletions src/filter/imp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use gstreamer_base::{
use once_cell::sync::Lazy;
use webrtc_vad::{Vad, VadMode};
use whisper_rs::{
convert_integer_to_float_audio, FullParams, SamplingStrategy, WhisperContext, WhisperState,
convert_integer_to_float_audio, FullParams, SamplingStrategy, WhisperContext, WhisperState, WhisperContextParameters,
};

const SAMPLE_RATE: usize = 16_000;
Expand All @@ -42,10 +42,11 @@ const DEFAULT_MIN_VOICE_ACTIVITY_MS: u64 = 200;
const DEFAULT_LANGUAGE: &str = "en";
const DEFAULT_TRANSLATE: bool = false;
const DEFAULT_CONTEXT: bool = true;
const DEFAULT_NUM_THREAD: i32 = 1;

static WHISPER_CONTEXT: Lazy<WhisperContext> = Lazy::new(|| {
let path = env::var("WHISPER_MODEL_PATH").unwrap();
WhisperContext::new(&path).unwrap()
WhisperContext::new_with_params(&path, WhisperContextParameters::default()).unwrap()
});

static CAT: Lazy<DebugCategory> = Lazy::new(|| {
Expand Down Expand Up @@ -74,6 +75,7 @@ struct Settings {
language: String,
translate: bool,
context: bool,
num_thread: i32,
}

struct State {
Expand Down Expand Up @@ -114,12 +116,18 @@ impl WhisperFilter {
}
params.set_translate(settings.translate);
params.set_no_context(!settings.context);
params.set_n_threads(settings.num_thread);
}
params
}

fn run_model(&self, state: &mut State, chunk: Chunk) -> Result<Option<Buffer>, FlowError> {
let samples = convert_integer_to_float_audio(&chunk.buffer);
let mut samples = vec![0.0f32; chunk.buffer.len()];

if let Err(err) = convert_integer_to_float_audio(&chunk.buffer, &mut samples) {
gstreamer::debug!(CAT, "err {:?}", err);
return Ok(None);
}

let start = Instant::now();
state
Expand Down Expand Up @@ -195,6 +203,7 @@ impl ObjectSubclass for WhisperFilter {
language: DEFAULT_LANGUAGE.into(),
translate: DEFAULT_TRANSLATE,
context: DEFAULT_CONTEXT,
num_thread: DEFAULT_NUM_THREAD,
}),
state: Mutex::new(None),
}
Expand Down Expand Up @@ -240,6 +249,13 @@ impl ObjectImpl for WhisperFilter {
.mutable_paused()
.mutable_playing()
.build(),
glib::ParamSpecInt::builder("num-thread")
.nick("Number of thread")
.blurb(&format!("Set the number of threads for inference. Defaults to {}.", DEFAULT_NUM_THREAD))
.mutable_ready()
.mutable_paused()
.mutable_playing()
.build(),
]
});
PROPERTIES.as_ref()
Expand All @@ -263,6 +279,9 @@ impl ObjectImpl for WhisperFilter {
"context" => {
settings.context = value.get().unwrap();
},
"num-thread" => {
settings.num_thread = value.get().unwrap();
},
other => panic!("no such property: {}", other),
}
}
Expand All @@ -275,6 +294,7 @@ impl ObjectImpl for WhisperFilter {
"language" => settings.language.to_value(),
"translate" => settings.translate.to_value(),
"context" => settings.context.to_value(),
"num-thread" => settings.num_thread.to_value(),
other => panic!("no such property: {}", other),
}
}
Expand Down