Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -284,10 +284,17 @@ Or run stages individually:
livekit-wakeword generate configs/prod.yaml # TTS synthesis + adversarial negatives
livekit-wakeword augment configs/prod.yaml # Augment + extract features
livekit-wakeword train configs/prod.yaml # 3-phase adaptive training
livekit-wakeword export configs/prod.yaml # Export to ONNX
livekit-wakeword export configs/prod.yaml # Export to ONNX (default)
livekit-wakeword eval configs/prod.yaml # Evaluate model (DET curve, AUT, FPPH)
```

The export format defaults to ONNX. Pass `--format tflite` (or set `output_format: tflite` in the config) to also emit an [openWakeWord](https://github.com/dscripka/openWakeWord)-compatible TFLite model — this requires the `tflite` extra and currently supports the `dnn` head only. See [Export & Inference](docs/export-and-inference.md#tflite-export-openwakeword-compatible) for details.

```bash
pip install livekit-wakeword[tflite]
livekit-wakeword export configs/prod.yaml --format tflite
```

You can also evaluate any compatible ONNX model (e.g., one trained with openWakeWord):

```bash
Expand Down
48 changes: 43 additions & 5 deletions docs/export-and-inference.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
# Export & Inference

The export stage converts the trained PyTorch classifier to ONNX for deployment. The inference API provides `WakeWordModel` for prediction and `WakeWordListener` for async microphone detection.
The export stage converts the trained PyTorch classifier to ONNX or TFLite for deployment. The inference API provides `WakeWordModel` for prediction and `WakeWordListener` for async microphone detection.

**Source:** `src/livekit/wakeword/export/onnx.py`, `src/livekit/wakeword/inference/model.py`, `src/livekit/wakeword/inference/listener.py`
**CLI:** `livekit-wakeword export <config>`
**Source:** `src/livekit/wakeword/export/onnx.py`, `src/livekit/wakeword/export/tflite.py`, `src/livekit/wakeword/inference/model.py`, `src/livekit/wakeword/inference/listener.py`
**CLI:** `livekit-wakeword export <config> [--format onnx|tflite]`

The output format is chosen by (in priority order) the `--format` flag, then the `output_format` field in the config (defaults to `onnx`).

## ONNX Export

### Classifier Export

`export_classifier()` exports the trained PyTorch classifier head to ONNX format.
`export_onnx()` exports the trained PyTorch classifier head to ONNX format.

| Property | Value |
|----------|-------|
Expand All @@ -34,7 +36,43 @@ livekit-wakeword export configs/hey_jarvis.yaml --quantize

### Export Entry Point

`run_export()` loads the trained model from `output/<model_name>/<model_name>.pt`, exports it to ONNX, and optionally quantizes it. Raises `FileNotFoundError` if the trained model doesn't exist.
`run_export(config, quantize=False, format=None)` loads the trained model from `output/<model_name>/<model_name>.pt`, exports it to ONNX, and optionally quantizes it. `format` defaults to `config.output_format`. ONNX is always produced (it is the conversion source for TFLite); when `format="tflite"`, the TFLite artifact is produced as well and its path is returned. Raises `FileNotFoundError` if the trained model doesn't exist.

## TFLite Export (openWakeWord-compatible)

`export_tflite()` converts an exported ONNX classifier to TFLite via `onnx2tf` (ONNX → TF SavedModel → TFLite), producing an artifact that [openWakeWord](https://github.com/dscripka/openWakeWord) can load directly.

Requires the optional extra:

```bash
uv sync --extra tflite # or: pip install 'livekit-wakeword[tflite]'
```

```bash
livekit-wakeword export configs/hey_jarvis.yaml --format tflite
```

### openWakeWord contract

openWakeWord loads classifier models with `ai_edge_litert.interpreter` and runs them without resizing tensors, so the artifact must satisfy:

| Requirement | Detail |
|-------------|--------|
| Input shape | **Static** `(1, 16, 96)` float32 (no dynamic batch — openWakeWord never calls `resize_tensor_input`) |
| Output shape | `(1, 1)` float32 sigmoid score |
| Ops | **Builtin TFLite ops only** — the LiteRT interpreter has no Flex/SELECT_TF delegate |

We pin the input shape with onnx2tf's `overwrite_input_shape` + `keep_shape_absolutely_input_names` (without the latter, onnx2tf's NCHW→NHWC pass transposes the input to `(1, 96, 16)`) and restrict the converter to `TFLITE_BUILTINS`.

### Head support

| Head | TFLite export | Notes |
|------|---------------|-------|
| `dnn` | Supported | Bit-exact vs ONNX/PyTorch (verified, maxdiff `0.0`) |
| `conv_attention` | Not supported | onnx2tf emits an unsupported constant for the attention block |
| `rnn` | Not supported | LSTM lowers to `TensorList` ops requiring the Flex delegate (which openWakeWord can't load) |

Use `dnn` for openWakeWord-compatible TFLite; deploy `conv_attention`/`rnn` via ONNX. Requesting TFLite for an unsupported head raises `NotImplementedError` before any export work begins.

## Inference API

Expand Down
11 changes: 10 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ readme = "README.md"
requires-python = ">=3.11"
license = "Apache-2.0"
authors = [{ name = "Binh Pham", email = "binh.pham@livekit.io" }]
keywords = ["wake-word", "keyword-spotting", "voice", "speech", "livekit", "onnx"]
keywords = ["wake-word", "keyword-spotting", "voice", "speech", "livekit", "onnx", "tflite"]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
Expand Down Expand Up @@ -64,6 +64,15 @@ export = [
"onnxruntime>=1.17",
"onnxscript>=0.6.2",
]
tflite = [
"onnx2tf>=1.22",
"tensorflow>=2.15",
"tf-keras>=2.15",
"onnx-graphsurgeon>=0.3.27",
"sng4onnx>=1.0.4",
"psutil>=5.9",
"ai-edge-litert>=1.0",
]
[project.urls]
Homepage = "https://github.com/livekit/livekit-wakeword"
Repository = "https://github.com/livekit/livekit-wakeword"
Expand Down
2 changes: 2 additions & 0 deletions src/livekit/wakeword/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# works with only numpy + onnxruntime (no torch, pydantic, etc.).
_LAZY_IMPORTS: dict[str, tuple[str, str]] = {
"WakeWordConfig": (".config", "WakeWordConfig"),
"ExportFormat": (".config", "ExportFormat"),
"load_config": (".config", "load_config"),
"run_augment": (".data.augment", "run_augment"),
"run_extraction": (".data.features", "run_extraction"),
Expand All @@ -31,6 +32,7 @@ def __getattr__(name: str) -> object:

__all__ = [
"WakeWordConfig",
"ExportFormat",
"WakeWordListener",
"WakeWordModel",
"Detection",
Expand Down
31 changes: 25 additions & 6 deletions src/livekit/wakeword/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,17 +313,24 @@ def train(
@app.command()
def export(
config_path: str = typer.Argument(..., help="Path to wake word config YAML"),
format: str = typer.Option(
None,
"--format",
"-f",
help="Export format: onnx or tflite (default: config.output_format)",
),
quantize: bool = typer.Option(False, "--quantize", help="Apply INT8 quantization"),
) -> None:
"""Export trained model to ONNX (optionally quantize for embedded)."""
"""Export trained model (optionally quantize for embedded)."""
config = load_config(config_path)

logger.info(f"Exporting '{config.model_name}' to ONNX...")
fmt = format or config.output_format
logger.info(f"Exporting '{config.model_name}' to {fmt}...")

from .export.onnx import run_export

onnx_path = run_export(config, quantize=quantize)
logger.info(f"Export complete! ONNX model at {onnx_path}")
out_path = run_export(config, quantize=quantize, format=fmt)
logger.info(f"Export complete! Model at {out_path}")


@app.command()
Expand Down Expand Up @@ -361,8 +368,17 @@ def run(
config_path: str = typer.Argument(..., help="Path to wake word config YAML"),
) -> None:
"""Run entire pipeline end-to-end: generate → augment → train → export."""
from .config import ExportFormat

config = load_config(config_path)

# Validate the export target up front so we don't train for hours and then
# fail at the export step (e.g. tflite + an unsupported head).
if config.output_format == ExportFormat.tflite:
from .export.tflite import ensure_tflite_supported

ensure_tflite_supported(config.model.model_type)

logger.info(f"Running full pipeline for '{config.model_name}'...")

from .data.augment import run_augment
Expand All @@ -384,8 +400,11 @@ def run(
logger.info("Step 4/6: Train classifier")
run_train(config)

logger.info("Step 5/6: Export to ONNX")
onnx_path = run_export(config)
logger.info(f"Step 5/6: Export to {config.output_format}")
run_export(config)

# Eval runs on the ONNX artifact (run_export always produces it).
onnx_path = config.model_output_dir / f"{config.model_name}.onnx"

logger.info("Step 6/6: Evaluate model")
results = run_eval(config, onnx_path)
Expand Down
10 changes: 10 additions & 0 deletions src/livekit/wakeword/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ class TtsBackend(StrEnum):
voxcpm = "voxcpm"


class ExportFormat(StrEnum):
"""Artifact format produced by the export stage."""

onnx = "onnx"
tflite = "tflite"


# Preset mapping: size -> (layer_dim, n_blocks)
MODEL_SIZE_PRESETS: dict[ModelSize, tuple[int, int]] = {
ModelSize.tiny: (16, 1),
Expand Down Expand Up @@ -132,6 +139,9 @@ class WakeWordConfig(BaseModel):
data_dir: Annotated[str, Field(description="Root data directory")] = "./data"
output_dir: str = "./output"

# Export
output_format: ExportFormat = ExportFormat.onnx

# Augmentation
augmentation: AugmentationConfig = Field(default_factory=AugmentationConfig)

Expand Down
8 changes: 5 additions & 3 deletions src/livekit/wakeword/export/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""ONNX export and quantization."""
"""Model export (ONNX, TFLite) and quantization."""

from .onnx import export_classifier, quantize_onnx, run_export
from .onnx import export_onnx, quantize_onnx, run_export
from .tflite import export_tflite

__all__ = [
"export_classifier",
"export_onnx",
"export_tflite",
"quantize_onnx",
"run_export",
]
90 changes: 74 additions & 16 deletions src/livekit/wakeword/export/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@
from __future__ import annotations

import logging
import tempfile
from pathlib import Path

import onnx
import torch

from ..config import WakeWordConfig
from ..config import ExportFormat, WakeWordConfig
from ..models.pipeline import WakeWordClassifier

logger = logging.getLogger(__name__)


def export_classifier(
def export_onnx(
config: WakeWordConfig,
model_path: Path,
output_path: Path,
Expand Down Expand Up @@ -60,34 +61,91 @@ def export_classifier(

def quantize_onnx(input_path: Path, output_path: Path | None = None) -> Path:
"""Apply INT8 dynamic quantization to an ONNX model."""
from onnxruntime.quantization import quantize_dynamic, QuantType
from onnxruntime.quantization import QuantType, quantize_dynamic

if output_path is None:
output_path = input_path.with_suffix(".int8.onnx")

quantize_dynamic(
model_input=str(input_path),
model_output=str(output_path),
weight_type=QuantType.QInt8,
)
# The torch dynamo ONNX exporter emits value_info entries describing the
# weight initializers (e.g. a Gemm B of shape [out, in]). When the dynamic
# quantizer rewrites Gemm->MatMul it transposes those weights in place but
# leaves the value_info stale, so its strict shape-inference pass then fails
# with "Inferred shape and existing shape differ". Dropping initializer
# value_info (which is redundant — shapes are inferred from the tensors)
# avoids the conflict. We do this on a temp copy so the input model on disk
# is left untouched.
model = onnx.load(str(input_path))
init_names = {init.name for init in model.graph.initializer}
kept = [vi for vi in model.graph.value_info if vi.name not in init_names]

if len(kept) == len(model.graph.value_info):
# Nothing to strip — quantize the input directly.
quantize_dynamic(
model_input=str(input_path),
model_output=str(output_path),
weight_type=QuantType.QInt8,
)
else:
del model.graph.value_info[:]
model.graph.value_info.extend(kept)
with tempfile.TemporaryDirectory() as tmp_dir:
cleaned = Path(tmp_dir) / "cleaned.onnx"
onnx.save(model, str(cleaned))
quantize_dynamic(
model_input=str(cleaned),
model_output=str(output_path),
weight_type=QuantType.QInt8,
)

logger.info(f"Quantized ONNX model to {output_path}")
return output_path


def run_export(config: WakeWordConfig, quantize: bool = False) -> Path:
"""Export trained model to ONNX."""
def run_export(
config: WakeWordConfig,
quantize: bool = False,
format: ExportFormat | str | None = None,
) -> Path:
"""Export the trained classifier head.

Args:
config: Wake word config.
quantize: Apply INT8 quantization to the exported artifact.
format: Output format (``onnx`` or ``tflite``). Defaults to
``config.output_format`` when ``None``.

Returns:
Path to the primary exported artifact for the chosen format. ONNX is
always produced as well, since TFLite is converted from it.
"""
fmt = ExportFormat(format) if format is not None else config.output_format

# Fail fast on unsupported (head, format) combinations before doing any work.
if fmt == ExportFormat.tflite:
from .tflite import ensure_tflite_supported

ensure_tflite_supported(config.model.model_type)

model_dir = config.model_output_dir
model_path = model_dir / f"{config.model_name}.pt"

if not model_path.exists():
raise FileNotFoundError(f"Trained model not found: {model_path}")

# Export classifier head
# Export classifier head to ONNX (also the conversion source for TFLite).
onnx_path = model_dir / f"{config.model_name}.onnx"
export_classifier(config, model_path, onnx_path)
export_onnx(config, model_path, onnx_path)

if fmt == ExportFormat.onnx:
if quantize:
quantize_onnx(onnx_path)
return onnx_path

if fmt == ExportFormat.tflite:
# TFLite quantization is applied by the TF converter, not the ONNX path.
from .tflite import export_tflite

# Optionally quantize
if quantize:
quantize_onnx(onnx_path)
tflite_path = model_dir / f"{config.model_name}.tflite"
return export_tflite(onnx_path, tflite_path, quantize=quantize)

return onnx_path
raise ValueError(f"Unsupported export format: {fmt}")
Loading
Loading