From 5c300d883f98a1f2c8748ed849bff97b9e52a13e Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Mon, 22 Jun 2026 20:11:17 -0600 Subject: [PATCH 1/2] Rename plot_clip_detections to plot_clip_evals and add plot_detection --- src/batdetect2/evaluate/plots/detection.py | 4 +- src/batdetect2/plotting/detections.py | 104 ++++++++++++++++++++- 2 files changed, 104 insertions(+), 4 deletions(-) diff --git a/src/batdetect2/evaluate/plots/detection.py b/src/batdetect2/evaluate/plots/detection.py index a99f8c40..51a8a26c 100644 --- a/src/batdetect2/evaluate/plots/detection.py +++ b/src/batdetect2/evaluate/plots/detection.py @@ -21,7 +21,7 @@ from batdetect2.evaluate.metrics.common import compute_precision_recall from batdetect2.evaluate.metrics.detection import ClipEval from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig -from batdetect2.plotting.detections import plot_clip_detections +from batdetect2.plotting.detections import plot_clip_evaluation from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve from batdetect2.preprocess import PreprocessingConfig, build_preprocessor from batdetect2.preprocess.types import PreprocessorProtocol @@ -276,7 +276,7 @@ def __call__( fig = self.create_figure() ax = fig.subplots() - plot_clip_detections( + plot_clip_evaluation( clip_eval, ax=ax, audio_loader=self.audio_loader, diff --git a/src/batdetect2/plotting/detections.py b/src/batdetect2/plotting/detections.py index 4bbaf274..f25a1112 100644 --- a/src/batdetect2/plotting/detections.py +++ b/src/batdetect2/plotting/detections.py @@ -1,4 +1,6 @@ +import numpy as np from matplotlib import axes, patches +from soundevent.geometry import compute_bounds from soundevent.plot import plot_geometry from batdetect2.evaluate.metrics.detection import ClipEval @@ -8,13 +10,111 @@ plot_clip, ) from batdetect2.plotting.common import create_ax +from batdetect2.postprocess import ClipDetections, Detection __all__ = [ - "plot_clip_detections", + "plot_clip_evaluation", + "plot_detection", ] -def plot_clip_detections( +def plot_detection( + detection: Detection, + figsize: tuple[int, int] = (10, 10), + ax: axes.Axes | None = None, + fill: bool = False, + linewidth: float = 1.0, + linestyle: str = "--", + color: str = "red", + show_class: bool = True, + class_names: list[str] | None = None, + fontsize: float | str = "small", +): + ax = create_ax(figsize=figsize, ax=ax) + + plot_geometry( + detection.geometry, + ax=ax, + add_points=False, + facecolor="none" if not fill else color, + alpha=detection.detection_score, + linewidth=linewidth, + linestyle=linestyle, + color=color, + ) + + if not show_class: + return ax + + start_time, low_freq, _, _ = compute_bounds(detection.geometry) + + top_class = np.argmax(detection.class_scores) + score = detection.class_scores[top_class] + + if class_names is not None: + class_name = class_names[top_class] + else: + class_name = f"class {top_class}" + + ax.text( + start_time, + low_freq, + f"{class_name}={score:.2f}", + va="top", + ha="left", + color=color, + fontsize=fontsize, + alpha=detection.detection_score, + ) + return ax + + +def plot_clip_detection( + clip_detections: ClipDetections, + figsize: tuple[int, int] = (10, 10), + ax: axes.Axes | None = None, + audio_loader: AudioLoader | None = None, + preprocessor: PreprocessorProtocol | None = None, + threshold: float | None = None, + spec_cmap: str = "gray", + fill: bool = False, + linewidth: float = 1.0, + linestyle: str = "--", + color: str = "red", + show_class: bool = True, + class_names: list[str] | None = None, + fontsize: float | str = "small", +): + ax = create_ax(figsize=figsize, ax=ax) + + plot_clip( + clip_detections.clip, + audio_loader=audio_loader, + preprocessor=preprocessor, + ax=ax, + spec_cmap=spec_cmap, + ) + + for detection in clip_detections.detections: + if threshold and detection.detection_score < threshold: + continue + + ax = plot_detection( + detection, + ax=ax, + class_names=class_names, + fontsize=fontsize, + fill=fill, + linewidth=linewidth, + linestyle=linestyle, + color=color, + show_class=show_class, + ) + + return ax + + +def plot_clip_evaluation( clip_eval: ClipEval, figsize: tuple[int, int] = (10, 10), ax: axes.Axes | None = None, From 3b34f467c623da5e83517e680a025496135a47bb Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Mon, 22 Jun 2026 20:29:25 -0600 Subject: [PATCH 2/2] fix: merge clip outputs for batdetect2 format --- src/batdetect2/outputs/formats/batdetect2.py | 51 ++++++++++- tests/test_cli/test_process.py | 90 ++++++++++++++++++++ tests/test_inference/test_batch.py | 55 ++++++++++++ tests/utils.py | 57 +++++++++++++ 4 files changed, 252 insertions(+), 1 deletion(-) create mode 100644 tests/test_cli/test_process.py create mode 100644 tests/test_inference/test_batch.py create mode 100644 tests/utils.py diff --git a/src/batdetect2/outputs/formats/batdetect2.py b/src/batdetect2/outputs/formats/batdetect2.py index d913ea62..09795e1f 100644 --- a/src/batdetect2/outputs/formats/batdetect2.py +++ b/src/batdetect2/outputs/formats/batdetect2.py @@ -1,4 +1,5 @@ import json +from collections import defaultdict from pathlib import Path from typing import List, Literal, Sequence, TypedDict, cast @@ -93,8 +94,11 @@ def __init__( def format( self, predictions: Sequence[ClipDetections] ) -> List[FileAnnotation]: + merged_predictions = merge_clip_detections(predictions) + return [ - self.format_prediction(prediction) for prediction in predictions + self.format_prediction(prediction) + for prediction in merged_predictions ] def save( @@ -349,3 +353,48 @@ def from_config(config: BatDetect2OutputConfig, targets: TargetProtocol): preserve_audio_tree=config.preserve_audio_tree, include_file_path=config.include_file_path, ) + + +def merge_clip_detections( + predictions: Sequence[ClipDetections], +) -> List[ClipDetections]: + """Merge clip predictions into one recording-level prediction. + + This intentionally discards the original clip boundaries because the + legacy BatDetect2 file format only stores recording-level detections. + """ + rec_to_clips = defaultdict(list) + rec_mapping = {} + + for prediction in predictions: + recording = prediction.clip.recording + key = recording.path + rec_to_clips[key].append(prediction) + rec_mapping[key] = recording + + merged_predictions = [] + for rec_path, clips in rec_to_clips.items(): + recording = rec_mapping[rec_path] + merged_predictions.append( + ClipDetections( + clip=data.Clip( + recording=recording, + start_time=0, + end_time=recording.duration, + ), + detections=sorted( + [ + detection + for clip_detections in clips + for detection in clip_detections.detections + ], + key=lambda detection: ( + detection.detection_score, + *compute_bounds(detection.geometry), + ), + reverse=True, + ), + ) + ) + + return merged_predictions diff --git a/tests/test_cli/test_process.py b/tests/test_cli/test_process.py new file mode 100644 index 00000000..99bb88ee --- /dev/null +++ b/tests/test_cli/test_process.py @@ -0,0 +1,90 @@ +import json +import shutil +from collections import Counter +from pathlib import Path + +from click.testing import CliRunner +from soundevent.geometry import compute_bounds + +from batdetect2 import BatDetect2API +from batdetect2.cli import cli + + +def test_cli_process_directory_merges_clip_outputs_per_recording( + tmp_path: Path, + contrib_dir: Path, +) -> None: + recording_path = contrib_dir / "jeff37" / "0166_20240531_223911.wav" + + source_folder = tmp_path / "audio" + source_folder.mkdir() + shutil.copy2( + recording_path, + source_folder / "example_audio.wav", + ) + + destination_folder = tmp_path / "results" + destination_folder.mkdir() + + api = BatDetect2API.from_checkpoint() + + api_outputs = api.process_directory( + source_folder, + detection_threshold=0.3, + ) + + # Get all detections regardless of clip + detections = [ + detection + for clip_detections in api_outputs + for detection in clip_detections.detections + ] + + result = CliRunner().invoke( + cli, + args=[ + "process", + "directory", + str(source_folder), + str(destination_folder), + "--detection-threshold", + "0.3", + ], + ) + + assert result.exit_code == 0 + assert destination_folder.exists() + + output_json = destination_folder / "example_audio.wav.json" + assert output_json.exists() + + saved_detections = json.loads(output_json.read_text()) + + expected_annotations = Counter( + ( + round(float(start_time), 4), + round(float(end_time), 4), + int(low_freq), + int(high_freq), + round(float(detection.class_scores.max()), 3), + round(float(detection.detection_score), 3), + ) + for detection in detections + for start_time, low_freq, end_time, high_freq in [ + compute_bounds(detection.geometry) + ] + ) + + actual_annotations = Counter( + ( + annotation["start_time"], + annotation["end_time"], + annotation["low_freq"], + annotation["high_freq"], + annotation["class_prob"], + annotation["det_prob"], + ) + for annotation in saved_detections["annotation"] + ) + + assert actual_annotations == expected_annotations diff --git a/tests/test_inference/test_batch.py b/tests/test_inference/test_batch.py new file mode 100644 index 00000000..f835f5ed --- /dev/null +++ b/tests/test_inference/test_batch.py @@ -0,0 +1,55 @@ +from pathlib import Path + +import pytest +from soundevent import data + +from batdetect2.inference.batch import run_batch_inference +from batdetect2.targets import build_roi_mapping, build_targets +from batdetect2.train import load_model_from_checkpoint +from tests.utils import assert_clip_detections_equal + +pytestmark = pytest.mark.slow + + +def test_run_batch_inference_matches_single_clip_inference( + contrib_dir: Path, +) -> None: + recording = data.Recording.from_file( + contrib_dir / "jeff37" / "0166_20240531_223911.wav" + ) + clips = [ + data.Clip(recording=recording, start_time=start, end_time=start + 1.0) + for start in (0.0, 1.0, 2.0) + ] + model, configs = load_model_from_checkpoint() + targets = build_targets(configs.targets) + roi_mapper = build_roi_mapping(configs.targets.roi) + + batched_predictions = run_batch_inference( + model, + clips, + targets=targets, + roi_mapper=roi_mapper, + batch_size=3, + num_workers=0, + ) + single_predictions = [ + run_batch_inference( + model, + [clip], + targets=targets, + roi_mapper=roi_mapper, + batch_size=1, + num_workers=0, + )[0] + for clip in clips + ] + + assert len(batched_predictions) == len(single_predictions) + + for batched, single in zip( + batched_predictions, + single_predictions, + strict=True, + ): + assert_clip_detections_equal(batched, single) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 00000000..d4e8854f --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,57 @@ +import numpy as np +from soundevent.geometry import compute_bounds + +from batdetect2.postprocess.types import ClipDetections + + +def assert_clip_detections_equal( + detections: ClipDetections, + other: ClipDetections, +) -> None: + """Assert two clip-detection objects are numerically equivalent.""" + assert detections.clip.recording.path == other.clip.recording.path + assert detections.clip.start_time == other.clip.start_time + assert detections.clip.end_time == other.clip.end_time + assert len(detections.detections) == len(other.detections) + + sorted_detections = sorted( + detections.detections, + key=lambda det: ( + compute_bounds(det.geometry)[0], + compute_bounds(det.geometry)[1], + ), + ) + + sorted_other = sorted( + other.detections, + key=lambda det: ( + compute_bounds(det.geometry)[0], + compute_bounds(det.geometry)[1], + ), + ) + + for det, other_det in zip( + sorted_detections, + sorted_other, + strict=True, + ): + np.testing.assert_allclose( + np.array(compute_bounds(det.geometry)), + np.array(compute_bounds(other_det.geometry)), + atol=2e-2, + ) + assert np.isclose( + det.detection_score, + other_det.detection_score, + atol=1e-6, + ) + np.testing.assert_allclose( + det.class_scores, + other_det.class_scores, + atol=1e-6, + ) + np.testing.assert_allclose( + det.features, + other_det.features, + atol=2e-6, + )