Add MIGraphX execution mode for AMD GPUs#2
Conversation
VBx EM iterations now use f64 throughout (gamma, pi, inv_l, alpha, log_p) matching pyannote's numpy default precision. Previously f32 accumulated errors over 20 iterations, producing slightly different gamma weights and therefore different centroids and speaker assignment boundaries. Also replaces NEG_INFINITY masking for inactive speakers with min_score - 1, matching pyannote's constrained_argmax behavior.
PLDA transform now computes eigendecomposition and all matrix ops in f64, matching pyannote's numpy precision. Parameters stored as f64 internally, output cast to f32 for downstream. Adds tracing crate with debug/trace level logging for clustering pipeline (AHC, VBx, centroids, assignments). Investigation found speakrs CPU matches pyannote CPU exactly on all VoxConverse files. The 14.9% vs 14.0% DER gap is from CoreML native embedding model (FP16) producing slightly different vectors on 1/10 files, causing VBx to find 2 speakers instead of 3 on that file.
Remove SPEAKRS_PROFILE env-var gated Instant timing from pipeline.rs in favor of structured tracing::info! logs. Drop overstated "exact parity" / "matches pyannote exactly" claims from inline comments — the algorithm matches but GPU floating-point is non-deterministic.
Clarify bit-exact status: CPU/CUDA match pyannote exactly, CoreML GPU diverges slightly due to floating-point non-determinism. Add DER table from VoxConverse dev set evaluation (14.0% CPU, 14.9% CoreML) with explanation of the CoreML gap.
Benchmarked pyannote-rs on all 216 VoxConverse dev files: - speakrs CoreML: 11.5% DER on the 33 files where pyannote-rs produces output - pyannote-rs: 80.2% DER on those same files (34.9% missed, 37.9% confusion) - pyannote-rs produces 0 segments on 183/216 files Expanded architecture comparison table with verified details from reading pyannote-rs source: raw argmax segmentation, no aggregation, no binarization, cosine 0.5 threshold clustering, CAM++ embeddings. Documented that pyannote-rs README claims ResNet34-LM but their build docs and GitHub release only ship CAM++ (no ONNX export of ResNet34-LM exists on HuggingFace).
Rewrite and clarify README copy, add a Table of Contents, and surface contributing instructions. Key changes: - Rephrased project intro and pipeline description for clarity. - Added a Table of Contents and minor wording/grammar fixes throughout. - Clarified execution-mode differences (coreml-lite tradeoffs, benchmark phrasing). - Explained GPU non-determinism and CoreML vs MPS differences more clearly. - Clarified model download/dev instructions and library usage examples. - Reworded the "Why Not pyannote-rs?" section and expanded notes about pyannote-rs benchmarks and model packaging. - Added a pointer to CONTRIBUTING.md for local setup and tests.
Consolidate ~310 lines of duplicated bash (audio prep, feature mapping, model downloads, benchmarking) into a Rust xtask package. Justfile becomes thin wrappers calling cargo xtask. Absorb compare_rttm binary into xtask.
…for DER Replace O(n!) brute-force speaker mapping in metrics.rs with O(n³) Hungarian algorithm to handle VoxConverse files with 20+ speakers. Port benchmark_diarization, benchmark_comparison, benchmark_der, and fluidaudio_json_to_rttm from Python to Rust xtask. Rename download_models → export_models.
Split the old CoreMlLite mode into CoreMlFast (FP32 + 2s step) and CoreMlFastLite (FP16 + 2s step). Use MLComputeUnits::All instead of CPUAndGPU/CPUAndNeuralEngine. Remove coreml_model_path_f16 in favor of uniform path resolution. Add multi-file batch support to the diarize binary with per-file timing output.
Run all 39 VoxConverse dev files through each execution mode and assert per-file and average DER thresholds. FP32 must stay under 10% avg, Fast under 12%, FastLite under 30%.
Document FP32 vs FP16 divergence on harder files and the decision to support CoreMl (6.8% DER) and CoreMlFastLite (27% DER, 3x faster), dropping the dominated FP16+1s combination.
Thin Swift wrapper around FluidAudio that accepts multiple WAV files and outputs RTTM to stdout, used by the xtask DER benchmark.
Subcommands: setup (rent + provision), benchmark (rsync + run), ssh (interactive session), destroy (teardown). Also converts all mod.rs files to module_name.rs format.
Skip embedding inference for inactive speakers (activity < 10.0 frames), cutting CoreML Fast time from 42s to 27s on VoxConverse — matching FluidAudio while maintaining better DER (9.4% vs 16.1%). Remove CoreMlFastLite execution mode (identical to CoreMlFast after DER fix). Add pluggable dataset system for DER benchmarks with support for VoxConverse dev/test, AMI, AISHELL-4, Earnings-21, and AliMeeting. Includes TextGrid-to-RTTM converter for datasets that ship Praat annotations instead of RTTM.
Writes a README.md in the benchmark run directory and includes the description in results.json and results.txt for tracking what each benchmark run was testing.
Soft decode produced worse DER (+2.1% CoreML, +0.8% CoreML Fast) because soft probability masks degrade WeSpeaker embeddings. Hard decode is the correct approach for this pipeline.
Remove filter_short_embeddings(), reassign_filtered_embeddings(), and FilteredEmbeddings type alias — proved neutral in DER ablation. Extract accumulate_activations() shared by reconstruct() and reconstruct_smoothed().
Replace hand-rolled --mode parsing with clap derive, adding DiarizeMode enum that encapsulates execution mode, step size, and pyannote device selection
Accept multiple wav_files positional args, extract file_id from filename for RTTM output instead of hardcoded "file1"
Replace outdated 10-file benchmarks with full VoxConverse dev/test results showing 6.4% DER matching pyannote MPS
Add VoxConverse dev/test benchmark results with comparison table, support --impl flag to run specific implementations, rename benchmarks output dir to _benchmarks to separate from tracked results
Temporal smoothing is now included in the default CoreML mode, so remove the separate SpeakRs row from benchmark results and update CoreML numbers accordingly. Also fix AMI IHM dataset download URLs to use the working mirror endpoint.
Replace the monolithic QueuedDiarizationPipeline with a sender/receiver pair. QueueSender is Clone, enabling multi-threaded push. QueueReceiver joins the worker thread on drain, distinguishing clean shutdown (Closed) from panics (WorkerPanicked). Removes push_batch and finish in favor of drop-based signaling.
Consolidate documentation in lib.rs doc comments so docs.rs and the README stay in sync via cargo-rdme. Add benchmarks, pipeline diagram, execution mode guidance, and pyannote-rs comparison to crate docs. Move make_exclusive from a free function to a method on DiscreteDiarization. Remove the CLI section (xtask is dev tooling). Fix private doc link warning for PldaTransform.
Bump crate version to 0.3.2 and update dependencies: libloading -> 0.9.0. Update benchmark manifest to use ort = 2.0.0-rc.12 and ndarray = 0.17. Regenerate Cargo.lock to pin upgraded dependency versions across the workspace.
Release 0.4.0: update changelog with new default BLAS behavior (ndarray-linalg defaults to Intel MKL on x86_64, OpenBLAS elsewhere), add explicit feature flags (intel-mkl, openblas-static, openblas-system) for users who disable default features, route PLDA linear algebra through an internal backend shim, update generated docs for backend options, and remove the stale OpenBLAS override from the GPU Docker build. Also bump package version in Cargo.toml and Cargo.lock to 0.4.0.
The `[target.'cfg(target_arch = "x86_64")'.dependencies]` and `[target.'cfg(not(target_arch = "x86_64"))'.dependencies]` tables were declared between `ndarray-npy` and `ort`, so everything after them in the `[dependencies]` section (`ort`, `libloading`, `tracing`, `thiserror`, `crossbeam-channel`, `rayon`) was silently re-scoped into the second target-specific table. On x86_64 those deps became unreachable and `cargo check` failed with "unresolved module or unlinked crate `ort`". Moving the two target tables below the untargeted deps restores the intended scoping without changing which backend is selected on either arch.
Adds an `ExecutionMode::MiGraphX` variant gated behind a new `migraphx` Cargo feature, forwarding to ONNX Runtime's MIGraphX execution provider. Users on AMD GPUs can now select an ORT-accelerated path without touching the existing CUDA or CoreML code paths. * New `migraphx = ["ort/migraphx"]` feature. * `ExecutionMode::MiGraphX` variant plus an `is_migraphx()` helper that mirrors `is_cuda()` / `is_coreml()`. * `validate()` returns the same feature-gated error pattern used by `coreml` and `cuda` when the feature is off. * `with_execution_mode()` attaches the MIGraphX execution provider with device 0 and `SameAsRequested` arena growth. Users who need compiled graph caching can set the ORT-standard env vars (`ORT_MIGRAPHX_LOAD_COMPILED_MODEL`, `ORT_MIGRAPHX_SAVE_COMPILED_MODEL`, `ORT_MIGRAPHX_SAVE_COMPILE_PATH`); no programmatic cache configuration is added here. * `required_files(MiGraphX)` reuses the CPU file set since MIGraphX loads stock ONNX models directly (no split-backend assets). * `segmentation_step_seconds(MiGraphX)` mirrors the CUDA step. * Added `migraphx_mode_requires_feature` unit test following the existing `coreml_modes_require_feature` / `cuda_modes_require_feature` pattern. Verified against VoxConverse (232 files) on an RX 9070 (RDNA 4, gfx1201) using a patched onnxruntime build: 10.65% strict DER at 15.47x realtime. Background and patch set for the RDNA 4 ORT+MIGraphX stack: https://maherr.dev/rdna4-missing-rung/
The inference_path() selector previously routed MIGraphX to Sequential, so segmentation ran to completion before embedding began. The existing run_concurrent_inference machinery (streaming segmentation windows over a bounded crossbeam channel into ConcurrentEmbeddingRunner::run_masked) already handles the MIGraphX EmbeddingPath::Masked case with no MIGraphX-specific gaps. This one-line change adds MiGraphX to the same match arm used for CoreML and CUDA. Measured on an RX 9070 (gfx1201) with the MIGraphX provider built against ORT 1.24.2: - 3-min call, speakrs alone: 12.34s -> 9.1s (-26%) - 20-min VoxConverse file, speakrs alone: 61.87s -> 44.06s (-28.8%) - Gain scales with audio length: segmentation fully overlaps with embedding, so the CPU-side segmentation prelude is absorbed. - Inside a parallel Whisper + speakrs wrapper the end-to-end saving is smaller (~9% on 3-min) because GPU contention on the shared device partially offsets the overlap, but it remains positive. - Segment counts and batching are unchanged (10x32 + 1x11 on the 3-min file, before and after). This is a scheduling change, not a modeling change.
Follow-up commit: enable concurrent inference path for MIGraphXAdded a one-line change on top of the MIGraphX feature commit that routes MIGraphX into the same concurrent inference path CoreML and CUDA already use ( WhyProfiling on an RX 9070 (gfx1201) showed that before this change, segmentation ran to completion as a pure-CPU prelude before the first embedding dispatch, leaving the GPU idle for the full segmentation duration. With concurrent routing, the first embedding_call's Measured impact at scale (full VoxConverse TEST, 232 files, 43.5 h audio)Baseline is the same binary with MIGraphX routed to
Per-file isolated profile data points:
Gain scales with audio length: longer files have more segmentation time for the concurrent path to absorb behind the embedding phase. On the DER driftThe +0.11 pp aggregate drift is not diffuse: 226/232 files produce bit-identical RTTMs within +/-0.1 pp between the two modes. Only 4 files drift more than 0.1 pp worse; 2 of them account for most of the aggregate delta. This is concurrent-path scheduling non-determinism hitting a handful of borderline clustering decisions, not a systemic accuracy regression. Worth surfacing for transparency; well within any reasonable parity threshold. Correctness
Why together with the MIGraphX feature commitThe concurrent path is arguably part of the "MIGraphX works well enough to ship" minimum, and keeping it in the same PR means reviewers only evaluate one |
c2bc262 to
11add70
Compare
|
Superseded by #3. |
Adds an
ExecutionMode::MiGraphXvariant gated behind a newmigraphxCargo feature. The new path forwards to ONNX Runtime's MIGraphX execution provider so users on AMD GPUs can get an ORT-accelerated path without touching CUDA or CoreML code.This is purely additive: existing modes, file layouts, and feature sets are unchanged.
What changed
Cargo.tomlmigraphx = [\"ort/migraphx\"]feature.[target.'cfg(...)'.dependencies]tables forndarray-linalg-defaultbelow the core[dependencies]block. In the current ordering,ort,libloading,tracing,thiserror,crossbeam-channel, andrayonall fall inside[target.'cfg(not(target_arch = \"x86_64\"))'.dependencies], which makescargo checkfail on x86_64 withunresolved module or unlinked crate 'ort'. This is a pre-existing issue unrelated to the MIGraphX feature but the MIGraphX feature can't be verified without it, so it's bundled as a separate commit. Happy to split into a standalone PR if preferred.src/inference.rsExecutionMode::MiGraphXvariant plus anis_migraphx()helper that mirrorsis_cuda()/is_coreml().validate()returns the same feature-gated error pattern used bycoremlandcudawhen the feature is off.with_execution_mode()attaches the MIGraphX EP with device 0 andSameAsRequestedarena growth.migraphx_mode_requires_featureunit test following the existing pattern.src/models.rs/src/pipeline/config.rsrequired_files(MiGraphX)reuses the CPU file set since MIGraphX loads stock ONNX models directly (no split-backend assets needed).segmentation_step_seconds(MiGraphX)mirrors the CUDA step.Both are in match arms that the compiler forces exhaustive, so this is the minimum needed to keep the crate compiling with the new variant.
Compiled graph caching
Not added programmatically. Users who want MIGraphX's on-disk graph cache (important on first run because MIGraphX's compiler is slow) can set the ORT-standard env vars:
ORT_MIGRAPHX_LOAD_COMPILED_MODEL,ORT_MIGRAPHX_SAVE_COMPILED_MODEL,ORT_MIGRAPHX_SAVE_COMPILE_PATH. Happy to add a programmaticwith_migraphx_cache_dir()-style helper in a follow-up if that fits the crate's direction better.Test plan
cargo check --no-default-features --features online,openblas-system,load-dynamicpasses.cargo check --no-default-features --features online,openblas-system,load-dynamic,migraphxpasses.cargo check --no-default-features --features online,openblas-system,load-dynamic,cuda,migraphxpasses.cargo test --no-default-features --features online,openblas-system,load-dynamic --lib inference::testspasses (4 tests:coreml_modes_require_feature,cuda_modes_require_feature,migraphx_mode_requires_feature,dynamic_runtime_preflight_fails_instead_of_hanging).onnxruntimebuild,speakrswith this feature ran the VoxConverse dev set (232 files) at 10.65% strict DER (7.85% under the VoxConverse paper convention with collar=0.25) at 15.47x realtime. Background and patch set for the RDNA 4 ORT build: https://maherr.dev/rdna4-missing-rung/Intentionally not in this PR
Several performance-oriented MIGraphX tweaks from the downstream build were deliberately left out to keep the change small and avoid coupling upstream to AMD-specific tuning. Happy to send any of these as follow-ups if they're wanted:
inference_path).Repo hygiene
Opened as a draft so you can review before it's marked ready. The author email on both commits is the GitHub noreply form to respect the user's email privacy setting.