Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/batdetect2/evaluate/plots/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.metrics.detection import ClipEval
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.detections import plot_clip_detections
from batdetect2.plotting.detections import plot_clip_evaluation
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol
Expand Down Expand Up @@ -276,7 +276,7 @@ def __call__(
fig = self.create_figure()
ax = fig.subplots()

plot_clip_detections(
plot_clip_evaluation(
clip_eval,
ax=ax,
audio_loader=self.audio_loader,
Expand Down
51 changes: 50 additions & 1 deletion src/batdetect2/outputs/formats/batdetect2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from collections import defaultdict
from pathlib import Path
from typing import List, Literal, Sequence, TypedDict, cast

Expand Down Expand Up @@ -93,8 +94,11 @@ def __init__(
def format(
self, predictions: Sequence[ClipDetections]
) -> List[FileAnnotation]:
merged_predictions = merge_clip_detections(predictions)

return [
self.format_prediction(prediction) for prediction in predictions
self.format_prediction(prediction)
for prediction in merged_predictions
]

def save(
Expand Down Expand Up @@ -349,3 +353,48 @@ def from_config(config: BatDetect2OutputConfig, targets: TargetProtocol):
preserve_audio_tree=config.preserve_audio_tree,
include_file_path=config.include_file_path,
)


def merge_clip_detections(
predictions: Sequence[ClipDetections],
) -> List[ClipDetections]:
"""Merge clip predictions into one recording-level prediction.

This intentionally discards the original clip boundaries because the
legacy BatDetect2 file format only stores recording-level detections.
"""
rec_to_clips = defaultdict(list)
rec_mapping = {}

for prediction in predictions:
recording = prediction.clip.recording
key = recording.path
rec_to_clips[key].append(prediction)
rec_mapping[key] = recording

merged_predictions = []
for rec_path, clips in rec_to_clips.items():
recording = rec_mapping[rec_path]
merged_predictions.append(
ClipDetections(
clip=data.Clip(
recording=recording,
start_time=0,
end_time=recording.duration,
),
detections=sorted(
[
detection
for clip_detections in clips
for detection in clip_detections.detections
],
key=lambda detection: (
detection.detection_score,
*compute_bounds(detection.geometry),
),
reverse=True,
),
)
)

return merged_predictions
104 changes: 102 additions & 2 deletions src/batdetect2/plotting/detections.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import numpy as np
from matplotlib import axes, patches
from soundevent.geometry import compute_bounds
from soundevent.plot import plot_geometry

from batdetect2.evaluate.metrics.detection import ClipEval
Expand All @@ -8,13 +10,111 @@
plot_clip,
)
from batdetect2.plotting.common import create_ax
from batdetect2.postprocess import ClipDetections, Detection

__all__ = [
"plot_clip_detections",
"plot_clip_evaluation",
"plot_detection",
]


def plot_clip_detections(
def plot_detection(
detection: Detection,
figsize: tuple[int, int] = (10, 10),
ax: axes.Axes | None = None,
fill: bool = False,
linewidth: float = 1.0,
linestyle: str = "--",
color: str = "red",
show_class: bool = True,
class_names: list[str] | None = None,
fontsize: float | str = "small",
):
ax = create_ax(figsize=figsize, ax=ax)

plot_geometry(
detection.geometry,
ax=ax,
add_points=False,
facecolor="none" if not fill else color,
alpha=detection.detection_score,
linewidth=linewidth,
linestyle=linestyle,
color=color,
)

if not show_class:
return ax

start_time, low_freq, _, _ = compute_bounds(detection.geometry)

top_class = np.argmax(detection.class_scores)
score = detection.class_scores[top_class]

if class_names is not None:
class_name = class_names[top_class]
else:
class_name = f"class {top_class}"

ax.text(
start_time,
low_freq,
f"{class_name}={score:.2f}",
va="top",
ha="left",
color=color,
fontsize=fontsize,
alpha=detection.detection_score,
)
return ax


def plot_clip_detection(
clip_detections: ClipDetections,
figsize: tuple[int, int] = (10, 10),
ax: axes.Axes | None = None,
audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None,
threshold: float | None = None,
spec_cmap: str = "gray",
fill: bool = False,
linewidth: float = 1.0,
linestyle: str = "--",
color: str = "red",
show_class: bool = True,
class_names: list[str] | None = None,
fontsize: float | str = "small",
):
ax = create_ax(figsize=figsize, ax=ax)

plot_clip(
clip_detections.clip,
audio_loader=audio_loader,
preprocessor=preprocessor,
ax=ax,
spec_cmap=spec_cmap,
)

for detection in clip_detections.detections:
if threshold and detection.detection_score < threshold:
continue

ax = plot_detection(
detection,
ax=ax,
class_names=class_names,
fontsize=fontsize,
fill=fill,
linewidth=linewidth,
linestyle=linestyle,
color=color,
show_class=show_class,
)

return ax


def plot_clip_evaluation(
clip_eval: ClipEval,
figsize: tuple[int, int] = (10, 10),
ax: axes.Axes | None = None,
Expand Down
90 changes: 90 additions & 0 deletions tests/test_cli/test_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import json
import shutil
from collections import Counter
from pathlib import Path

from click.testing import CliRunner
from soundevent.geometry import compute_bounds

from batdetect2 import BatDetect2API
from batdetect2.cli import cli


def test_cli_process_directory_merges_clip_outputs_per_recording(
tmp_path: Path,
contrib_dir: Path,
) -> None:
recording_path = contrib_dir / "jeff37" / "0166_20240531_223911.wav"

source_folder = tmp_path / "audio"
source_folder.mkdir()
shutil.copy2(
recording_path,
source_folder / "example_audio.wav",
)

destination_folder = tmp_path / "results"
destination_folder.mkdir()

api = BatDetect2API.from_checkpoint()

api_outputs = api.process_directory(
source_folder,
detection_threshold=0.3,
)

# Get all detections regardless of clip
detections = [
detection
for clip_detections in api_outputs
for detection in clip_detections.detections
]

result = CliRunner().invoke(
cli,
args=[
"process",
"directory",
str(source_folder),
str(destination_folder),
"--detection-threshold",
"0.3",
],
)

assert result.exit_code == 0
assert destination_folder.exists()

output_json = destination_folder / "example_audio.wav.json"
assert output_json.exists()

saved_detections = json.loads(output_json.read_text())

expected_annotations = Counter(
(
round(float(start_time), 4),
round(float(end_time), 4),
int(low_freq),
int(high_freq),
round(float(detection.class_scores.max()), 3),
round(float(detection.detection_score), 3),
)
for detection in detections
for start_time, low_freq, end_time, high_freq in [
compute_bounds(detection.geometry)
]
)

actual_annotations = Counter(
(
annotation["start_time"],
annotation["end_time"],
annotation["low_freq"],
annotation["high_freq"],
annotation["class_prob"],
annotation["det_prob"],
)
for annotation in saved_detections["annotation"]
)

assert actual_annotations == expected_annotations
55 changes: 55 additions & 0 deletions tests/test_inference/test_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from pathlib import Path

import pytest
from soundevent import data

from batdetect2.inference.batch import run_batch_inference
from batdetect2.targets import build_roi_mapping, build_targets
from batdetect2.train import load_model_from_checkpoint
from tests.utils import assert_clip_detections_equal

pytestmark = pytest.mark.slow


def test_run_batch_inference_matches_single_clip_inference(
contrib_dir: Path,
) -> None:
recording = data.Recording.from_file(
contrib_dir / "jeff37" / "0166_20240531_223911.wav"
)
clips = [
data.Clip(recording=recording, start_time=start, end_time=start + 1.0)
for start in (0.0, 1.0, 2.0)
]
model, configs = load_model_from_checkpoint()
targets = build_targets(configs.targets)
roi_mapper = build_roi_mapping(configs.targets.roi)

batched_predictions = run_batch_inference(
model,
clips,
targets=targets,
roi_mapper=roi_mapper,
batch_size=3,
num_workers=0,
)
single_predictions = [
run_batch_inference(
model,
[clip],
targets=targets,
roi_mapper=roi_mapper,
batch_size=1,
num_workers=0,
)[0]
for clip in clips
]

assert len(batched_predictions) == len(single_predictions)

for batched, single in zip(
batched_predictions,
single_predictions,
strict=True,
):
assert_clip_detections_equal(batched, single)
Loading
Loading