Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Runtime dependencies
tensorflow>=2.13.0
tensorflow-hub>=0.14.0
soundfile>=0.12.1
numpy>=1.24.0
opencv-python>=4.8.0
mediapipe>=0.10.0

# Dev / test
pytest>=7.4.0
Empty file added src/__init__.py
Empty file.
3 changes: 3 additions & 0 deletions src/audio/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from src.audio.detector import AudioEvent, SoundEventDetector

__all__ = ["AudioEvent", "SoundEventDetector"]
138 changes: 138 additions & 0 deletions src/audio/detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""Sound Event Detection using YAMNet (Goal 1).

Heavy ML imports (tensorflow_hub, soundfile) are deferred to method bodies
so the module is importable without any ML stack installed — useful for
running unit tests and importing constants in lightweight contexts.
"""

from __future__ import annotations

import dataclasses

from src.audio.labels import SPEECH_LABELS, to_cc_label

# YAMNet produces one score vector per ~0.48 s of audio.
_YAMNET_HOP_S: float = 0.48

# Adjacent same-label events closer than this are merged into one.
_MERGE_GAP_S: float = 0.5

# Default minimum confidence for an event to be kept.
DEFAULT_CONFIDENCE_THRESHOLD: float = 0.35


@dataclasses.dataclass
class AudioEvent:
"""A single detected non-speech audio event."""

label: str
start_s: float
end_s: float
confidence: float


def _merge_adjacent(events: list[AudioEvent]) -> list[AudioEvent]:
"""Merge same-label events whose gap is within *_MERGE_GAP_S*."""
if not events:
return []
merged = [events[0]]
for ev in events[1:]:
prev = merged[-1]
if ev.label == prev.label and (ev.start_s - prev.end_s) <= _MERGE_GAP_S:
merged[-1] = dataclasses.replace(
prev,
end_s=ev.end_s,
confidence=max(prev.confidence, ev.confidence),
)
else:
merged.append(ev)
return merged


class SoundEventDetector:
"""Detects non-speech audio events in a 16 kHz mono WAV file via YAMNet.

Usage::

detector = SoundEventDetector()
events = detector.detect("extracted_audio.wav")
for ev in events:
print(ev.label, ev.start_s, ev.end_s, ev.confidence)
"""

def __init__(self, confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD) -> None:
self.confidence_threshold = confidence_threshold
self._model = None
self._class_names: list[str] | None = None

# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------

def detect(self, audio_path: str) -> list[AudioEvent]:
"""Run YAMNet on *audio_path* and return filtered, merged events.

*audio_path* must be a 16 kHz mono WAV file. Use ``ffmpeg`` or
``librosa`` to resample if the source has a different sample rate.

Raises
------
ValueError
If *audio_path* is not 16 kHz.
"""
import numpy as np
import soundfile as sf

self._load_model()

waveform, sr = sf.read(audio_path, dtype="float32", always_2d=False)
if sr != 16000:
raise ValueError(
f"Expected 16 kHz audio, got {sr} Hz. "
"Resample with ffmpeg: ffmpeg -i input.wav -ar 16000 out.wav"
)

scores, _, _ = self._model(waveform)
scores = scores.numpy() # shape: (n_frames, 521)

events: list[AudioEvent] = []
for frame_idx, frame_scores in enumerate(scores):
top_idx = int(frame_scores.argmax())
confidence = float(frame_scores[top_idx])
if confidence < self.confidence_threshold:
continue
class_name = self._class_names[top_idx]
if class_name in SPEECH_LABELS:
continue
label = to_cc_label(class_name)
if label is None:
continue
start_s = frame_idx * _YAMNET_HOP_S
events.append(
AudioEvent(
label=label,
start_s=round(start_s, 3),
end_s=round(start_s + _YAMNET_HOP_S, 3),
confidence=round(confidence, 4),
)
)

return _merge_adjacent(events)

# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------

def _load_model(self) -> None:
"""Load YAMNet from TF-Hub (idempotent)."""
if self._model is not None:
return
import csv
import io

import tensorflow_hub as hub

self._model = hub.load("https://tfhub.dev/google/yamnet/1")
class_map_bytes: bytes = self._model.class_map_path().numpy()
reader = csv.DictReader(io.StringIO(class_map_bytes.decode()))
self._class_names = [row["display_name"] for row in reader]
110 changes: 110 additions & 0 deletions src/audio/labels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""AudioSet label utilities for the CC suggestion pipeline."""

from __future__ import annotations

# YAMNet class names that represent speech — always suppressed.
SPEECH_LABELS: frozenset[str] = frozenset({
"Speech",
"Male speech, man speaking",
"Female speech, woman speaking",
"Child speech, kid speaking",
"Conversation",
"Narration, monologue",
"Babbling",
"Speech synthesizer",
"Shout",
"Bellow",
"Whoop",
"Yell",
"Children shouting",
"Screaming",
"Whispering",
"Laughter",
"Baby laughter",
"Giggling",
"Snicker",
"Breathing",
"Wheeze",
"Snoring",
"Cough",
"Sneeze",
"Gasp",
"Sigh",
})

# Maps YAMNet display names to human-readable CC labels.
# Includes India-specific sounds absent from generic AudioSet mappings.
LABEL_MAP: dict[str, str] = {
"Gunshot, gunfire": "[gunshot]",
"Machine gun": "[gunshot]",
"Explosion": "[explosion]",
"Blowing up": "[explosion]",
"Fire alarm": "[alarm]",
"Alarm": "[alarm]",
"Smoke detector, smoke alarm": "[alarm]",
"Car alarm": "[alarm]",
"Siren": "[siren]",
"Civil defense siren": "[siren]",
"Ambulance (siren)": "[siren]",
"Police car (siren)": "[siren]",
"Glass": "[glass breaking]",
"Breaking": "[glass breaking]",
"Applause": "[applause]",
"Crowd": "[crowd noise]",
"Cheering": "[cheering]",
"Baby cry, infant cry": "[baby crying]",
"Crying, sobbing": "[crying]",
"Dog": "[dog barking]",
"Bark": "[dog barking]",
"Cat": "[cat meowing]",
"Meow": "[cat meowing]",
# India-specific: Diwali crackers map from AudioSet "Fireworks"
"Fireworks": "[firecrackers]",
# India-specific classical/folk percussion
"Tabla": "[tabla]",
"Dhol": "[dhol]",
"Temple bells": "[temple bells]",
"Knock": "[knocking]",
"Telephone": "[phone ringing]",
"Bell": "[bell]",
"Thunder": "[thunder]",
"Rain": "[rain]",
"Wind": "[wind]",
"Traffic noise, roadway noise": "[traffic]",
"Honk": "[honking]",
"Music": "[music]",
"Musical instrument": "[music]",
}

# High-impact events: audio confidence alone is usually sufficient.
HIGH_IMPACT: frozenset[str] = frozenset({
"[gunshot]",
"[explosion]",
"[alarm]",
"[siren]",
"[glass breaking]",
})

# Ambient events: require strong visual confirmation to avoid overcaptioning.
AMBIENT: frozenset[str] = frozenset({
"[music]",
"[rain]",
"[wind]",
"[traffic]",
})


def to_cc_label(yamnet_class: str) -> str | None:
"""Return a CC label string for *yamnet_class*, or ``None`` for speech.

Falls back to ``[<lowercased class name>]`` for unmapped non-speech events.
Matching is case-insensitive and uses substring search so partial YAMNet
class names (e.g. "Bark" matching "Dog bark") resolve correctly.
"""
if yamnet_class in SPEECH_LABELS:
return None
lower = yamnet_class.lower()
for key, label in LABEL_MAP.items():
if key.lower() in lower or lower in key.lower():
return label
return f"[{lower}]"
Empty file added src/vision/__init__.py
Empty file.
107 changes: 107 additions & 0 deletions src/vision/extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""Frame extraction utilities for the visual reaction module (Goal 2).

Heavy imports (cv2) are deferred to method bodies so the module is
importable without OpenCV installed — keeps tests fast.
"""

from __future__ import annotations

import dataclasses


@dataclasses.dataclass
class FrameWindow:
"""A window of frames extracted around a single audio event."""

event_start_s: float
event_end_s: float
# Each entry is (timestamp_s, BGR ndarray).
frames: list[tuple[float, object]]


class FrameExtractor:
"""Extracts video frames in a reaction window around detected audio events.

The reaction window starts 100 ms before the event onset (to capture
anticipatory reactions) and extends 1 500 ms past the event end. Frames
are sampled at *fps* frames-per-second inside that window.

Parameters
----------
reaction_before_s:
How many seconds before event onset to start the window (default 0.1 s).
reaction_after_s:
How many seconds after event end to extend the window (default 1.5 s).
sample_fps:
Frame sampling rate inside the window (default 5 fps → one frame per
200 ms, which is enough to detect head-turns and startled gestures).
"""

def __init__(
self,
reaction_before_s: float = 0.1,
reaction_after_s: float = 1.5,
sample_fps: float = 5.0,
) -> None:
self.reaction_before_s = reaction_before_s
self.reaction_after_s = reaction_after_s
self.sample_fps = sample_fps

def extract(self, video_path: str, event_start_s: float, event_end_s: float) -> FrameWindow:
"""Return frames from *video_path* in the reaction window for one event.

Parameters
----------
video_path:
Path to any video format supported by OpenCV.
event_start_s, event_end_s:
Start and end timestamps (in seconds) of the audio event.

Raises
------
IOError
If *video_path* cannot be opened.
"""
import cv2 # deferred

cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise IOError(f"Cannot open video file: {video_path}")

try:
return self._extract_frames(cap, event_start_s, event_end_s)
finally:
cap.release()

# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------

def _extract_frames(self, cap, event_start_s: float, event_end_s: float) -> FrameWindow:
import numpy as np

win_start = max(0.0, event_start_s - self.reaction_before_s)
win_end = event_end_s + self.reaction_after_s
step_s = 1.0 / self.sample_fps

timestamps = []
t = win_start
while t <= win_end:
timestamps.append(round(t, 4))
t += step_s

frames: list[tuple[float, object]] = []
for ts in timestamps:
cap.set(1, int(ts * self._video_fps(cap)))
ret, frame = cap.read()
if ret and frame is not None:
frames.append((ts, frame))

return FrameWindow(event_start_s=event_start_s, event_end_s=event_end_s, frames=frames)

@staticmethod
def _video_fps(cap) -> float:
import cv2 # deferred

fps = cap.get(cv2.CAP_PROP_FPS)
return fps if fps and fps > 0 else 25.0
Loading