Files
mediaproc/tests/detect/test_frame_extractor.py
2026-03-23 15:18:23 -03:00

78 lines
2.2 KiB
Python

"""Tests for FrameExtractor stage."""
import glob
from pathlib import Path
import pytest
from detect.profiles.base import FrameExtractionConfig
from detect.stages.frame_extractor import extract_frames
SAMPLE_DIR = Path("media/out/chunks/95043d50-4df6-4ac8-bbd5-2ba873117c6e")
def _get_sample_video() -> str:
"""Return path to first available sample chunk."""
chunks = sorted(SAMPLE_DIR.glob("chunk_*.mp4"))
if not chunks:
pytest.skip("No sample video found in media/out/chunks/")
return str(chunks[0])
def test_extract_frames_basic():
video = _get_sample_video()
config = FrameExtractionConfig(fps=1, max_frames=10)
frames = extract_frames(video, config)
assert len(frames) > 0
assert len(frames) <= 10
for f in frames:
assert f.image.ndim == 3 # H x W x C
assert f.image.shape[2] == 3 # RGB
assert f.sequence >= 0
assert f.timestamp >= 0.0
def test_extract_frames_respects_fps():
video = _get_sample_video()
config_1fps = FrameExtractionConfig(fps=1, max_frames=100)
config_2fps = FrameExtractionConfig(fps=2, max_frames=100)
frames_1 = extract_frames(video, config_1fps)
frames_2 = extract_frames(video, config_2fps)
# 2fps should produce roughly 2x as many frames
assert len(frames_2) >= len(frames_1)
def test_extract_frames_respects_max():
video = _get_sample_video()
config = FrameExtractionConfig(fps=10, max_frames=3)
frames = extract_frames(video, config)
assert len(frames) <= 3
def test_extract_frames_with_events(monkeypatch):
"""Verify SSE events are emitted when job_id is provided."""
events = []
def mock_push(job_id, event_type, data):
events.append((job_id, event_type, data))
monkeypatch.setattr("detect.emit.push_detect_event", mock_push)
video = _get_sample_video()
config = FrameExtractionConfig(fps=1, max_frames=5)
frames = extract_frames(video, config, job_id="test-123")
assert len(frames) > 0
event_types = [e[1] for e in events]
assert "log" in event_types
assert "stats_update" in event_types
# All events targeted the right job
assert all(e[0] == "test-123" for e in events)