78 lines
2.2 KiB
Python
78 lines
2.2 KiB
Python
"""Tests for FrameExtractor stage."""
|
|
|
|
import glob
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from detect.profiles.base import FrameExtractionConfig
|
|
from detect.stages.frame_extractor import extract_frames
|
|
|
|
SAMPLE_DIR = Path("media/out/chunks/95043d50-4df6-4ac8-bbd5-2ba873117c6e")
|
|
|
|
|
|
def _get_sample_video() -> str:
|
|
"""Return path to first available sample chunk."""
|
|
chunks = sorted(SAMPLE_DIR.glob("chunk_*.mp4"))
|
|
if not chunks:
|
|
pytest.skip("No sample video found in media/out/chunks/")
|
|
return str(chunks[0])
|
|
|
|
|
|
def test_extract_frames_basic():
|
|
video = _get_sample_video()
|
|
config = FrameExtractionConfig(fps=1, max_frames=10)
|
|
frames = extract_frames(video, config)
|
|
|
|
assert len(frames) > 0
|
|
assert len(frames) <= 10
|
|
|
|
for f in frames:
|
|
assert f.image.ndim == 3 # H x W x C
|
|
assert f.image.shape[2] == 3 # RGB
|
|
assert f.sequence >= 0
|
|
assert f.timestamp >= 0.0
|
|
|
|
|
|
def test_extract_frames_respects_fps():
|
|
video = _get_sample_video()
|
|
config_1fps = FrameExtractionConfig(fps=1, max_frames=100)
|
|
config_2fps = FrameExtractionConfig(fps=2, max_frames=100)
|
|
|
|
frames_1 = extract_frames(video, config_1fps)
|
|
frames_2 = extract_frames(video, config_2fps)
|
|
|
|
# 2fps should produce roughly 2x as many frames
|
|
assert len(frames_2) >= len(frames_1)
|
|
|
|
|
|
def test_extract_frames_respects_max():
|
|
video = _get_sample_video()
|
|
config = FrameExtractionConfig(fps=10, max_frames=3)
|
|
frames = extract_frames(video, config)
|
|
|
|
assert len(frames) <= 3
|
|
|
|
|
|
def test_extract_frames_with_events(monkeypatch):
|
|
"""Verify SSE events are emitted when job_id is provided."""
|
|
events = []
|
|
|
|
def mock_push(job_id, event_type, data):
|
|
events.append((job_id, event_type, data))
|
|
|
|
monkeypatch.setattr("detect.emit.push_detect_event", mock_push)
|
|
|
|
video = _get_sample_video()
|
|
config = FrameExtractionConfig(fps=1, max_frames=5)
|
|
frames = extract_frames(video, config, job_id="test-123")
|
|
|
|
assert len(frames) > 0
|
|
|
|
event_types = [e[1] for e in events]
|
|
assert "log" in event_types
|
|
assert "stats_update" in event_types
|
|
|
|
# All events targeted the right job
|
|
assert all(e[0] == "test-123" for e in events)
|