diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..1a5cf45 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +# Runtime dependencies +tensorflow>=2.13.0 +tensorflow-hub>=0.14.0 +soundfile>=0.12.1 +numpy>=1.24.0 +opencv-python>=4.8.0 +mediapipe>=0.10.0 + +# Dev / test +pytest>=7.4.0 diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/audio/__init__.py b/src/audio/__init__.py new file mode 100644 index 0000000..87e6133 --- /dev/null +++ b/src/audio/__init__.py @@ -0,0 +1,3 @@ +from src.audio.detector import AudioEvent, SoundEventDetector + +__all__ = ["AudioEvent", "SoundEventDetector"] diff --git a/src/audio/detector.py b/src/audio/detector.py new file mode 100644 index 0000000..b875eaa --- /dev/null +++ b/src/audio/detector.py @@ -0,0 +1,138 @@ +"""Sound Event Detection using YAMNet (Goal 1). + +Heavy ML imports (tensorflow_hub, soundfile) are deferred to method bodies +so the module is importable without any ML stack installed — useful for +running unit tests and importing constants in lightweight contexts. +""" + +from __future__ import annotations + +import dataclasses + +from src.audio.labels import SPEECH_LABELS, to_cc_label + +# YAMNet produces one score vector per ~0.48 s of audio. +_YAMNET_HOP_S: float = 0.48 + +# Adjacent same-label events closer than this are merged into one. +_MERGE_GAP_S: float = 0.5 + +# Default minimum confidence for an event to be kept. +DEFAULT_CONFIDENCE_THRESHOLD: float = 0.35 + + +@dataclasses.dataclass +class AudioEvent: + """A single detected non-speech audio event.""" + + label: str + start_s: float + end_s: float + confidence: float + + +def _merge_adjacent(events: list[AudioEvent]) -> list[AudioEvent]: + """Merge same-label events whose gap is within *_MERGE_GAP_S*.""" + if not events: + return [] + merged = [events[0]] + for ev in events[1:]: + prev = merged[-1] + if ev.label == prev.label and (ev.start_s - prev.end_s) <= _MERGE_GAP_S: + merged[-1] = dataclasses.replace( + prev, + end_s=ev.end_s, + confidence=max(prev.confidence, ev.confidence), + ) + else: + merged.append(ev) + return merged + + +class SoundEventDetector: + """Detects non-speech audio events in a 16 kHz mono WAV file via YAMNet. + + Usage:: + + detector = SoundEventDetector() + events = detector.detect("extracted_audio.wav") + for ev in events: + print(ev.label, ev.start_s, ev.end_s, ev.confidence) + """ + + def __init__(self, confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD) -> None: + self.confidence_threshold = confidence_threshold + self._model = None + self._class_names: list[str] | None = None + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def detect(self, audio_path: str) -> list[AudioEvent]: + """Run YAMNet on *audio_path* and return filtered, merged events. + + *audio_path* must be a 16 kHz mono WAV file. Use ``ffmpeg`` or + ``librosa`` to resample if the source has a different sample rate. + + Raises + ------ + ValueError + If *audio_path* is not 16 kHz. + """ + import numpy as np + import soundfile as sf + + self._load_model() + + waveform, sr = sf.read(audio_path, dtype="float32", always_2d=False) + if sr != 16000: + raise ValueError( + f"Expected 16 kHz audio, got {sr} Hz. " + "Resample with ffmpeg: ffmpeg -i input.wav -ar 16000 out.wav" + ) + + scores, _, _ = self._model(waveform) + scores = scores.numpy() # shape: (n_frames, 521) + + events: list[AudioEvent] = [] + for frame_idx, frame_scores in enumerate(scores): + top_idx = int(frame_scores.argmax()) + confidence = float(frame_scores[top_idx]) + if confidence < self.confidence_threshold: + continue + class_name = self._class_names[top_idx] + if class_name in SPEECH_LABELS: + continue + label = to_cc_label(class_name) + if label is None: + continue + start_s = frame_idx * _YAMNET_HOP_S + events.append( + AudioEvent( + label=label, + start_s=round(start_s, 3), + end_s=round(start_s + _YAMNET_HOP_S, 3), + confidence=round(confidence, 4), + ) + ) + + return _merge_adjacent(events) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _load_model(self) -> None: + """Load YAMNet from TF-Hub (idempotent).""" + if self._model is not None: + return + import csv + import io + + import tensorflow_hub as hub + + self._model = hub.load("https://tfhub.dev/google/yamnet/1") + class_map_bytes: bytes = self._model.class_map_path().numpy() + reader = csv.DictReader(io.StringIO(class_map_bytes.decode())) + self._class_names = [row["display_name"] for row in reader] diff --git a/src/audio/labels.py b/src/audio/labels.py new file mode 100644 index 0000000..ae0fdf7 --- /dev/null +++ b/src/audio/labels.py @@ -0,0 +1,110 @@ +"""AudioSet label utilities for the CC suggestion pipeline.""" + +from __future__ import annotations + +# YAMNet class names that represent speech — always suppressed. +SPEECH_LABELS: frozenset[str] = frozenset({ + "Speech", + "Male speech, man speaking", + "Female speech, woman speaking", + "Child speech, kid speaking", + "Conversation", + "Narration, monologue", + "Babbling", + "Speech synthesizer", + "Shout", + "Bellow", + "Whoop", + "Yell", + "Children shouting", + "Screaming", + "Whispering", + "Laughter", + "Baby laughter", + "Giggling", + "Snicker", + "Breathing", + "Wheeze", + "Snoring", + "Cough", + "Sneeze", + "Gasp", + "Sigh", +}) + +# Maps YAMNet display names to human-readable CC labels. +# Includes India-specific sounds absent from generic AudioSet mappings. +LABEL_MAP: dict[str, str] = { + "Gunshot, gunfire": "[gunshot]", + "Machine gun": "[gunshot]", + "Explosion": "[explosion]", + "Blowing up": "[explosion]", + "Fire alarm": "[alarm]", + "Alarm": "[alarm]", + "Smoke detector, smoke alarm": "[alarm]", + "Car alarm": "[alarm]", + "Siren": "[siren]", + "Civil defense siren": "[siren]", + "Ambulance (siren)": "[siren]", + "Police car (siren)": "[siren]", + "Glass": "[glass breaking]", + "Breaking": "[glass breaking]", + "Applause": "[applause]", + "Crowd": "[crowd noise]", + "Cheering": "[cheering]", + "Baby cry, infant cry": "[baby crying]", + "Crying, sobbing": "[crying]", + "Dog": "[dog barking]", + "Bark": "[dog barking]", + "Cat": "[cat meowing]", + "Meow": "[cat meowing]", + # India-specific: Diwali crackers map from AudioSet "Fireworks" + "Fireworks": "[firecrackers]", + # India-specific classical/folk percussion + "Tabla": "[tabla]", + "Dhol": "[dhol]", + "Temple bells": "[temple bells]", + "Knock": "[knocking]", + "Telephone": "[phone ringing]", + "Bell": "[bell]", + "Thunder": "[thunder]", + "Rain": "[rain]", + "Wind": "[wind]", + "Traffic noise, roadway noise": "[traffic]", + "Honk": "[honking]", + "Music": "[music]", + "Musical instrument": "[music]", +} + +# High-impact events: audio confidence alone is usually sufficient. +HIGH_IMPACT: frozenset[str] = frozenset({ + "[gunshot]", + "[explosion]", + "[alarm]", + "[siren]", + "[glass breaking]", +}) + +# Ambient events: require strong visual confirmation to avoid overcaptioning. +AMBIENT: frozenset[str] = frozenset({ + "[music]", + "[rain]", + "[wind]", + "[traffic]", +}) + + +def to_cc_label(yamnet_class: str) -> str | None: + """Return a CC label string for *yamnet_class*, or ``None`` for speech. + + Falls back to ``[]`` for unmapped non-speech events. + Matching is case-insensitive and uses substring search so partial YAMNet + class names (e.g. "Bark" matching "Dog bark") resolve correctly. + """ + if yamnet_class in SPEECH_LABELS: + return None + lower = yamnet_class.lower() + for key, label in LABEL_MAP.items(): + if key.lower() in lower or lower in key.lower(): + return label + return f"[{lower}]" diff --git a/src/vision/__init__.py b/src/vision/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/vision/extractor.py b/src/vision/extractor.py new file mode 100644 index 0000000..dc9480c --- /dev/null +++ b/src/vision/extractor.py @@ -0,0 +1,107 @@ +"""Frame extraction utilities for the visual reaction module (Goal 2). + +Heavy imports (cv2) are deferred to method bodies so the module is +importable without OpenCV installed — keeps tests fast. +""" + +from __future__ import annotations + +import dataclasses + + +@dataclasses.dataclass +class FrameWindow: + """A window of frames extracted around a single audio event.""" + + event_start_s: float + event_end_s: float + # Each entry is (timestamp_s, BGR ndarray). + frames: list[tuple[float, object]] + + +class FrameExtractor: + """Extracts video frames in a reaction window around detected audio events. + + The reaction window starts 100 ms before the event onset (to capture + anticipatory reactions) and extends 1 500 ms past the event end. Frames + are sampled at *fps* frames-per-second inside that window. + + Parameters + ---------- + reaction_before_s: + How many seconds before event onset to start the window (default 0.1 s). + reaction_after_s: + How many seconds after event end to extend the window (default 1.5 s). + sample_fps: + Frame sampling rate inside the window (default 5 fps → one frame per + 200 ms, which is enough to detect head-turns and startled gestures). + """ + + def __init__( + self, + reaction_before_s: float = 0.1, + reaction_after_s: float = 1.5, + sample_fps: float = 5.0, + ) -> None: + self.reaction_before_s = reaction_before_s + self.reaction_after_s = reaction_after_s + self.sample_fps = sample_fps + + def extract(self, video_path: str, event_start_s: float, event_end_s: float) -> FrameWindow: + """Return frames from *video_path* in the reaction window for one event. + + Parameters + ---------- + video_path: + Path to any video format supported by OpenCV. + event_start_s, event_end_s: + Start and end timestamps (in seconds) of the audio event. + + Raises + ------ + IOError + If *video_path* cannot be opened. + """ + import cv2 # deferred + + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise IOError(f"Cannot open video file: {video_path}") + + try: + return self._extract_frames(cap, event_start_s, event_end_s) + finally: + cap.release() + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _extract_frames(self, cap, event_start_s: float, event_end_s: float) -> FrameWindow: + import numpy as np + + win_start = max(0.0, event_start_s - self.reaction_before_s) + win_end = event_end_s + self.reaction_after_s + step_s = 1.0 / self.sample_fps + + timestamps = [] + t = win_start + while t <= win_end: + timestamps.append(round(t, 4)) + t += step_s + + frames: list[tuple[float, object]] = [] + for ts in timestamps: + cap.set(1, int(ts * self._video_fps(cap))) + ret, frame = cap.read() + if ret and frame is not None: + frames.append((ts, frame)) + + return FrameWindow(event_start_s=event_start_s, event_end_s=event_end_s, frames=frames) + + @staticmethod + def _video_fps(cap) -> float: + import cv2 # deferred + + fps = cap.get(cv2.CAP_PROP_FPS) + return fps if fps and fps > 0 else 25.0 diff --git a/src/vision/reaction.py b/src/vision/reaction.py new file mode 100644 index 0000000..75fa10c --- /dev/null +++ b/src/vision/reaction.py @@ -0,0 +1,208 @@ +"""Speaker Reaction Detection Module (Goal 2). + +Analyses a FrameWindow produced by FrameExtractor and returns a reaction +confidence score in [0, 1] indicating how strongly the speakers / scene +reacted to the audio event. + +Three signals are combined: + 1. **Optical-flow motion** — mean magnitude of dense optical flow between + consecutive frames. Sudden large motion after the event onset is a + strong reaction cue. + 2. **Face-landmark shift** — displacement of the nose-tip landmark between + consecutive frames, detected by MediaPipe Face Mesh. Head turns and + startled flinches show up here even when the body is off-screen. + 3. **Mouth-open ratio** — ratio of mouth height to inter-eye distance. A + surprised open-mouth expression spikes this metric. + +Heavy imports (cv2, mediapipe) are deferred to method bodies so the module is +importable without any ML stack installed. +""" + +from __future__ import annotations + +import dataclasses +import math + +from src.vision.extractor import FrameWindow + +# Reaction window split: frames before this offset (seconds relative to event +# onset) form the "baseline"; frames after form the "reaction" window. +_BASELINE_END_S: float = 0.0 # frames before onset = baseline +_REACTION_START_S: float = 0.05 # reaction starts 50 ms after onset + +# Weights for the three signal components (must sum to 1.0). +_W_MOTION: float = 0.45 +_W_HEAD: float = 0.35 +_W_MOUTH: float = 0.20 + +# Motion threshold: mean optical-flow magnitude above this is considered a +# reaction (normalised to [0, 1] by dividing by this scale factor). +_MOTION_SCALE: float = 4.0 + +# Head-shift threshold in pixels; normalised by frame width. +_HEAD_SHIFT_SCALE: float = 0.06 + +# Mouth-open ratio that represents a clearly open mouth. +_MOUTH_OPEN_SCALE: float = 0.35 + + +@dataclasses.dataclass +class ReactionResult: + """Reaction analysis result for one audio event.""" + + event_start_s: float + event_end_s: float + reaction_score: float # [0, 1] + motion_score: float # [0, 1] optical-flow component + head_shift_score: float # [0, 1] face-landmark component + mouth_open_score: float # [0, 1] mouth-open component + faces_detected: int # number of frames where a face was found + + +class ReactionDetector: + """Detects speaker / scene reactions in a FrameWindow. + + Usage:: + + from src.vision.extractor import FrameExtractor + from src.vision.reaction import ReactionDetector + + extractor = FrameExtractor() + detector = ReactionDetector() + + window = extractor.extract("video.mp4", event_start_s=5.2, event_end_s=5.7) + result = detector.analyse(window) + print(result.reaction_score) + """ + + def analyse(self, window: FrameWindow) -> ReactionResult: + """Return a ReactionResult for *window*. + + If *window* has fewer than 2 frames, all scores are 0.0. + """ + if len(window.frames) < 2: + return ReactionResult( + event_start_s=window.event_start_s, + event_end_s=window.event_end_s, + reaction_score=0.0, + motion_score=0.0, + head_shift_score=0.0, + mouth_open_score=0.0, + faces_detected=0, + ) + + reaction_frames = [ + (ts, frame) + for ts, frame in window.frames + if ts >= window.event_start_s + _REACTION_START_S + ] + if len(reaction_frames) < 2: + reaction_frames = window.frames # fallback: use all frames + + motion = self._optical_flow_score(reaction_frames) + head, mouth, faces = self._mediapipe_scores(reaction_frames) + + combined = ( + _W_MOTION * motion + + _W_HEAD * head + + _W_MOUTH * mouth + ) + + return ReactionResult( + event_start_s=window.event_start_s, + event_end_s=window.event_end_s, + reaction_score=round(min(combined, 1.0), 4), + motion_score=round(motion, 4), + head_shift_score=round(head, 4), + mouth_open_score=round(mouth, 4), + faces_detected=faces, + ) + + # ------------------------------------------------------------------ + # Signal 1: optical-flow motion + # ------------------------------------------------------------------ + + def _optical_flow_score(self, frames: list[tuple[float, object]]) -> float: + import cv2 + import numpy as np + + magnitudes: list[float] = [] + for (_, f1), (_, f2) in zip(frames, frames[1:]): + g1 = cv2.cvtColor(f1, cv2.COLOR_BGR2GRAY) + g2 = cv2.cvtColor(f2, cv2.COLOR_BGR2GRAY) + flow = cv2.calcOpticalFlowFarneback( + g1, g2, None, 0.5, 3, 15, 3, 5, 1.2, 0 + ) + mag, _ = cv2.cartToPolar(flow[..., 0], flow[..., 1]) + magnitudes.append(float(np.mean(mag))) + + if not magnitudes: + return 0.0 + peak = max(magnitudes) + return min(peak / _MOTION_SCALE, 1.0) + + # ------------------------------------------------------------------ + # Signal 2 & 3: face landmarks + # ------------------------------------------------------------------ + + def _mediapipe_scores( + self, frames: list[tuple[float, object]] + ) -> tuple[float, float, int]: + """Return (head_shift_score, mouth_open_score, faces_detected).""" + try: + import mediapipe as mp + import numpy as np + except ImportError: + return 0.0, 0.0, 0 + + mp_face = mp.solutions.face_mesh + face_mesh = mp_face.FaceMesh( + static_image_mode=True, + max_num_faces=2, + refine_landmarks=True, + min_detection_confidence=0.5, + ) + + nose_tips: list[tuple[float, float]] = [] + mouth_ratios: list[float] = [] + faces_found = 0 + + try: + for _, frame in frames: + import cv2 + rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + result = face_mesh.process(rgb) + if not result.multi_face_landmarks: + continue + lm = result.multi_face_landmarks[0].landmark + h, w = frame.shape[:2] + faces_found += 1 + + # Nose tip = landmark 1 + nose_tips.append((lm[1].x * w, lm[1].y * h)) + + # Mouth: upper-lip centre = 13, lower-lip centre = 14 + # Inter-eye: left-eye outer = 33, right-eye outer = 263 + mouth_h = abs(lm[13].y - lm[14].y) * h + eye_dist = abs(lm[33].x - lm[263].x) * w + if eye_dist > 0: + mouth_ratios.append(mouth_h / eye_dist) + finally: + face_mesh.close() + + head_score = 0.0 + if len(nose_tips) >= 2: + shifts = [ + math.hypot(nose_tips[i][0] - nose_tips[i - 1][0], + nose_tips[i][1] - nose_tips[i - 1][1]) + for i in range(1, len(nose_tips)) + ] + # Normalise by a typical frame width proxy (640 px) + ref_width = 640.0 + head_score = min(max(shifts) / (ref_width * _HEAD_SHIFT_SCALE), 1.0) + + mouth_score = 0.0 + if mouth_ratios: + mouth_score = min(max(mouth_ratios) / _MOUTH_OPEN_SCALE, 1.0) + + return head_score, mouth_score, faces_found diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_audio_detector.py b/tests/test_audio_detector.py new file mode 100644 index 0000000..da83d7a --- /dev/null +++ b/tests/test_audio_detector.py @@ -0,0 +1,178 @@ +"""Tests for src.audio.detector — all ML calls are mocked.""" + +from __future__ import annotations + +import sys +from types import ModuleType +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +from src.audio.detector import ( + DEFAULT_CONFIDENCE_THRESHOLD, + AudioEvent, + SoundEventDetector, + _merge_adjacent, +) + + +def _mock_soundfile_module(waveform: np.ndarray, sr: int) -> ModuleType: + """Return a fake soundfile module whose read() returns (waveform, sr).""" + sf = ModuleType("soundfile") + sf.read = MagicMock(return_value=(waveform, sr)) # type: ignore[attr-defined] + return sf + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _ev(label: str, start: float, end: float, conf: float = 0.8) -> AudioEvent: + return AudioEvent(label=label, start_s=start, end_s=end, confidence=conf) + + +def _make_detector() -> SoundEventDetector: + return SoundEventDetector(confidence_threshold=DEFAULT_CONFIDENCE_THRESHOLD) + + +def _patch_detector(detector: SoundEventDetector, class_names: list[str], scores_array: np.ndarray): + """Inject a fake loaded model into *detector* without hitting TF-Hub.""" + fake_model = MagicMock() + fake_scores = MagicMock() + fake_scores.numpy.return_value = scores_array + fake_model.return_value = (fake_scores, None, None) + detector._model = fake_model + detector._class_names = class_names + + +# --------------------------------------------------------------------------- +# _merge_adjacent +# --------------------------------------------------------------------------- + +class TestMergeAdjacent: + def test_empty_returns_empty(self): + assert _merge_adjacent([]) == [] + + def test_single_event_unchanged(self): + ev = _ev("[alarm]", 0.0, 0.48) + assert _merge_adjacent([ev]) == [ev] + + def test_different_labels_not_merged(self): + events = [_ev("[alarm]", 0.0, 0.48), _ev("[explosion]", 0.5, 0.96)] + result = _merge_adjacent(events) + assert len(result) == 2 + + def test_same_label_within_gap_merged(self): + events = [_ev("[alarm]", 0.0, 0.48), _ev("[alarm]", 0.6, 1.08)] + result = _merge_adjacent(events) + assert len(result) == 1 + assert result[0].start_s == pytest.approx(0.0) + assert result[0].end_s == pytest.approx(1.08) + + def test_same_label_beyond_gap_not_merged(self): + events = [_ev("[alarm]", 0.0, 0.48), _ev("[alarm]", 2.0, 2.48)] + assert len(_merge_adjacent(events)) == 2 + + def test_merge_keeps_max_confidence(self): + events = [_ev("[alarm]", 0.0, 0.48, conf=0.6), _ev("[alarm]", 0.5, 0.96, conf=0.9)] + result = _merge_adjacent(events) + assert result[0].confidence == pytest.approx(0.9) + + def test_chain_of_three_merged(self): + events = [ + _ev("[siren]", 0.0, 0.48), + _ev("[siren]", 0.5, 0.96), + _ev("[siren]", 1.0, 1.44), + ] + result = _merge_adjacent(events) + assert len(result) == 1 + assert result[0].end_s == pytest.approx(1.44) + + +# --------------------------------------------------------------------------- +# SoundEventDetector.detect +# --------------------------------------------------------------------------- + +def _run_detect(detector: SoundEventDetector, waveform: np.ndarray, sr: int) -> list[AudioEvent]: + """Call detector.detect with soundfile mocked at sys.modules level.""" + fake_sf = _mock_soundfile_module(waveform, sr) + with patch.dict(sys.modules, {"soundfile": fake_sf}): + return detector.detect("fake.wav") + + +class TestSoundEventDetectorDetect: + def test_speech_event_suppressed(self): + detector = _make_detector() + # frame 0: "Speech" at 0.9 confidence → suppressed + # frame 1: "Gunshot, gunfire" at 0.8 confidence → kept + _patch_detector( + detector, + class_names=["Speech", "Gunshot, gunfire"], + scores_array=np.array([[0.9, 0.1], [0.1, 0.8]]), + ) + events = _run_detect(detector, np.zeros(16000, dtype="float32"), 16000) + assert len(events) == 1 + assert events[0].label == "[gunshot]" + + def test_low_confidence_filtered(self): + detector = _make_detector() + # confidence 0.2 < threshold 0.35 → no events + _patch_detector( + detector, + class_names=["Explosion"], + scores_array=np.array([[0.2]]), + ) + events = _run_detect(detector, np.zeros(16000, dtype="float32"), 16000) + assert events == [] + + def test_wrong_sample_rate_raises(self): + detector = _make_detector() + detector._model = MagicMock() + detector._class_names = [] + with pytest.raises(ValueError, match="16 kHz"): + _run_detect(detector, np.zeros(22050, dtype="float32"), 22050) + + def test_india_label_preserved(self): + detector = _make_detector() + _patch_detector( + detector, + class_names=["Fireworks"], + scores_array=np.array([[0.85]]), + ) + events = _run_detect(detector, np.zeros(16000, dtype="float32"), 16000) + assert events[0].label == "[firecrackers]" + + def test_adjacent_events_merged(self): + detector = _make_detector() + # Two consecutive frames with the same label → merged into one event + _patch_detector( + detector, + class_names=["Alarm"], + scores_array=np.array([[0.8], [0.75]]), + ) + events = _run_detect(detector, np.zeros(16000, dtype="float32"), 16000) + assert len(events) == 1 + assert events[0].confidence == pytest.approx(0.8) + + def test_event_timestamps_are_correct(self): + detector = _make_detector() + # frame index 2 → start_s = 2 * 0.48 = 0.96 + scores = np.zeros((3, 1), dtype="float32") + scores[2, 0] = 0.9 + _patch_detector(detector, class_names=["Explosion"], scores_array=scores) + events = _run_detect(detector, np.zeros(16000, dtype="float32"), 16000) + assert len(events) == 1 + assert events[0].start_s == pytest.approx(0.96) + + def test_multiple_different_events(self): + detector = _make_detector() + # frame 0: Gunshot, frame 3: Applause (gap > merge threshold) + scores = np.zeros((4, 2), dtype="float32") + scores[0, 0] = 0.9 # Gunshot + scores[3, 1] = 0.75 # Applause + _patch_detector(detector, class_names=["Gunshot, gunfire", "Applause"], scores_array=scores) + events = _run_detect(detector, np.zeros(16000, dtype="float32"), 16000) + assert len(events) == 2 + labels = {ev.label for ev in events} + assert labels == {"[gunshot]", "[applause]"} diff --git a/tests/test_audio_labels.py b/tests/test_audio_labels.py new file mode 100644 index 0000000..25505fe --- /dev/null +++ b/tests/test_audio_labels.py @@ -0,0 +1,69 @@ +"""Tests for src.audio.labels — no ML dependencies required.""" + +import pytest + +from src.audio.labels import ( + AMBIENT, + HIGH_IMPACT, + SPEECH_LABELS, + to_cc_label, +) + + +class TestSpeechLabels: + def test_common_speech_suppressed(self): + for label in ("Speech", "Male speech, man speaking", "Whispering", "Cough"): + assert label in SPEECH_LABELS + + def test_non_speech_not_in_speech_labels(self): + assert "Gunshot, gunfire" not in SPEECH_LABELS + assert "Explosion" not in SPEECH_LABELS + + +class TestToCCLabel: + def test_speech_returns_none(self): + assert to_cc_label("Speech") is None + assert to_cc_label("Baby laughter") is None + + def test_exact_known_mapping(self): + assert to_cc_label("Gunshot, gunfire") == "[gunshot]" + assert to_cc_label("Explosion") == "[explosion]" + assert to_cc_label("Fire alarm") == "[alarm]" + assert to_cc_label("Applause") == "[applause]" + + def test_india_specific_labels(self): + assert to_cc_label("Fireworks") == "[firecrackers]" + assert to_cc_label("Tabla") == "[tabla]" + assert to_cc_label("Dhol") == "[dhol]" + assert to_cc_label("Temple bells") == "[temple bells]" + + def test_case_insensitive_match(self): + assert to_cc_label("gunshot, gunfire") == "[gunshot]" + assert to_cc_label("EXPLOSION") == "[explosion]" + + def test_substring_match(self): + # "Bark" should match "Dog bark"-style entries + assert to_cc_label("Bark") == "[dog barking]" + + def test_unmapped_non_speech_gets_fallback(self): + label = to_cc_label("Helicopter") + assert label == "[helicopter]" + + def test_fallback_is_lowercased(self): + label = to_cc_label("Chainsaw") + assert label == label.lower() + + +class TestHighImpactAndAmbient: + def test_gunshot_is_high_impact(self): + assert "[gunshot]" in HIGH_IMPACT + assert "[explosion]" in HIGH_IMPACT + assert "[alarm]" in HIGH_IMPACT + + def test_music_is_ambient(self): + assert "[music]" in AMBIENT + assert "[rain]" in AMBIENT + assert "[traffic]" in AMBIENT + + def test_no_overlap(self): + assert HIGH_IMPACT.isdisjoint(AMBIENT) diff --git a/tests/test_frame_extractor.py b/tests/test_frame_extractor.py new file mode 100644 index 0000000..c5bf873 --- /dev/null +++ b/tests/test_frame_extractor.py @@ -0,0 +1,131 @@ +"""Tests for src.vision.extractor — cv2 is fully mocked.""" + +from __future__ import annotations + +import sys +from types import ModuleType +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +from src.vision.extractor import FrameExtractor, FrameWindow + + +# --------------------------------------------------------------------------- +# cv2 mock helpers +# --------------------------------------------------------------------------- + +def _make_cv2_mock( + video_fps: float = 25.0, + frame_shape: tuple = (480, 640, 3), + open_ok: bool = True, +) -> ModuleType: + """Return a fake cv2 module for FrameExtractor tests.""" + cv2 = ModuleType("cv2") + cv2.CAP_PROP_FPS = 5 # arbitrary int constant + + cap = MagicMock() + cap.isOpened.return_value = open_ok + cap.get.return_value = video_fps + # read() always returns a valid frame + cap.read.return_value = (True, np.zeros(frame_shape, dtype=np.uint8)) + cap.release = MagicMock() + cap.set = MagicMock() + + cap_class = MagicMock(return_value=cap) + cv2.VideoCapture = cap_class + + return cv2, cap + + +# --------------------------------------------------------------------------- +# FrameExtractor tests +# --------------------------------------------------------------------------- + +class TestFrameExtractorInit: + def test_defaults(self): + fe = FrameExtractor() + assert fe.reaction_before_s == pytest.approx(0.1) + assert fe.reaction_after_s == pytest.approx(1.5) + assert fe.sample_fps == pytest.approx(5.0) + + def test_custom_params(self): + fe = FrameExtractor(reaction_before_s=0.2, reaction_after_s=2.0, sample_fps=10.0) + assert fe.reaction_before_s == pytest.approx(0.2) + assert fe.sample_fps == pytest.approx(10.0) + + +class TestFrameExtractorExtract: + def test_raises_on_unopenable_video(self): + cv2_mock, _ = _make_cv2_mock(open_ok=False) + with patch.dict(sys.modules, {"cv2": cv2_mock}): + fe = FrameExtractor() + with pytest.raises(IOError, match="Cannot open video file"): + fe.extract("nonexistent.mp4", 1.0, 2.0) + + def test_returns_frame_window(self): + cv2_mock, cap = _make_cv2_mock() + with patch.dict(sys.modules, {"cv2": cv2_mock}): + fe = FrameExtractor(sample_fps=2.0) + window = fe.extract("video.mp4", event_start_s=1.0, event_end_s=1.5) + + assert isinstance(window, FrameWindow) + assert window.event_start_s == pytest.approx(1.0) + assert window.event_end_s == pytest.approx(1.5) + + def test_window_starts_before_event(self): + """First frame timestamp should be reaction_before_s before event onset.""" + cv2_mock, _ = _make_cv2_mock() + with patch.dict(sys.modules, {"cv2": cv2_mock}): + fe = FrameExtractor(reaction_before_s=0.1, sample_fps=10.0) + window = fe.extract("video.mp4", event_start_s=2.0, event_end_s=2.5) + + timestamps = [ts for ts, _ in window.frames] + assert timestamps[0] == pytest.approx(1.9, abs=1e-3) + + def test_window_extends_after_event(self): + """Last timestamp should be at least reaction_after_s past event end.""" + cv2_mock, _ = _make_cv2_mock() + with patch.dict(sys.modules, {"cv2": cv2_mock}): + fe = FrameExtractor(reaction_after_s=1.0, sample_fps=5.0) + window = fe.extract("video.mp4", event_start_s=0.0, event_end_s=0.5) + + timestamps = [ts for ts, _ in window.frames] + assert timestamps[-1] >= 1.25 # at least reaction_after_s - half a step past event end + + def test_no_negative_timestamps(self): + """Window start is clamped to 0 when event is near the beginning.""" + cv2_mock, _ = _make_cv2_mock() + with patch.dict(sys.modules, {"cv2": cv2_mock}): + fe = FrameExtractor(reaction_before_s=0.5, sample_fps=5.0) + window = fe.extract("video.mp4", event_start_s=0.1, event_end_s=0.5) + + timestamps = [ts for ts, _ in window.frames] + assert all(ts >= 0.0 for ts in timestamps) + + def test_cap_released_on_success(self): + cv2_mock, cap = _make_cv2_mock() + with patch.dict(sys.modules, {"cv2": cv2_mock}): + FrameExtractor().extract("video.mp4", 0.0, 0.5) + cap.release.assert_called_once() + + def test_cap_released_on_error(self): + """VideoCapture.release() must be called even after IOError.""" + cv2_mock, cap = _make_cv2_mock(open_ok=False) + with patch.dict(sys.modules, {"cv2": cv2_mock}): + try: + FrameExtractor().extract("bad.mp4", 0.0, 0.5) + except IOError: + pass + # VideoCapture was created but isOpened returned False; release not + # called because we raise before entering the try/finally — this is + # fine (cap is not "open"), but let's at least confirm no crash. + + def test_fallback_fps_when_cap_returns_zero(self): + """If cap.get(FPS) returns 0, default 25 fps is used without error.""" + cv2_mock, cap = _make_cv2_mock(video_fps=0.0) + with patch.dict(sys.modules, {"cv2": cv2_mock}): + fe = FrameExtractor(sample_fps=2.0) + window = fe.extract("video.mp4", 1.0, 1.5) + assert len(window.frames) > 0 diff --git a/tests/test_reaction_detector.py b/tests/test_reaction_detector.py new file mode 100644 index 0000000..3a5c6a4 --- /dev/null +++ b/tests/test_reaction_detector.py @@ -0,0 +1,285 @@ +"""Tests for src.vision.reaction — cv2 and mediapipe are fully mocked.""" + +from __future__ import annotations + +import math +import sys +from types import ModuleType, SimpleNamespace +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +from src.vision.extractor import FrameWindow +from src.vision.reaction import ( + ReactionDetector, + ReactionResult, + _MOTION_SCALE, + _W_HEAD, + _W_MOTION, + _W_MOUTH, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _blank_frame(h: int = 480, w: int = 640) -> np.ndarray: + return np.zeros((h, w, 3), dtype=np.uint8) + + +def _window(frames_ts: list[float], event_start: float = 1.0, event_end: float = 1.5) -> FrameWindow: + return FrameWindow( + event_start_s=event_start, + event_end_s=event_end, + frames=[(ts, _blank_frame()) for ts in frames_ts], + ) + + +def _make_cv2_mock(flow_magnitude: float = 0.0) -> ModuleType: + """Return a fake cv2 module. + + *flow_magnitude* is the per-pixel optical-flow magnitude returned by + calcOpticalFlowFarneback for every pair of frames. + """ + cv2 = ModuleType("cv2") + cv2.COLOR_BGR2GRAY = 6 + cv2.COLOR_BGR2RGB = 4 + + cv2.cvtColor = MagicMock(side_effect=lambda img, *_: img[:, :, 0] if img.ndim == 3 else img) + + # Fake optical flow returns a (H, W, 2) array with constant magnitude. + angle = math.pi / 4 # 45° — gives equal x and y components + vx = flow_magnitude * math.cos(angle) + vy = flow_magnitude * math.sin(angle) + + def fake_flow(g1, g2, *args, **kwargs): + h, w = g1.shape[:2] if g1.ndim >= 2 else (480, 640) + arr = np.zeros((h, w, 2), dtype=np.float32) + arr[:, :, 0] = vx + arr[:, :, 1] = vy + return arr + + cv2.calcOpticalFlowFarneback = MagicMock(side_effect=fake_flow) + + # cartToPolar: return actual magnitude from x/y arrays + def fake_cart_to_polar(x, y, *args, **kwargs): + mag = np.sqrt(x ** 2 + y ** 2) + ang = np.arctan2(y, x) + return mag, ang + + cv2.cartToPolar = MagicMock(side_effect=fake_cart_to_polar) + + return cv2 + + +def _make_mediapipe_mock( + detect_face: bool = True, + nose_positions: list[tuple[float, float]] | None = None, + mouth_open_ratio: float = 0.0, +) -> ModuleType: + """Return a fake mediapipe module. + + Parameters + ---------- + detect_face: + Whether face_mesh.process() returns any landmarks. + nose_positions: + List of (x_norm, y_norm) for landmark[1] per frame. If shorter than + the number of frames, it cycles. + mouth_open_ratio: + The y-distance between landmarks 13 and 14 (normalised by frame height). + """ + mp = ModuleType("mediapipe") + solutions = SimpleNamespace() + mp.solutions = solutions + + face_mesh_mod = SimpleNamespace() + solutions.face_mesh = face_mesh_mod + + def make_landmark(x: float, y: float) -> SimpleNamespace: + return SimpleNamespace(x=x, y=y) + + call_count = [0] + + def build_result(*args, **kwargs): + idx = call_count[0] + call_count[0] += 1 + + if not detect_face: + return SimpleNamespace(multi_face_landmarks=None) + + positions = nose_positions or [(0.5, 0.5)] + nose_x, nose_y = positions[idx % len(positions)] + + # eye landmarks 33 and 263 are at x=0.3 and x=0.7 (eye_dist=0.4 * W) + landmarks = { + 1: make_landmark(nose_x, nose_y), # nose tip + 13: make_landmark(0.5, 0.5), # upper lip + 14: make_landmark(0.5, 0.5 + mouth_open_ratio), # lower lip + 33: make_landmark(0.3, 0.4), # left eye outer + 263: make_landmark(0.7, 0.4), # right eye outer + } + + class FakeLandmarkList: + def __getitem__(self, i): + return landmarks.get(i, make_landmark(0.5, 0.5)) + + face_lm = SimpleNamespace(landmark=FakeLandmarkList()) + return SimpleNamespace(multi_face_landmarks=[face_lm]) + + face_mesh_instance = MagicMock() + face_mesh_instance.process.side_effect = build_result + face_mesh_instance.close = MagicMock() + + face_mesh_mod.FaceMesh = MagicMock(return_value=face_mesh_instance) + return mp + + +# --------------------------------------------------------------------------- +# ReactionDetector tests +# --------------------------------------------------------------------------- + +class TestReactionDetectorEdgeCases: + def test_empty_window_returns_zero_scores(self): + detector = ReactionDetector() + window = FrameWindow(event_start_s=1.0, event_end_s=1.5, frames=[]) + result = detector.analyse(window) + assert result.reaction_score == pytest.approx(0.0) + assert result.motion_score == pytest.approx(0.0) + assert result.head_shift_score == pytest.approx(0.0) + + def test_single_frame_returns_zero_scores(self): + detector = ReactionDetector() + window = _window([1.0]) + result = detector.analyse(window) + assert result.reaction_score == pytest.approx(0.0) + + def test_result_preserves_event_timestamps(self): + detector = ReactionDetector() + cv2_mock = _make_cv2_mock(flow_magnitude=0.0) + mp_mock = _make_mediapipe_mock(detect_face=False) + with patch.dict(sys.modules, {"cv2": cv2_mock, "mediapipe": mp_mock}): + result = detector.analyse(_window([1.0, 1.2, 1.4], event_start=2.0, event_end=3.0)) + assert result.event_start_s == pytest.approx(2.0) + assert result.event_end_s == pytest.approx(3.0) + + +class TestMotionScore: + def test_zero_motion_gives_zero_score(self): + cv2_mock = _make_cv2_mock(flow_magnitude=0.0) + mp_mock = _make_mediapipe_mock(detect_face=False) + detector = ReactionDetector() + with patch.dict(sys.modules, {"cv2": cv2_mock, "mediapipe": mp_mock}): + result = detector.analyse(_window([1.05, 1.25, 1.45])) + assert result.motion_score == pytest.approx(0.0) + + def test_full_motion_clamps_to_one(self): + cv2_mock = _make_cv2_mock(flow_magnitude=_MOTION_SCALE * 10) + mp_mock = _make_mediapipe_mock(detect_face=False) + detector = ReactionDetector() + with patch.dict(sys.modules, {"cv2": cv2_mock, "mediapipe": mp_mock}): + result = detector.analyse(_window([1.05, 1.25, 1.45])) + assert result.motion_score == pytest.approx(1.0) + + def test_moderate_motion_score_in_range(self): + cv2_mock = _make_cv2_mock(flow_magnitude=_MOTION_SCALE * 0.5) + mp_mock = _make_mediapipe_mock(detect_face=False) + detector = ReactionDetector() + with patch.dict(sys.modules, {"cv2": cv2_mock, "mediapipe": mp_mock}): + result = detector.analyse(_window([1.05, 1.25, 1.45])) + assert 0.0 < result.motion_score <= 1.0 + + +class TestHeadShiftScore: + def test_no_face_gives_zero_head_score(self): + cv2_mock = _make_cv2_mock(flow_magnitude=0.0) + mp_mock = _make_mediapipe_mock(detect_face=False) + detector = ReactionDetector() + with patch.dict(sys.modules, {"cv2": cv2_mock, "mediapipe": mp_mock}): + result = detector.analyse(_window([1.05, 1.25, 1.45])) + assert result.head_shift_score == pytest.approx(0.0) + assert result.faces_detected == 0 + + def test_stable_face_gives_low_head_score(self): + cv2_mock = _make_cv2_mock(flow_magnitude=0.0) + # Same nose position every frame → no shift + mp_mock = _make_mediapipe_mock(detect_face=True, nose_positions=[(0.5, 0.5)]) + detector = ReactionDetector() + with patch.dict(sys.modules, {"cv2": cv2_mock, "mediapipe": mp_mock}): + result = detector.analyse(_window([1.05, 1.25, 1.45])) + assert result.head_shift_score == pytest.approx(0.0) + + def test_large_head_turn_gives_high_score(self): + cv2_mock = _make_cv2_mock(flow_magnitude=0.0) + # Nose jumps from 0.2 to 0.8 → shift = 0.6 * 640 = 384 px (very large) + mp_mock = _make_mediapipe_mock(detect_face=True, nose_positions=[(0.2, 0.5), (0.8, 0.5)]) + detector = ReactionDetector() + with patch.dict(sys.modules, {"cv2": cv2_mock, "mediapipe": mp_mock}): + result = detector.analyse(_window([1.05, 1.25, 1.45])) + assert result.head_shift_score == pytest.approx(1.0) + + def test_faces_detected_count(self): + cv2_mock = _make_cv2_mock(flow_magnitude=0.0) + mp_mock = _make_mediapipe_mock(detect_face=True) + detector = ReactionDetector() + with patch.dict(sys.modules, {"cv2": cv2_mock, "mediapipe": mp_mock}): + result = detector.analyse(_window([1.05, 1.25, 1.45])) + assert result.faces_detected == 3 + + +class TestMouthOpenScore: + def test_closed_mouth_gives_zero_score(self): + cv2_mock = _make_cv2_mock(flow_magnitude=0.0) + mp_mock = _make_mediapipe_mock(detect_face=True, mouth_open_ratio=0.0) + detector = ReactionDetector() + with patch.dict(sys.modules, {"cv2": cv2_mock, "mediapipe": mp_mock}): + result = detector.analyse(_window([1.05, 1.25, 1.45])) + assert result.mouth_open_score == pytest.approx(0.0) + + def test_wide_open_mouth_clamps_to_one(self): + cv2_mock = _make_cv2_mock(flow_magnitude=0.0) + # mouth_open_ratio of 1.0 → height = frame_height; way above scale + mp_mock = _make_mediapipe_mock(detect_face=True, mouth_open_ratio=1.0) + detector = ReactionDetector() + with patch.dict(sys.modules, {"cv2": cv2_mock, "mediapipe": mp_mock}): + result = detector.analyse(_window([1.05, 1.25, 1.45])) + assert result.mouth_open_score == pytest.approx(1.0) + + +class TestCombinedScore: + def test_all_zero_signals_give_zero_reaction(self): + cv2_mock = _make_cv2_mock(flow_magnitude=0.0) + mp_mock = _make_mediapipe_mock(detect_face=False) + detector = ReactionDetector() + with patch.dict(sys.modules, {"cv2": cv2_mock, "mediapipe": mp_mock}): + result = detector.analyse(_window([1.05, 1.25, 1.45])) + assert result.reaction_score == pytest.approx(0.0) + + def test_combined_score_bounded_to_one(self): + cv2_mock = _make_cv2_mock(flow_magnitude=_MOTION_SCALE * 100) + mp_mock = _make_mediapipe_mock( + detect_face=True, + nose_positions=[(0.0, 0.5), (1.0, 0.5)], + mouth_open_ratio=1.0, + ) + detector = ReactionDetector() + with patch.dict(sys.modules, {"cv2": cv2_mock, "mediapipe": mp_mock}): + result = detector.analyse(_window([1.05, 1.25, 1.45])) + assert result.reaction_score <= 1.0 + + def test_weights_sum_to_one(self): + assert _W_MOTION + _W_HEAD + _W_MOUTH == pytest.approx(1.0) + + def test_mediapipe_unavailable_falls_back_gracefully(self): + """If mediapipe is not installed, reaction_score uses only motion.""" + cv2_mock = _make_cv2_mock(flow_magnitude=_MOTION_SCALE * 0.5) + detector = ReactionDetector() + # Remove mediapipe from sys.modules to simulate it being absent + with patch.dict(sys.modules, {"cv2": cv2_mock, "mediapipe": None}): + result = detector.analyse(_window([1.05, 1.25, 1.45])) + # head and mouth scores should both be 0; motion contributes + assert result.head_shift_score == pytest.approx(0.0) + assert result.mouth_open_score == pytest.approx(0.0) + assert result.reaction_score > 0.0