phase 4
This commit is contained in:
@@ -1,76 +0,0 @@
|
||||
"""
|
||||
Shared fixtures for chunker tests.
|
||||
|
||||
Demonstrates: TDD and unit testing best practices (Interview Topic 8) — fixtures, temp files.
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from core.chunker.models import Chunk, ChunkResult
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_file():
|
||||
"""Create a temporary file with known content, cleaned up after test."""
|
||||
files = []
|
||||
|
||||
def _create(content: bytes = b"x" * 4096):
|
||||
f = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
||||
f.write(content)
|
||||
f.close()
|
||||
files.append(f.name)
|
||||
return f.name
|
||||
|
||||
yield _create
|
||||
|
||||
for path in files:
|
||||
if os.path.exists(path):
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_chunk(temp_file):
|
||||
"""Create a sample time-based Chunk with valid time range."""
|
||||
path = temp_file(b"x" * 1024)
|
||||
return Chunk(
|
||||
sequence=0,
|
||||
start_time=0.0,
|
||||
end_time=10.0,
|
||||
source_path=path,
|
||||
duration=10.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_chunk(temp_file):
|
||||
"""Factory fixture for creating time-based chunks with specific sequence numbers."""
|
||||
path = temp_file(b"x" * 1024)
|
||||
|
||||
def _make(sequence: int, duration: float = 10.0) -> Chunk:
|
||||
start = sequence * duration
|
||||
return Chunk(
|
||||
sequence=sequence,
|
||||
start_time=start,
|
||||
end_time=start + duration,
|
||||
source_path=path,
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
return _make
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_result():
|
||||
"""Factory fixture for creating ChunkResults."""
|
||||
|
||||
def _make(sequence: int, success: bool = True, processing_time: float = 0.01) -> ChunkResult:
|
||||
return ChunkResult(
|
||||
sequence=sequence,
|
||||
success=success,
|
||||
processing_time=processing_time,
|
||||
)
|
||||
|
||||
return _make
|
||||
@@ -1,149 +0,0 @@
|
||||
"""
|
||||
Tests for Chunker — time-based segmentation, chunk counts, sequence numbers, generator behavior.
|
||||
|
||||
Demonstrates: TDD (Interview Topic 8) — parametrized tests, edge cases, mocking.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.chunker import Chunker
|
||||
from core.chunker.exceptions import ChunkReadError
|
||||
|
||||
|
||||
def mock_probe(duration):
|
||||
"""Create a mock probe_file that returns the given duration."""
|
||||
result = MagicMock()
|
||||
result.duration = duration
|
||||
return result
|
||||
|
||||
|
||||
class TestChunker:
|
||||
@patch("core.chunker.chunker.probe_file")
|
||||
def test_basic_chunking(self, mock_pf, temp_file):
|
||||
"""File splits into expected number of time-based chunks."""
|
||||
path = temp_file(b"x" * 1000)
|
||||
mock_pf.return_value = mock_probe(30.0)
|
||||
|
||||
chunker = Chunker(path, chunk_duration=10.0)
|
||||
chunks = list(chunker.chunks())
|
||||
|
||||
assert len(chunks) == 3
|
||||
assert chunks[0].start_time == 0.0
|
||||
assert chunks[0].end_time == 10.0
|
||||
assert chunks[0].duration == 10.0
|
||||
assert chunks[1].start_time == 10.0
|
||||
assert chunks[2].start_time == 20.0
|
||||
|
||||
@patch("core.chunker.chunker.probe_file")
|
||||
def test_sequence_numbers(self, mock_pf, temp_file):
|
||||
"""Chunks have sequential sequence numbers starting at 0."""
|
||||
path = temp_file(b"x" * 100)
|
||||
mock_pf.return_value = mock_probe(40.0)
|
||||
|
||||
chunker = Chunker(path, chunk_duration=10.0)
|
||||
chunks = list(chunker.chunks())
|
||||
sequences = [c.sequence for c in chunks]
|
||||
|
||||
assert sequences == [0, 1, 2, 3]
|
||||
|
||||
@patch("core.chunker.chunker.probe_file")
|
||||
def test_time_ranges(self, mock_pf, temp_file):
|
||||
"""Each chunk has correct start_time and end_time."""
|
||||
path = temp_file(b"x" * 100)
|
||||
mock_pf.return_value = mock_probe(25.0)
|
||||
|
||||
chunker = Chunker(path, chunk_duration=10.0)
|
||||
chunks = list(chunker.chunks())
|
||||
|
||||
assert chunks[0].start_time == 0.0
|
||||
assert chunks[0].end_time == 10.0
|
||||
assert chunks[1].start_time == 10.0
|
||||
assert chunks[1].end_time == 20.0
|
||||
assert chunks[2].start_time == 20.0
|
||||
assert chunks[2].end_time == 25.0 # last chunk shorter
|
||||
assert chunks[2].duration == 5.0
|
||||
|
||||
@patch("core.chunker.chunker.probe_file")
|
||||
def test_expected_chunks_property(self, mock_pf, temp_file):
|
||||
"""expected_chunks calculates correctly before iteration."""
|
||||
path = temp_file(b"x" * 100)
|
||||
mock_pf.return_value = mock_probe(25.0)
|
||||
|
||||
chunker = Chunker(path, chunk_duration=10.0)
|
||||
assert chunker.expected_chunks == 3 # ceil(25/10)
|
||||
|
||||
@patch("core.chunker.chunker.probe_file")
|
||||
def test_source_path_on_chunks(self, mock_pf, temp_file):
|
||||
"""Each chunk carries the source file path."""
|
||||
path = temp_file(b"x" * 100)
|
||||
mock_pf.return_value = mock_probe(10.0)
|
||||
|
||||
chunker = Chunker(path, chunk_duration=10.0)
|
||||
chunks = list(chunker.chunks())
|
||||
|
||||
assert all(c.source_path == path for c in chunks)
|
||||
|
||||
def test_file_not_found(self):
|
||||
"""Non-existent file raises ChunkReadError."""
|
||||
with pytest.raises(ChunkReadError, match="File not found"):
|
||||
Chunker("/nonexistent/file.mp4")
|
||||
|
||||
@patch("core.chunker.chunker.probe_file")
|
||||
def test_invalid_chunk_duration(self, mock_pf, temp_file):
|
||||
"""Zero or negative chunk_duration raises ValueError."""
|
||||
path = temp_file(b"x" * 100)
|
||||
|
||||
with pytest.raises(ValueError, match="chunk_duration must be positive"):
|
||||
Chunker(path, chunk_duration=0)
|
||||
|
||||
with pytest.raises(ValueError, match="chunk_duration must be positive"):
|
||||
Chunker(path, chunk_duration=-1)
|
||||
|
||||
@patch("core.chunker.chunker.probe_file")
|
||||
def test_generator_laziness(self, mock_pf, temp_file):
|
||||
"""Chunks are yielded lazily, not pre-loaded."""
|
||||
path = temp_file(b"x" * 100)
|
||||
mock_pf.return_value = mock_probe(30.0)
|
||||
|
||||
chunker = Chunker(path, chunk_duration=10.0)
|
||||
gen = chunker.chunks()
|
||||
first = next(gen)
|
||||
assert first.sequence == 0
|
||||
# Generator is not exhausted — remaining chunks still pending
|
||||
|
||||
@pytest.mark.parametrize("duration,chunk_dur,expected", [
|
||||
(10.0, 10.0, 1),
|
||||
(10.1, 10.0, 2),
|
||||
(1.0, 1.0, 1),
|
||||
(100.0, 1.0, 100),
|
||||
(5.0, 100.0, 1),
|
||||
])
|
||||
@patch("core.chunker.chunker.probe_file")
|
||||
def test_expected_chunks_parametrized(self, mock_pf, temp_file, duration, chunk_dur, expected):
|
||||
"""Parametrized: various duration/chunk_duration combos."""
|
||||
path = temp_file(b"x" * 100)
|
||||
mock_pf.return_value = mock_probe(duration)
|
||||
chunker = Chunker(path, chunk_duration=chunk_dur)
|
||||
assert chunker.expected_chunks == expected
|
||||
|
||||
@patch("core.chunker.chunker.probe_file")
|
||||
def test_exact_multiple(self, mock_pf, temp_file):
|
||||
"""Duration exactly divisible by chunk_duration."""
|
||||
path = temp_file(b"x" * 100)
|
||||
mock_pf.return_value = mock_probe(30.0)
|
||||
|
||||
chunker = Chunker(path, chunk_duration=10.0)
|
||||
chunks = list(chunker.chunks())
|
||||
assert len(chunks) == 3
|
||||
assert all(c.duration == 10.0 for c in chunks)
|
||||
|
||||
@patch("core.chunker.chunker.probe_file")
|
||||
def test_probe_failure(self, mock_pf, temp_file):
|
||||
"""Probe failure raises ChunkReadError."""
|
||||
path = temp_file(b"x" * 100)
|
||||
mock_pf.side_effect = Exception("ffprobe failed")
|
||||
|
||||
with pytest.raises(ChunkReadError, match="Failed to probe"):
|
||||
Chunker(path, chunk_duration=10.0)
|
||||
@@ -1,103 +0,0 @@
|
||||
"""
|
||||
Tests for ResultCollector — ordered reassembly, out-of-order buffering, duplicates.
|
||||
|
||||
Demonstrates: TDD (Interview Topic 8) — testing algorithms (heapq reassembly).
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from core.chunker.collector import ResultCollector
|
||||
from core.chunker.exceptions import ReassemblyError
|
||||
|
||||
|
||||
class TestResultCollector:
|
||||
def test_in_order_emission(self, make_result):
|
||||
"""Results arriving in order are emitted immediately."""
|
||||
collector = ResultCollector(total_chunks=3)
|
||||
|
||||
emitted = collector.add(make_result(0))
|
||||
assert len(emitted) == 1
|
||||
assert emitted[0].sequence == 0
|
||||
|
||||
emitted = collector.add(make_result(1))
|
||||
assert len(emitted) == 1
|
||||
|
||||
emitted = collector.add(make_result(2))
|
||||
assert len(emitted) == 1
|
||||
|
||||
assert collector.is_complete
|
||||
|
||||
def test_out_of_order_buffering(self, make_result):
|
||||
"""Out-of-order results are buffered until gaps fill."""
|
||||
collector = ResultCollector(total_chunks=3)
|
||||
|
||||
# Arrive: 2, 0, 1
|
||||
emitted = collector.add(make_result(2))
|
||||
assert len(emitted) == 0
|
||||
assert collector.buffered_count == 1
|
||||
|
||||
emitted = collector.add(make_result(0))
|
||||
assert len(emitted) == 1 # Only 0 emitted, 1 still missing
|
||||
|
||||
emitted = collector.add(make_result(1))
|
||||
assert len(emitted) == 2 # 1 and 2 now emittable
|
||||
assert collector.is_complete
|
||||
|
||||
def test_reverse_order(self, make_result):
|
||||
"""All results arrive in reverse — only last add emits everything."""
|
||||
collector = ResultCollector(total_chunks=4)
|
||||
|
||||
for seq in [3, 2, 1]:
|
||||
emitted = collector.add(make_result(seq))
|
||||
assert len(emitted) == 0
|
||||
|
||||
emitted = collector.add(make_result(0))
|
||||
assert len(emitted) == 4
|
||||
assert collector.is_complete
|
||||
|
||||
def test_duplicate_raises(self, make_result):
|
||||
"""Duplicate sequence number raises ReassemblyError."""
|
||||
collector = ResultCollector(total_chunks=3)
|
||||
collector.add(make_result(0))
|
||||
|
||||
with pytest.raises(ReassemblyError, match="Duplicate"):
|
||||
collector.add(make_result(0))
|
||||
|
||||
def test_emitted_count(self, make_result):
|
||||
"""emitted_count tracks correctly."""
|
||||
collector = ResultCollector(total_chunks=3)
|
||||
assert collector.emitted_count == 0
|
||||
|
||||
collector.add(make_result(0))
|
||||
assert collector.emitted_count == 1
|
||||
|
||||
collector.add(make_result(2)) # buffered
|
||||
assert collector.emitted_count == 1
|
||||
|
||||
collector.add(make_result(1)) # releases 1 and 2
|
||||
assert collector.emitted_count == 3
|
||||
|
||||
def test_get_ordered_results(self, make_result):
|
||||
"""get_ordered_results returns all emitted results in order."""
|
||||
collector = ResultCollector(total_chunks=3)
|
||||
collector.add(make_result(2))
|
||||
collector.add(make_result(0))
|
||||
collector.add(make_result(1))
|
||||
|
||||
ordered = collector.get_ordered_results()
|
||||
assert [r.sequence for r in ordered] == [0, 1, 2]
|
||||
|
||||
def test_avg_processing_time(self, make_result):
|
||||
"""Average processing time from sliding window."""
|
||||
collector = ResultCollector(total_chunks=2)
|
||||
collector.add(make_result(0, processing_time=0.1))
|
||||
collector.add(make_result(1, processing_time=0.3))
|
||||
|
||||
assert abs(collector.avg_processing_time - 0.2) < 0.001
|
||||
|
||||
def test_not_complete_when_partial(self, make_result):
|
||||
"""is_complete is False until all chunks emitted."""
|
||||
collector = ResultCollector(total_chunks=3)
|
||||
collector.add(make_result(0))
|
||||
collector.add(make_result(1))
|
||||
assert not collector.is_complete
|
||||
@@ -1,69 +0,0 @@
|
||||
"""
|
||||
Tests for exception hierarchy — catch patterns, attributes.
|
||||
|
||||
Demonstrates: TDD (Interview Topic 8) — testing exception design.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from core.chunker.exceptions import (
|
||||
ChunkChecksumError,
|
||||
ChunkError,
|
||||
ChunkReadError,
|
||||
PipelineError,
|
||||
ProcessingError,
|
||||
ProcessorFailureError,
|
||||
ProcessorTimeoutError,
|
||||
ReassemblyError,
|
||||
)
|
||||
|
||||
|
||||
class TestExceptionHierarchy:
|
||||
"""Verify the exception class hierarchy and catch patterns."""
|
||||
|
||||
def test_pipeline_error_is_base(self):
|
||||
"""All chunker exceptions inherit from PipelineError."""
|
||||
assert issubclass(ChunkError, PipelineError)
|
||||
assert issubclass(ProcessingError, PipelineError)
|
||||
assert issubclass(ReassemblyError, PipelineError)
|
||||
|
||||
def test_chunk_error_subtypes(self):
|
||||
"""ChunkReadError and ChunkChecksumError are ChunkErrors."""
|
||||
assert issubclass(ChunkReadError, ChunkError)
|
||||
assert issubclass(ChunkChecksumError, ChunkError)
|
||||
|
||||
def test_processing_error_subtypes(self):
|
||||
"""ProcessorTimeoutError and ProcessorFailureError are ProcessingErrors."""
|
||||
assert issubclass(ProcessorTimeoutError, ProcessingError)
|
||||
assert issubclass(ProcessorFailureError, ProcessingError)
|
||||
|
||||
def test_catch_pipeline_error_catches_all(self):
|
||||
"""Catching PipelineError catches any subtype."""
|
||||
with pytest.raises(PipelineError):
|
||||
raise ChunkReadError("test")
|
||||
|
||||
with pytest.raises(PipelineError):
|
||||
raise ReassemblyError("test")
|
||||
|
||||
def test_checksum_error_attributes(self):
|
||||
"""ChunkChecksumError carries sequence, expected, actual."""
|
||||
err = ChunkChecksumError(sequence=5, expected="aaa", actual="bbb")
|
||||
assert err.sequence == 5
|
||||
assert err.expected == "aaa"
|
||||
assert err.actual == "bbb"
|
||||
assert "5" in str(err)
|
||||
|
||||
def test_timeout_error_attributes(self):
|
||||
"""ProcessorTimeoutError carries sequence and timeout."""
|
||||
err = ProcessorTimeoutError(sequence=3, timeout=30.0)
|
||||
assert err.sequence == 3
|
||||
assert err.timeout == 30.0
|
||||
|
||||
def test_failure_error_attributes(self):
|
||||
"""ProcessorFailureError carries sequence, retries, original error."""
|
||||
original = RuntimeError("boom")
|
||||
err = ProcessorFailureError(sequence=1, retries=3, original_error=original)
|
||||
assert err.sequence == 1
|
||||
assert err.retries == 3
|
||||
assert err.original_error is original
|
||||
assert "boom" in str(err)
|
||||
@@ -1,144 +0,0 @@
|
||||
"""
|
||||
Tests for Pipeline — end-to-end orchestration, stats, error handling.
|
||||
|
||||
Demonstrates: TDD (Interview Topic 8) — integration testing with mocked FFmpeg probe.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.chunker import Pipeline
|
||||
from core.chunker.exceptions import PipelineError
|
||||
|
||||
|
||||
def mock_probe(duration):
|
||||
"""Create a mock ProbeResult with the given duration."""
|
||||
result = MagicMock()
|
||||
result.duration = duration
|
||||
return result
|
||||
|
||||
|
||||
class TestPipeline:
|
||||
@patch("core.chunker.chunker.probe_file")
|
||||
def test_end_to_end(self, mock_pf, temp_file):
|
||||
"""Full pipeline processes a file successfully."""
|
||||
path = temp_file(b"x" * 4096)
|
||||
mock_pf.return_value = mock_probe(40.0)
|
||||
|
||||
result = Pipeline(
|
||||
source=path,
|
||||
chunk_duration=10.0,
|
||||
num_workers=2,
|
||||
processor_type="checksum",
|
||||
).run()
|
||||
|
||||
assert result.total_chunks == 4
|
||||
assert result.processed == 4
|
||||
assert result.failed == 0
|
||||
assert result.elapsed_time > 0
|
||||
assert result.chunks_in_order is True
|
||||
|
||||
@patch("core.chunker.chunker.probe_file")
|
||||
def test_throughput_calculated(self, mock_pf, temp_file):
|
||||
"""Pipeline calculates throughput."""
|
||||
path = temp_file(b"x" * 10000)
|
||||
mock_pf.return_value = mock_probe(30.0)
|
||||
|
||||
result = Pipeline(source=path, chunk_duration=10.0, num_workers=2).run()
|
||||
|
||||
assert result.throughput_mbps > 0
|
||||
|
||||
@patch("core.chunker.chunker.probe_file")
|
||||
def test_worker_stats(self, mock_pf, temp_file):
|
||||
"""Pipeline reports per-worker stats."""
|
||||
path = temp_file(b"x" * 4000)
|
||||
mock_pf.return_value = mock_probe(40.0)
|
||||
|
||||
result = Pipeline(
|
||||
source=path, chunk_duration=10.0, num_workers=2
|
||||
).run()
|
||||
|
||||
assert len(result.worker_stats) == 2
|
||||
for worker_id, stats in result.worker_stats.items():
|
||||
assert "processed" in stats
|
||||
assert "errors" in stats
|
||||
|
||||
def test_nonexistent_file(self):
|
||||
"""Non-existent file raises PipelineError."""
|
||||
with pytest.raises(PipelineError):
|
||||
Pipeline(source="/nonexistent/file.mp4").run()
|
||||
|
||||
@patch("core.chunker.chunker.probe_file")
|
||||
def test_event_callback(self, mock_pf, temp_file):
|
||||
"""Pipeline emits events through callback."""
|
||||
path = temp_file(b"x" * 2048)
|
||||
mock_pf.return_value = mock_probe(20.0)
|
||||
events = []
|
||||
|
||||
def capture(event_type, data):
|
||||
events.append(event_type)
|
||||
|
||||
Pipeline(
|
||||
source=path,
|
||||
chunk_duration=10.0,
|
||||
num_workers=1,
|
||||
event_callback=capture,
|
||||
).run()
|
||||
|
||||
assert "pipeline_start" in events
|
||||
assert "pipeline_complete" in events
|
||||
assert "chunk_queued" in events
|
||||
|
||||
@patch("core.chunker.chunker.probe_file")
|
||||
def test_simulated_decode_processor(self, mock_pf, temp_file):
|
||||
"""Pipeline works with simulated_decode processor."""
|
||||
path = temp_file(b"x" * 2048)
|
||||
mock_pf.return_value = mock_probe(20.0)
|
||||
|
||||
result = Pipeline(
|
||||
source=path,
|
||||
chunk_duration=10.0,
|
||||
num_workers=2,
|
||||
processor_type="simulated_decode",
|
||||
).run()
|
||||
|
||||
assert result.total_chunks == 2
|
||||
assert result.failed == 0
|
||||
|
||||
@patch("core.chunker.chunker.probe_file")
|
||||
def test_single_chunk_file(self, mock_pf, temp_file):
|
||||
"""Duration shorter than chunk_duration produces one chunk."""
|
||||
path = temp_file(b"x" * 100)
|
||||
mock_pf.return_value = mock_probe(5.0)
|
||||
|
||||
result = Pipeline(source=path, chunk_duration=10.0).run()
|
||||
|
||||
assert result.total_chunks == 1
|
||||
assert result.processed == 1
|
||||
|
||||
@patch("core.chunker.chunker.probe_file")
|
||||
def test_retries_tracked(self, mock_pf, temp_file):
|
||||
"""Pipeline result tracks total retries."""
|
||||
path = temp_file(b"x" * 2048)
|
||||
mock_pf.return_value = mock_probe(20.0)
|
||||
|
||||
result = Pipeline(source=path, chunk_duration=10.0).run()
|
||||
|
||||
assert result.retries >= 0 # Might be 0 if no failures
|
||||
|
||||
@patch("core.chunker.chunker.probe_file")
|
||||
def test_output_dir_and_chunk_files(self, mock_pf, temp_file):
|
||||
"""Pipeline tracks output_dir and chunk_files when set."""
|
||||
path = temp_file(b"x" * 1024)
|
||||
mock_pf.return_value = mock_probe(10.0)
|
||||
|
||||
result = Pipeline(
|
||||
source=path,
|
||||
chunk_duration=10.0,
|
||||
processor_type="checksum",
|
||||
).run()
|
||||
|
||||
# No output_dir set, so chunk_files should be empty
|
||||
assert result.output_dir is None
|
||||
assert result.chunk_files == []
|
||||
@@ -1,98 +0,0 @@
|
||||
"""
|
||||
Tests for Processor implementations — ChecksumProcessor, SimulatedDecodeProcessor, CompositeProcessor.
|
||||
|
||||
Demonstrates: TDD (Interview Topic 8) — ABC contract, parametrized tests.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from core.chunker.exceptions import ChunkChecksumError
|
||||
from core.chunker.models import Chunk
|
||||
from core.chunker.processor import (
|
||||
ChecksumProcessor,
|
||||
CompositeProcessor,
|
||||
Processor,
|
||||
SimulatedDecodeProcessor,
|
||||
)
|
||||
|
||||
|
||||
class TestChecksumProcessor:
|
||||
def test_valid_time_range(self, sample_chunk):
|
||||
"""Valid time range passes."""
|
||||
proc = ChecksumProcessor()
|
||||
result = proc.process(sample_chunk)
|
||||
assert result.success is True
|
||||
assert result.checksum_valid is True
|
||||
assert result.processing_time > 0
|
||||
|
||||
def test_invalid_time_range(self):
|
||||
"""Invalid time range raises ChunkChecksumError."""
|
||||
chunk = Chunk(
|
||||
sequence=0,
|
||||
start_time=10.0,
|
||||
end_time=10.0, # zero duration
|
||||
source_path="/fake.mp4",
|
||||
duration=0.0,
|
||||
)
|
||||
proc = ChecksumProcessor()
|
||||
with pytest.raises(ChunkChecksumError) as exc_info:
|
||||
proc.process(chunk)
|
||||
assert exc_info.value.sequence == 0
|
||||
|
||||
def test_sequence_preserved(self, make_chunk):
|
||||
"""Result carries the chunk's sequence number."""
|
||||
chunk = make_chunk(42)
|
||||
proc = ChecksumProcessor()
|
||||
result = proc.process(chunk)
|
||||
assert result.sequence == 42
|
||||
|
||||
|
||||
class TestSimulatedDecodeProcessor:
|
||||
def test_processes_successfully(self, sample_chunk):
|
||||
"""Simulated decode always succeeds."""
|
||||
proc = SimulatedDecodeProcessor(ms_per_second=1.0)
|
||||
result = proc.process(sample_chunk)
|
||||
assert result.success is True
|
||||
assert result.processing_time > 0
|
||||
|
||||
def test_time_proportional_to_duration(self):
|
||||
"""Longer chunks take longer."""
|
||||
short = Chunk(0, 0.0, 1.0, "/fake.mp4", 1.0)
|
||||
long = Chunk(1, 0.0, 10.0, "/fake.mp4", 10.0)
|
||||
|
||||
proc = SimulatedDecodeProcessor(ms_per_second=50.0)
|
||||
r_short = proc.process(short)
|
||||
r_long = proc.process(long)
|
||||
|
||||
assert r_long.processing_time > r_short.processing_time
|
||||
|
||||
|
||||
class TestCompositeProcessor:
|
||||
def test_chains_processors(self, sample_chunk):
|
||||
"""Composite runs all processors in sequence."""
|
||||
proc = CompositeProcessor([
|
||||
ChecksumProcessor(),
|
||||
SimulatedDecodeProcessor(ms_per_second=1.0),
|
||||
])
|
||||
result = proc.process(sample_chunk)
|
||||
assert result.success is True
|
||||
|
||||
def test_stops_on_failure(self):
|
||||
"""If first processor raises, composite propagates the error."""
|
||||
bad_chunk = Chunk(0, 10.0, 10.0, "/fake.mp4", 0.0) # invalid range
|
||||
proc = CompositeProcessor([
|
||||
ChecksumProcessor(),
|
||||
SimulatedDecodeProcessor(ms_per_second=1.0),
|
||||
])
|
||||
with pytest.raises(ChunkChecksumError):
|
||||
proc.process(bad_chunk)
|
||||
|
||||
def test_requires_at_least_one(self):
|
||||
"""Empty processor list raises ValueError."""
|
||||
with pytest.raises(ValueError, match="at least one"):
|
||||
CompositeProcessor([])
|
||||
|
||||
def test_is_processor(self):
|
||||
"""CompositeProcessor is a Processor."""
|
||||
proc = CompositeProcessor([ChecksumProcessor()])
|
||||
assert isinstance(proc, Processor)
|
||||
@@ -1,115 +0,0 @@
|
||||
"""
|
||||
Tests for ChunkQueue — backpressure, sentinel shutdown, timeout behavior.
|
||||
|
||||
Demonstrates: TDD (Interview Topic 8) — concurrency testing.
|
||||
"""
|
||||
|
||||
import queue
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
|
||||
from core.chunker.queue import ChunkQueue
|
||||
|
||||
|
||||
class TestChunkQueue:
|
||||
def test_put_and_get(self, make_chunk):
|
||||
"""Basic put/get cycle."""
|
||||
q = ChunkQueue(maxsize=5)
|
||||
chunk = make_chunk(0)
|
||||
q.put(chunk)
|
||||
result = q.get(timeout=1.0)
|
||||
assert result.sequence == 0
|
||||
|
||||
def test_fifo_order(self, make_chunk):
|
||||
"""Items come out in FIFO order."""
|
||||
q = ChunkQueue(maxsize=5)
|
||||
for i in range(3):
|
||||
q.put(make_chunk(i))
|
||||
|
||||
for i in range(3):
|
||||
assert q.get(timeout=1.0).sequence == i
|
||||
|
||||
def test_close_returns_none(self, make_chunk):
|
||||
"""After close(), get() returns None (sentinel)."""
|
||||
q = ChunkQueue(maxsize=5)
|
||||
q.put(make_chunk(0))
|
||||
q.close()
|
||||
|
||||
result = q.get(timeout=1.0)
|
||||
assert result.sequence == 0
|
||||
|
||||
# Next get should hit sentinel
|
||||
result = q.get(timeout=1.0)
|
||||
assert result is None
|
||||
|
||||
def test_close_propagates_to_multiple_consumers(self, make_chunk):
|
||||
"""Sentinel propagates: multiple consumers all get None."""
|
||||
q = ChunkQueue(maxsize=5)
|
||||
q.close()
|
||||
|
||||
# Multiple consumers should all see None
|
||||
assert q.get(timeout=1.0) is None
|
||||
assert q.get(timeout=1.0) is None
|
||||
|
||||
def test_is_closed(self):
|
||||
"""is_closed reflects state."""
|
||||
q = ChunkQueue()
|
||||
assert not q.is_closed
|
||||
q.close()
|
||||
assert q.is_closed
|
||||
|
||||
def test_qsize(self, make_chunk):
|
||||
"""qsize tracks approximate queue depth."""
|
||||
q = ChunkQueue(maxsize=10)
|
||||
assert q.qsize() == 0
|
||||
|
||||
q.put(make_chunk(0))
|
||||
q.put(make_chunk(1))
|
||||
assert q.qsize() == 2
|
||||
|
||||
q.get(timeout=1.0)
|
||||
assert q.qsize() == 1
|
||||
|
||||
def test_backpressure_blocks(self, make_chunk):
|
||||
"""Put blocks when queue is full (backpressure)."""
|
||||
q = ChunkQueue(maxsize=2)
|
||||
q.put(make_chunk(0))
|
||||
q.put(make_chunk(1))
|
||||
|
||||
# Queue is full — put with short timeout should raise
|
||||
with pytest.raises(queue.Full):
|
||||
q.put(make_chunk(2), timeout=0.05)
|
||||
|
||||
def test_get_timeout(self):
|
||||
"""Get on empty queue with timeout raises Empty."""
|
||||
q = ChunkQueue(maxsize=5)
|
||||
|
||||
with pytest.raises(queue.Empty):
|
||||
q.get(timeout=0.05)
|
||||
|
||||
def test_concurrent_put_get(self, make_chunk):
|
||||
"""Producer/consumer threads work correctly."""
|
||||
q = ChunkQueue(maxsize=3)
|
||||
results = []
|
||||
|
||||
def producer():
|
||||
for i in range(10):
|
||||
q.put(make_chunk(i))
|
||||
q.close()
|
||||
|
||||
def consumer():
|
||||
while True:
|
||||
item = q.get(timeout=2.0)
|
||||
if item is None:
|
||||
break
|
||||
results.append(item.sequence)
|
||||
|
||||
t1 = threading.Thread(target=producer)
|
||||
t2 = threading.Thread(target=consumer)
|
||||
t1.start()
|
||||
t2.start()
|
||||
t1.join(timeout=5.0)
|
||||
t2.join(timeout=5.0)
|
||||
|
||||
assert sorted(results) == list(range(10))
|
||||
@@ -1,127 +0,0 @@
|
||||
"""
|
||||
Tests for Worker — processing, retry with backoff, error handling.
|
||||
|
||||
Demonstrates: TDD (Interview Topic 8) — mocking processors, testing retry logic.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.chunker.models import Chunk, ChunkResult
|
||||
from core.chunker.processor import Processor
|
||||
from core.chunker.queue import ChunkQueue
|
||||
from core.chunker.worker import Worker
|
||||
|
||||
|
||||
class FailNTimesProcessor(Processor):
|
||||
"""Test processor that fails N times then succeeds."""
|
||||
|
||||
def __init__(self, fail_count: int):
|
||||
self.fail_count = fail_count
|
||||
self.call_count = 0
|
||||
|
||||
def process(self, chunk: Chunk) -> ChunkResult:
|
||||
self.call_count += 1
|
||||
if self.call_count <= self.fail_count:
|
||||
raise RuntimeError(f"Simulated failure #{self.call_count}")
|
||||
return ChunkResult(
|
||||
sequence=chunk.sequence,
|
||||
success=True,
|
||||
processing_time=0.001,
|
||||
)
|
||||
|
||||
|
||||
class AlwaysFailProcessor(Processor):
|
||||
"""Test processor that always fails."""
|
||||
|
||||
def process(self, chunk: Chunk) -> ChunkResult:
|
||||
raise RuntimeError("Always fails")
|
||||
|
||||
|
||||
class TestWorker:
|
||||
def test_processes_chunks(self, make_chunk):
|
||||
"""Worker processes all chunks from queue."""
|
||||
q = ChunkQueue(maxsize=5)
|
||||
for i in range(3):
|
||||
q.put(make_chunk(i))
|
||||
q.close()
|
||||
|
||||
from core.chunker.processor import ChecksumProcessor
|
||||
worker = Worker("w-0", q, ChecksumProcessor(), max_retries=0)
|
||||
results = worker.run()
|
||||
|
||||
assert len(results) == 3
|
||||
assert all(r.success for r in results)
|
||||
|
||||
def test_retry_on_failure(self, make_chunk):
|
||||
"""Worker retries on processor failure."""
|
||||
q = ChunkQueue(maxsize=5)
|
||||
q.put(make_chunk(0))
|
||||
q.close()
|
||||
|
||||
proc = FailNTimesProcessor(fail_count=2)
|
||||
worker = Worker("w-0", q, proc, max_retries=3)
|
||||
results = worker.run()
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].success is True
|
||||
assert results[0].retries == 2
|
||||
assert proc.call_count == 3 # 2 failures + 1 success
|
||||
|
||||
def test_max_retries_exceeded(self, make_chunk):
|
||||
"""Worker gives up after max retries."""
|
||||
q = ChunkQueue(maxsize=5)
|
||||
q.put(make_chunk(0))
|
||||
q.close()
|
||||
|
||||
worker = Worker("w-0", q, AlwaysFailProcessor(), max_retries=2)
|
||||
results = worker.run()
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].success is False
|
||||
assert results[0].error is not None
|
||||
assert worker.error_count == 1
|
||||
|
||||
def test_worker_id_on_results(self, make_chunk):
|
||||
"""Worker stamps its ID on results."""
|
||||
q = ChunkQueue(maxsize=5)
|
||||
q.put(make_chunk(0))
|
||||
q.close()
|
||||
|
||||
from core.chunker.processor import ChecksumProcessor
|
||||
worker = Worker("worker-7", q, ChecksumProcessor())
|
||||
results = worker.run()
|
||||
|
||||
assert results[0].worker_id == "worker-7"
|
||||
|
||||
def test_event_callback(self, make_chunk):
|
||||
"""Worker emits events via callback."""
|
||||
q = ChunkQueue(maxsize=5)
|
||||
q.put(make_chunk(0))
|
||||
q.close()
|
||||
|
||||
events = []
|
||||
callback = MagicMock(side_effect=lambda t, d: events.append((t, d)))
|
||||
|
||||
from core.chunker.processor import ChecksumProcessor
|
||||
worker = Worker("w-0", q, ChecksumProcessor(), event_callback=callback)
|
||||
worker.run()
|
||||
|
||||
event_types = [e[0] for e in events]
|
||||
assert "worker_status" in event_types
|
||||
assert "chunk_processing" in event_types
|
||||
assert "chunk_done" in event_types
|
||||
|
||||
def test_processed_count(self, make_chunk):
|
||||
"""Worker tracks processed count."""
|
||||
q = ChunkQueue(maxsize=10)
|
||||
for i in range(5):
|
||||
q.put(make_chunk(i))
|
||||
q.close()
|
||||
|
||||
from core.chunker.processor import ChecksumProcessor
|
||||
worker = Worker("w-0", q, ChecksumProcessor())
|
||||
worker.run()
|
||||
|
||||
assert worker.processed_count == 5
|
||||
@@ -24,9 +24,9 @@ logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(m
|
||||
|
||||
sys.path.insert(0, ".")
|
||||
|
||||
from detect.profiles.soccer import SoccerBroadcastProfile
|
||||
from detect.stages.frame_extractor import extract_frames
|
||||
from detect.stages.scene_filter import scene_filter
|
||||
from core.detect.profile import get_profile
|
||||
from core.detect.stages.frame_extractor import extract_frames
|
||||
from core.detect.stages.scene_filter import scene_filter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -24,8 +24,8 @@ logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(m
|
||||
|
||||
sys.path.insert(0, ".")
|
||||
|
||||
from detect.graph import get_pipeline
|
||||
from detect.state import DetectState
|
||||
from core.detect.graph import get_pipeline
|
||||
from core.detect.state import DetectState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -39,13 +39,13 @@ sys.path.insert(0, ".")
|
||||
|
||||
from langgraph.graph import END, StateGraph
|
||||
|
||||
from detect import emit
|
||||
from detect.models import PipelineStats
|
||||
from detect.profiles.soccer import SoccerBroadcastProfile
|
||||
from detect.stages.frame_extractor import extract_frames
|
||||
from detect.stages.scene_filter import scene_filter
|
||||
from detect.stages.edge_detector import detect_edge_regions
|
||||
from detect.state import DetectState
|
||||
from core.detect import emit
|
||||
from core.detect.models import PipelineStats
|
||||
from core.detect.profile import get_profile
|
||||
from core.detect.stages.frame_extractor import extract_frames
|
||||
from core.detect.stages.scene_filter import scene_filter
|
||||
from core.detect.stages.edge_detector import detect_edge_regions
|
||||
from core.detect.state import DetectState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -166,7 +166,7 @@ def main():
|
||||
# --- Parameter sensitivity ---
|
||||
logger.info("=== Parameter sensitivity (local debug) ===")
|
||||
|
||||
from detect.stages.edge_detector import _load_cv_edges
|
||||
from core.detect.stages.edge_detector import _load_cv_edges
|
||||
edges_mod = _load_cv_edges()
|
||||
|
||||
filtered = result.get("filtered_frames", [])
|
||||
|
||||
@@ -58,7 +58,7 @@ def extract_frames_ffmpeg(video_path: str, fps: float, max_frames: int):
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from detect.models import Frame
|
||||
from core.detect.models import Frame
|
||||
|
||||
tmpdir = tempfile.mkdtemp(prefix="scenario_")
|
||||
pattern = os.path.join(tmpdir, "frame_%04d.jpg")
|
||||
@@ -111,7 +111,7 @@ def main():
|
||||
logger.info("Extracted %d frames", len(frames))
|
||||
|
||||
# Create timeline + branch + checkpoint
|
||||
from detect.checkpoint.storage import create_timeline, save_stage_output
|
||||
from core.detect.checkpoint.storage import create_timeline, save_stage_output
|
||||
|
||||
timeline_id, branch_id = create_timeline(
|
||||
source_video=video_path,
|
||||
|
||||
@@ -58,7 +58,7 @@ def make_brand_image(text: str, width: int = 300, height: int = 100) -> str:
|
||||
|
||||
|
||||
def main():
|
||||
from detect.providers import get_provider, has_api_key, PROVIDERS
|
||||
from core.detect.providers import get_provider, has_api_key, PROVIDERS
|
||||
|
||||
provider_name = os.environ.get("CLOUD_LLM_PROVIDER", "groq")
|
||||
logger.info("Provider: %s", provider_name)
|
||||
|
||||
@@ -13,8 +13,8 @@ import sys
|
||||
|
||||
sys.path.insert(0, ".")
|
||||
|
||||
from detect.profiles.soccer import SoccerBroadcastProfile
|
||||
from detect.stages.frame_extractor import extract_frames
|
||||
from core.detect.profile import get_profile
|
||||
from core.detect.stages.frame_extractor import extract_frames
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -86,9 +86,9 @@ def test_ocr_stage_remote(url: str):
|
||||
logger.info("--- OCR stage (remote mode) ---")
|
||||
|
||||
sys.path.insert(0, ".")
|
||||
from detect.models import BoundingBox, Frame
|
||||
from detect.profiles.base import OCRConfig
|
||||
from detect.stages.ocr_stage import run_ocr
|
||||
from core.detect.models import BoundingBox, Frame
|
||||
from core.detect.stages.models import OCRConfig
|
||||
from core.detect.stages.ocr_stage import run_ocr
|
||||
|
||||
# Create a frame with text baked in
|
||||
image = make_text_image("EMIRATES")
|
||||
|
||||
@@ -48,10 +48,10 @@ def main():
|
||||
# Override Redis to localhost (ctrl/.env has k8s hostname)
|
||||
os.environ["REDIS_URL"] = f"redis://localhost:{args.port}/0"
|
||||
|
||||
from detect.graph import get_pipeline, NODES
|
||||
from detect.checkpoint import list_checkpoints
|
||||
from detect.checkpoint import replay_from
|
||||
from detect.state import DetectState
|
||||
from core.detect.graph import get_pipeline, NODES
|
||||
from core.detect.checkpoint import list_checkpoints
|
||||
from core.detect.checkpoint import replay_from
|
||||
from core.detect.state import DetectState
|
||||
|
||||
VIDEO = "media/out/chunks/95043d50-4df6-4ac8-bbd5-2ba873117c6e/chunk_0000.mp4"
|
||||
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from detect.models import BoundingBox, BrandDetection, PipelineStats
|
||||
from detect.stages.aggregator import compile_report, _merge_contiguous
|
||||
from core.detect.models import BoundingBox, BrandDetection, PipelineStats
|
||||
from core.detect.stages.aggregator import compile_report, _merge_contiguous
|
||||
|
||||
|
||||
def _make_detection(brand: str, timestamp: float, duration: float = 0.5,
|
||||
@@ -43,7 +43,7 @@ def test_merge_empty():
|
||||
|
||||
def test_compile_report(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
dets = [
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from detect.models import BoundingBox, Frame, TextCandidate
|
||||
from detect.profiles.base import ResolverConfig
|
||||
from detect.stages.brand_resolver import resolve_brands, _normalize, _match_session
|
||||
from core.detect.models import BoundingBox, Frame, TextCandidate
|
||||
from core.detect.stages.models import ResolverConfig
|
||||
from core.detect.stages.brand_resolver import resolve_brands, _normalize, _match_session
|
||||
|
||||
|
||||
CONFIG = ResolverConfig(fuzzy_threshold=75)
|
||||
@@ -28,7 +28,7 @@ def test_session_match():
|
||||
|
||||
def test_resolve_with_session(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
session = {"nike": "Nike", "emirates": "Emirates"}
|
||||
@@ -46,7 +46,7 @@ def test_resolve_with_session(monkeypatch):
|
||||
|
||||
def test_resolve_unresolved_without_db(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
candidates = [_make_candidate("random garbage text")]
|
||||
@@ -61,7 +61,7 @@ def test_resolve_unresolved_without_db(monkeypatch):
|
||||
|
||||
def test_resolve_empty(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
matched, unresolved = resolve_brands([], CONFIG, session_brands={})
|
||||
@@ -73,7 +73,7 @@ def test_resolve_empty(monkeypatch):
|
||||
def test_resolve_builds_session_during_run(monkeypatch):
|
||||
"""Session brands accumulate during a single run — second candidate benefits."""
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
session = {"nike": "Nike"}
|
||||
@@ -93,7 +93,7 @@ def test_resolve_builds_session_during_run(monkeypatch):
|
||||
|
||||
def test_events_emitted(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
session = {"nike": "Nike"}
|
||||
|
||||
@@ -5,7 +5,7 @@ import json
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from detect.models import BoundingBox, BrandDetection, Frame, PipelineStats, TextCandidate
|
||||
from core.detect.models import BoundingBox, BrandDetection, Frame, PipelineStats, TextCandidate
|
||||
from core.schema.serializers._common import safe_construct
|
||||
from core.schema.serializers.pipeline import (
|
||||
serialize_frame_meta,
|
||||
@@ -163,34 +163,39 @@ def test_all_serialized_is_json_compatible():
|
||||
assert roundtrip["frame_meta"]["sequence"] == frame.sequence
|
||||
|
||||
|
||||
# --- OverrideProfile ---
|
||||
# --- Config overrides (dict merging, replaces OverrideProfile) ---
|
||||
|
||||
def test_override_profile_region_analysis():
|
||||
"""OverrideProfile must patch region_analysis_config with overrides."""
|
||||
from detect.checkpoint.replay import OverrideProfile
|
||||
from detect.profiles.soccer import SoccerBroadcastProfile
|
||||
from detect.profiles.base import RegionAnalysisConfig
|
||||
def test_config_override_region_analysis():
|
||||
"""Config overrides must patch stage config values."""
|
||||
from core.detect.profile import get_profile, get_stage_config
|
||||
from core.detect.stages.models import RegionAnalysisConfig
|
||||
|
||||
base = SoccerBroadcastProfile()
|
||||
original = base.region_analysis_config()
|
||||
profile = get_profile("soccer_broadcast")
|
||||
original = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
|
||||
|
||||
overrides = {"region_analysis": {"edge_canny_low": 25, "edge_canny_high": 200}}
|
||||
wrapped = OverrideProfile(base, overrides)
|
||||
patched = wrapped.region_analysis_config()
|
||||
overrides = {"detect_edges": {"edge_canny_low": 25, "edge_canny_high": 200}}
|
||||
merged_configs = {**profile["configs"]}
|
||||
merged_configs["detect_edges"] = {**merged_configs["detect_edges"], **overrides["detect_edges"]}
|
||||
patched_profile = {**profile, "configs": merged_configs}
|
||||
|
||||
patched = RegionAnalysisConfig(**get_stage_config(patched_profile, "detect_edges"))
|
||||
|
||||
assert isinstance(patched, RegionAnalysisConfig)
|
||||
assert patched.edge_canny_low == 25
|
||||
assert patched.edge_canny_high == 200
|
||||
# Unmodified fields keep their defaults
|
||||
assert patched.edge_hough_threshold == original.edge_hough_threshold
|
||||
|
||||
|
||||
def test_override_profile_passthrough():
|
||||
"""OverrideProfile without region_analysis key passes through unchanged."""
|
||||
from detect.checkpoint.replay import OverrideProfile
|
||||
from detect.profiles.soccer import SoccerBroadcastProfile
|
||||
def test_config_override_passthrough():
|
||||
"""Overrides for other stages don't affect unrelated stages."""
|
||||
from core.detect.profile import get_profile, get_stage_config
|
||||
from core.detect.stages.models import RegionAnalysisConfig
|
||||
|
||||
base = SoccerBroadcastProfile()
|
||||
wrapped = OverrideProfile(base, {"ocr": {"min_confidence": 0.1}})
|
||||
config = wrapped.region_analysis_config()
|
||||
assert config.edge_canny_low == base.region_analysis_config().edge_canny_low
|
||||
profile = get_profile("soccer_broadcast")
|
||||
original = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
|
||||
|
||||
overrides = {"run_ocr": {"min_confidence": 0.1}}
|
||||
merged_configs = {**profile["configs"], **overrides}
|
||||
patched_profile = {**profile, "configs": merged_configs}
|
||||
|
||||
patched = RegionAnalysisConfig(**get_stage_config(patched_profile, "detect_edges"))
|
||||
assert patched.edge_canny_low == original.edge_canny_low
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Tests for the config endpoint and stage palette."""
|
||||
|
||||
from detect.stages import list_stages, get_palette
|
||||
from core.detect.stages import list_stages, get_palette
|
||||
|
||||
|
||||
def test_stage_palette_has_config_fields():
|
||||
|
||||
@@ -15,7 +15,7 @@ import pytest
|
||||
|
||||
# Load edges module directly
|
||||
_spec = importlib.util.spec_from_file_location(
|
||||
"cv_edges", Path("gpu/models/cv/edges.py"),
|
||||
"cv_edges", Path("core/gpu/models/cv/edges.py"),
|
||||
)
|
||||
_edges_mod = importlib.util.module_from_spec(_spec)
|
||||
_spec.loader.exec_module(_edges_mod)
|
||||
|
||||
@@ -5,8 +5,8 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from detect.profiles.base import FrameExtractionConfig
|
||||
from detect.stages.frame_extractor import extract_frames
|
||||
from core.detect.stages.models import FrameExtractionConfig
|
||||
from core.detect.stages.frame_extractor import extract_frames
|
||||
|
||||
SAMPLE_DIR = Path("media/out/chunks/95043d50-4df6-4ac8-bbd5-2ba873117c6e")
|
||||
|
||||
@@ -61,7 +61,7 @@ def test_extract_frames_with_events(monkeypatch):
|
||||
def mock_push(job_id, event_type, data):
|
||||
events.append((job_id, event_type, data))
|
||||
|
||||
monkeypatch.setattr("detect.emit.push_detect_event", mock_push)
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event", mock_push)
|
||||
|
||||
video = _get_sample_video()
|
||||
config = FrameExtractionConfig(fps=1, max_frames=5)
|
||||
|
||||
@@ -4,9 +4,9 @@ import os
|
||||
|
||||
import pytest
|
||||
|
||||
from detect.graph import NODES, build_graph, get_pipeline
|
||||
from detect.models import PipelineStats
|
||||
from detect.state import DetectState
|
||||
from core.detect.graph import NODES, build_graph, get_pipeline
|
||||
from core.detect.models import PipelineStats
|
||||
from core.detect.state import DetectState
|
||||
|
||||
VIDEO = "media/out/chunks/95043d50-4df6-4ac8-bbd5-2ba873117c6e/chunk_0000.mp4"
|
||||
|
||||
@@ -42,7 +42,7 @@ def test_graph_has_all_nodes():
|
||||
def test_graph_runs_end_to_end(monkeypatch):
|
||||
"""Run the full graph with mocked event emission."""
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
pipeline = get_pipeline()
|
||||
@@ -75,7 +75,7 @@ def test_graph_runs_end_to_end(monkeypatch):
|
||||
def test_graph_node_transitions(monkeypatch):
|
||||
"""Verify each node emits running → done transitions."""
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
pipeline = get_pipeline()
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from detect.models import BoundingBox, Frame
|
||||
from detect.profiles.base import OCRConfig
|
||||
from detect.stages.ocr_stage import _crop_region, _parse_ocr_raw, run_ocr
|
||||
from core.detect.models import BoundingBox, Frame
|
||||
from core.detect.stages.models import OCRConfig
|
||||
from core.detect.stages.ocr_stage import _crop_region, _parse_ocr_raw, run_ocr
|
||||
|
||||
|
||||
def _has_paddleocr() -> bool:
|
||||
@@ -80,7 +80,7 @@ def test_parse_empty():
|
||||
|
||||
def test_run_ocr_remote(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
class FakeResult:
|
||||
@@ -94,11 +94,11 @@ def test_run_ocr_remote(monkeypatch):
|
||||
def ocr(self, image, languages):
|
||||
return [FakeResult("NIKE", 0.92)]
|
||||
|
||||
monkeypatch.setattr("detect.stages.ocr_stage.InferenceClient", FakeClient,
|
||||
monkeypatch.setattr("core.detect.stages.ocr_stage.InferenceClient", FakeClient,
|
||||
raising=False)
|
||||
# Patch the import path used in the function
|
||||
import detect.stages.ocr_stage as mod
|
||||
monkeypatch.setattr("detect.inference.InferenceClient", FakeClient)
|
||||
import core.detect.stages.ocr_stage as mod
|
||||
monkeypatch.setattr("core.detect.inference.InferenceClient", FakeClient)
|
||||
|
||||
frame = _make_frame()
|
||||
box = _make_box()
|
||||
@@ -123,7 +123,7 @@ def test_run_ocr_remote(monkeypatch):
|
||||
)
|
||||
def test_run_ocr_skips_empty_crop(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
frame = _make_frame(w=10, h=10)
|
||||
|
||||
@@ -26,7 +26,7 @@ def _make_image(w: int = 200, h: int = 60) -> np.ndarray:
|
||||
|
||||
@requires_cv2
|
||||
def test_binarize():
|
||||
from gpu.models.preprocess import binarize
|
||||
from core.gpu.models.preprocess import binarize
|
||||
|
||||
img = _make_image()
|
||||
result = binarize(img)
|
||||
@@ -40,7 +40,7 @@ def test_binarize():
|
||||
|
||||
@requires_cv2
|
||||
def test_enhance_contrast():
|
||||
from gpu.models.preprocess import enhance_contrast
|
||||
from core.gpu.models.preprocess import enhance_contrast
|
||||
|
||||
img = _make_image()
|
||||
result = enhance_contrast(img)
|
||||
@@ -51,7 +51,7 @@ def test_enhance_contrast():
|
||||
|
||||
@requires_cv2
|
||||
def test_deskew_no_rotation():
|
||||
from gpu.models.preprocess import deskew
|
||||
from core.gpu.models.preprocess import deskew
|
||||
|
||||
img = _make_image()
|
||||
result = deskew(img)
|
||||
@@ -63,7 +63,7 @@ def test_deskew_no_rotation():
|
||||
|
||||
@requires_cv2
|
||||
def test_preprocess_pipeline():
|
||||
from gpu.models.preprocess import preprocess
|
||||
from core.gpu.models.preprocess import preprocess
|
||||
|
||||
img = _make_image()
|
||||
|
||||
@@ -76,7 +76,7 @@ def test_preprocess_pipeline():
|
||||
|
||||
@requires_cv2
|
||||
def test_preprocess_all_disabled():
|
||||
from gpu.models.preprocess import preprocess
|
||||
from core.gpu.models.preprocess import preprocess
|
||||
|
||||
img = _make_image()
|
||||
result = preprocess(img, do_binarize=False, do_deskew=False, do_contrast=False)
|
||||
|
||||
@@ -1,55 +1,70 @@
|
||||
"""Tests for ContentTypeProfile implementations."""
|
||||
"""Tests for profile data and helper functions."""
|
||||
|
||||
import pytest
|
||||
|
||||
from detect.models import BrandDetection
|
||||
from detect.profiles.base import ContentTypeProfile, CropContext
|
||||
from detect.profiles.soccer import SoccerBroadcastProfile
|
||||
from detect.profiles.stubs import AdvertisingProfile, NewsBroadcastProfile, TranscriptProfile
|
||||
from core.detect.models import BrandDetection, CropContext
|
||||
from core.detect.profile import get_profile, get_stage_config, build_vlm_prompt, aggregate_detections, pipeline_config_from_dict
|
||||
from core.detect.stages.models import FrameExtractionConfig, DetectionConfig, ResolverConfig
|
||||
|
||||
|
||||
def test_soccer_satisfies_protocol():
|
||||
profile: ContentTypeProfile = SoccerBroadcastProfile()
|
||||
assert profile.name == "soccer_broadcast"
|
||||
def test_soccer_profile_exists():
|
||||
profile = get_profile("soccer_broadcast")
|
||||
assert profile["name"] == "soccer_broadcast"
|
||||
|
||||
|
||||
def test_soccer_has_pipeline():
|
||||
profile = get_profile("soccer_broadcast")
|
||||
assert "stages" in profile["pipeline"]
|
||||
assert "edges" in profile["pipeline"]
|
||||
|
||||
|
||||
def test_soccer_has_configs():
|
||||
profile = get_profile("soccer_broadcast")
|
||||
configs = profile["configs"]
|
||||
assert "extract_frames" in configs
|
||||
assert "filter_scenes" in configs
|
||||
assert "detect_edges" in configs
|
||||
|
||||
|
||||
def test_soccer_frame_extraction_config():
|
||||
cfg = SoccerBroadcastProfile().frame_extraction_config()
|
||||
profile = get_profile("soccer_broadcast")
|
||||
cfg = FrameExtractionConfig(**get_stage_config(profile, "extract_frames"))
|
||||
assert cfg.fps > 0
|
||||
assert cfg.max_frames > 0
|
||||
|
||||
|
||||
def test_soccer_detection_config():
|
||||
cfg = SoccerBroadcastProfile().detection_config()
|
||||
profile = get_profile("soccer_broadcast")
|
||||
cfg = DetectionConfig(**get_stage_config(profile, "detect_objects"))
|
||||
assert 0 < cfg.confidence_threshold < 1
|
||||
assert isinstance(cfg.target_classes, list)
|
||||
|
||||
|
||||
def test_soccer_resolver_config():
|
||||
cfg = SoccerBroadcastProfile().resolver_config()
|
||||
profile = get_profile("soccer_broadcast")
|
||||
cfg = ResolverConfig(**get_stage_config(profile, "match_brands"))
|
||||
assert cfg.fuzzy_threshold > 0
|
||||
|
||||
|
||||
def test_soccer_vlm_prompt():
|
||||
def test_vlm_prompt():
|
||||
ctx = CropContext(image=b"fake", surrounding_text="Emirates", position_hint="top-center")
|
||||
prompt = SoccerBroadcastProfile().vlm_prompt(ctx)
|
||||
template = get_profile("soccer_broadcast")["configs"]["escalate_vlm"]["vlm_prompt_template"]
|
||||
prompt = build_vlm_prompt(ctx, template)
|
||||
assert "brand" in prompt.lower()
|
||||
assert "Emirates" in prompt
|
||||
|
||||
|
||||
def test_soccer_aggregate_empty():
|
||||
report = SoccerBroadcastProfile().aggregate([])
|
||||
def test_aggregate_empty():
|
||||
report = aggregate_detections([], "soccer_broadcast")
|
||||
assert len(report.brands) == 0
|
||||
assert len(report.timeline) == 0
|
||||
|
||||
|
||||
def test_soccer_aggregate_groups():
|
||||
def test_aggregate_groups():
|
||||
detections = [
|
||||
BrandDetection(brand="Nike", timestamp=1.0, duration=0.5, confidence=0.9, source="ocr"),
|
||||
BrandDetection(brand="Nike", timestamp=2.0, duration=0.5, confidence=0.8, source="ocr"),
|
||||
BrandDetection(brand="Adidas", timestamp=3.0, duration=0.5, confidence=0.7, source="logo_match"),
|
||||
]
|
||||
report = SoccerBroadcastProfile().aggregate(detections)
|
||||
report = aggregate_detections(detections, "soccer_broadcast")
|
||||
assert "Nike" in report.brands
|
||||
assert "Adidas" in report.brands
|
||||
assert report.brands["Nike"].total_appearances == 2
|
||||
@@ -57,15 +72,9 @@ def test_soccer_aggregate_groups():
|
||||
assert report.timeline == sorted(report.timeline, key=lambda d: d.timestamp)
|
||||
|
||||
|
||||
def test_soccer_auxiliary_returns_empty():
|
||||
assert SoccerBroadcastProfile().auxiliary_detections("test.mp4") == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stub_cls", [NewsBroadcastProfile, AdvertisingProfile, TranscriptProfile])
|
||||
def test_stubs_raise(stub_cls):
|
||||
stub = stub_cls()
|
||||
assert isinstance(stub.name, str)
|
||||
with pytest.raises(NotImplementedError):
|
||||
stub.frame_extraction_config()
|
||||
with pytest.raises(NotImplementedError):
|
||||
stub.resolver_config()
|
||||
def test_pipeline_config():
|
||||
profile = get_profile("soccer_broadcast")
|
||||
config = pipeline_config_from_dict(profile["pipeline"])
|
||||
assert config.name == "soccer_broadcast"
|
||||
assert len(config.stages) > 0
|
||||
assert len(config.edges) > 0
|
||||
|
||||
@@ -6,14 +6,14 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from detect.models import BoundingBox, Frame
|
||||
from detect.profiles.base import RegionAnalysisConfig
|
||||
from detect.profiles.soccer import SoccerBroadcastProfile
|
||||
from core.detect.models import BoundingBox, Frame
|
||||
from core.detect.stages.models import RegionAnalysisConfig
|
||||
from core.detect.profile import get_profile, get_stage_config
|
||||
|
||||
|
||||
# Load edges module directly — gpu/models/__init__.py has GPU-only imports
|
||||
_spec = importlib.util.spec_from_file_location(
|
||||
"cv_edges", Path("gpu/models/cv/edges.py"),
|
||||
"cv_edges", Path("core/gpu/models/cv/edges.py"),
|
||||
)
|
||||
_edges_mod = importlib.util.module_from_spec(_spec)
|
||||
_spec.loader.exec_module(_edges_mod)
|
||||
@@ -40,8 +40,8 @@ def _make_frame_with_lines(seq: int = 0) -> Frame:
|
||||
# --- Config ---
|
||||
|
||||
def test_soccer_profile_has_region_analysis_config():
|
||||
profile = SoccerBroadcastProfile()
|
||||
config = profile.region_analysis_config()
|
||||
profile = get_profile("soccer_broadcast")
|
||||
config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
|
||||
assert isinstance(config, RegionAnalysisConfig)
|
||||
assert config.enabled is True
|
||||
|
||||
@@ -133,9 +133,9 @@ def test_detect_edges_debug_blank_frame():
|
||||
|
||||
def test_stage_disabled(monkeypatch):
|
||||
"""When disabled, returns empty dict."""
|
||||
monkeypatch.setattr("detect.emit.push_detect_event", lambda *a, **kw: None)
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event", lambda *a, **kw: None)
|
||||
|
||||
from detect.stages.edge_detector import detect_edge_regions
|
||||
from core.detect.stages.edge_detector import detect_edge_regions
|
||||
|
||||
config = RegionAnalysisConfig(enabled=False)
|
||||
result = detect_edge_regions([_make_frame()], config, job_id="test")
|
||||
@@ -144,9 +144,9 @@ def test_stage_disabled(monkeypatch):
|
||||
|
||||
def test_stage_local_blank(monkeypatch):
|
||||
"""Local mode on blank frames returns empty boxes."""
|
||||
monkeypatch.setattr("detect.emit.push_detect_event", lambda *a, **kw: None)
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event", lambda *a, **kw: None)
|
||||
|
||||
from detect.stages.edge_detector import detect_edge_regions
|
||||
from core.detect.stages.edge_detector import detect_edge_regions
|
||||
|
||||
config = RegionAnalysisConfig()
|
||||
result = detect_edge_regions([_make_frame()], config, job_id="test")
|
||||
@@ -156,9 +156,9 @@ def test_stage_local_blank(monkeypatch):
|
||||
|
||||
def test_stage_local_with_lines(monkeypatch):
|
||||
"""Local mode on frame with lines should find regions."""
|
||||
monkeypatch.setattr("detect.emit.push_detect_event", lambda *a, **kw: None)
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event", lambda *a, **kw: None)
|
||||
|
||||
from detect.stages.edge_detector import detect_edge_regions
|
||||
from core.detect.stages.edge_detector import detect_edge_regions
|
||||
|
||||
config = RegionAnalysisConfig()
|
||||
frame = _make_frame_with_lines()
|
||||
@@ -174,22 +174,22 @@ def test_stage_local_with_lines(monkeypatch):
|
||||
|
||||
def test_detect_edges_in_nodes():
|
||||
"""detect_edges must be in the pipeline node list."""
|
||||
from detect.graph import NODES, NODE_FUNCTIONS
|
||||
from core.detect.graph import NODES, NODE_FUNCTIONS
|
||||
|
||||
assert "detect_edges" in NODES
|
||||
node_names = [name for name, _ in NODE_FUNCTIONS]
|
||||
assert "detect_edges" in node_names
|
||||
|
||||
# Must be after filter_scenes, before detect_objects
|
||||
# Must be after field_segmentation, before detect_objects
|
||||
idx = NODES.index("detect_edges")
|
||||
assert NODES[idx - 1] == "filter_scenes"
|
||||
assert NODES[idx - 1] == "field_segmentation"
|
||||
assert NODES[idx + 1] == "detect_objects"
|
||||
|
||||
|
||||
# --- State ---
|
||||
|
||||
def test_state_has_edge_regions_field():
|
||||
from detect.state import DetectState
|
||||
from core.detect.state import DetectState
|
||||
|
||||
hints = DetectState.__annotations__
|
||||
assert "edge_regions_by_frame" in hints
|
||||
|
||||
@@ -1,87 +1,67 @@
|
||||
"""Tests for replay and OverrideProfile."""
|
||||
"""Tests for config overrides and replay."""
|
||||
|
||||
import pytest
|
||||
|
||||
from detect.profiles.soccer import SoccerBroadcastProfile
|
||||
from detect.profiles.base import RegionAnalysisConfig
|
||||
from detect.checkpoint.replay import OverrideProfile, replay_single_stage
|
||||
from core.detect.profile import get_profile, get_stage_config
|
||||
from core.detect.stages.models import RegionAnalysisConfig, OCRConfig, ResolverConfig
|
||||
from core.detect.checkpoint.replay import replay_single_stage
|
||||
|
||||
|
||||
def test_override_profile_patches_ocr():
|
||||
base = SoccerBroadcastProfile()
|
||||
overrides = {"ocr": {"min_confidence": 0.3, "languages": ["en", "es", "pt"]}}
|
||||
profile = OverrideProfile(base, overrides)
|
||||
def _apply_overrides(profile, overrides):
|
||||
"""Apply config overrides to a profile dict (same logic as nodes._load_profile)."""
|
||||
merged_configs = dict(profile.get("configs", {}))
|
||||
for stage_name, stage_overrides in overrides.items():
|
||||
if stage_name in merged_configs:
|
||||
merged_configs[stage_name] = {**merged_configs[stage_name], **stage_overrides}
|
||||
else:
|
||||
merged_configs[stage_name] = stage_overrides
|
||||
return {**profile, "configs": merged_configs}
|
||||
|
||||
config = profile.ocr_config()
|
||||
|
||||
def test_override_patches_ocr():
|
||||
profile = get_profile("soccer_broadcast")
|
||||
overrides = {"run_ocr": {"min_confidence": 0.3, "languages": ["en", "es", "pt"]}}
|
||||
patched = _apply_overrides(profile, overrides)
|
||||
|
||||
config = OCRConfig(**get_stage_config(patched, "run_ocr"))
|
||||
|
||||
assert config.min_confidence == 0.3
|
||||
assert config.languages == ["en", "es", "pt"]
|
||||
|
||||
|
||||
def test_override_profile_patches_resolver():
|
||||
base = SoccerBroadcastProfile()
|
||||
overrides = {"resolver": {"fuzzy_threshold": 60}}
|
||||
profile = OverrideProfile(base, overrides)
|
||||
def test_override_patches_resolver():
|
||||
profile = get_profile("soccer_broadcast")
|
||||
overrides = {"match_brands": {"fuzzy_threshold": 60}}
|
||||
patched = _apply_overrides(profile, overrides)
|
||||
|
||||
config = profile.resolver_config()
|
||||
config = ResolverConfig(**get_stage_config(patched, "match_brands"))
|
||||
|
||||
assert config.fuzzy_threshold == 60
|
||||
|
||||
|
||||
def test_override_profile_patches_detection():
|
||||
base = SoccerBroadcastProfile()
|
||||
overrides = {"detection": {"confidence_threshold": 0.5}}
|
||||
profile = OverrideProfile(base, overrides)
|
||||
def test_override_no_overrides():
|
||||
profile = get_profile("soccer_broadcast")
|
||||
patched = _apply_overrides(profile, {})
|
||||
|
||||
config = profile.detection_config()
|
||||
|
||||
assert config.confidence_threshold == 0.5
|
||||
|
||||
|
||||
def test_override_profile_no_overrides():
|
||||
base = SoccerBroadcastProfile()
|
||||
profile = OverrideProfile(base, {})
|
||||
|
||||
ocr = profile.ocr_config()
|
||||
base_ocr = base.ocr_config()
|
||||
ocr = OCRConfig(**get_stage_config(patched, "run_ocr"))
|
||||
base_ocr = OCRConfig(**get_stage_config(profile, "run_ocr"))
|
||||
|
||||
assert ocr.min_confidence == base_ocr.min_confidence
|
||||
assert ocr.languages == base_ocr.languages
|
||||
|
||||
|
||||
def test_override_profile_delegates_non_config():
|
||||
base = SoccerBroadcastProfile()
|
||||
profile = OverrideProfile(base, {"ocr": {"min_confidence": 0.1}})
|
||||
def test_override_patches_region_analysis():
|
||||
profile = get_profile("soccer_broadcast")
|
||||
overrides = {"detect_edges": {"edge_canny_low": 25, "edge_canny_high": 200}}
|
||||
patched = _apply_overrides(profile, overrides)
|
||||
|
||||
assert profile.name == "soccer_broadcast"
|
||||
assert profile.resolver_config().fuzzy_threshold > 0
|
||||
config = RegionAnalysisConfig(**get_stage_config(patched, "detect_edges"))
|
||||
|
||||
|
||||
def test_override_profile_ignores_unknown_fields():
|
||||
base = SoccerBroadcastProfile()
|
||||
overrides = {"ocr": {"nonexistent_field": 42}}
|
||||
profile = OverrideProfile(base, overrides)
|
||||
|
||||
config = profile.ocr_config()
|
||||
|
||||
assert not hasattr(config, "nonexistent_field")
|
||||
assert config.min_confidence == base.ocr_config().min_confidence
|
||||
|
||||
|
||||
# --- OverrideProfile for region_analysis ---
|
||||
|
||||
def test_override_profile_patches_region_analysis():
|
||||
base = SoccerBroadcastProfile()
|
||||
overrides = {"region_analysis": {"edge_canny_low": 25, "edge_canny_high": 200}}
|
||||
profile = OverrideProfile(base, overrides)
|
||||
|
||||
config = profile.region_analysis_config()
|
||||
|
||||
assert isinstance(config, RegionAnalysisConfig)
|
||||
assert config.edge_canny_low == 25
|
||||
assert config.edge_canny_high == 200
|
||||
# Unchanged fields keep defaults
|
||||
assert config.edge_hough_threshold == base.region_analysis_config().edge_hough_threshold
|
||||
# Unchanged fields keep defaults from profile
|
||||
base_config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
|
||||
assert config.edge_hough_threshold == base_config.edge_hough_threshold
|
||||
|
||||
|
||||
# --- replay_single_stage ---
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from detect.models import Frame
|
||||
from detect.profiles.base import SceneFilterConfig
|
||||
from detect.stages.scene_filter import scene_filter
|
||||
from core.detect.models import Frame
|
||||
from core.detect.stages.models import SceneFilterConfig
|
||||
from core.detect.stages.scene_filter import scene_filter
|
||||
|
||||
|
||||
def _make_frame(seq: int, color: tuple[int, int, int] = (128, 128, 128)) -> Frame:
|
||||
@@ -72,7 +72,7 @@ def test_hashes_populated():
|
||||
|
||||
def test_events_emitted(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
frames = [_make_frame(i) for i in range(5)]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Round-trip serialization tests for SSE contract models."""
|
||||
|
||||
from detect.sse_contract import (
|
||||
from core.detect.sse import (
|
||||
BoundingBoxEvent,
|
||||
BrandSummary,
|
||||
Detection,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Tests for the stage registry."""
|
||||
|
||||
from detect.stages import list_stages, get_stage, get_palette
|
||||
from detect.stages.base import get_stage_class
|
||||
from core.detect.stages import list_stages, get_stage, get_palette
|
||||
from core.detect.stages.base import get_stage_class
|
||||
|
||||
|
||||
EXPECTED_STAGES = [
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from detect.tracing import trace_node, SpanContext, flush
|
||||
from core.detect.tracing import trace_node, SpanContext, flush
|
||||
|
||||
|
||||
def test_trace_node_noop():
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from detect.models import BoundingBox, Frame, PipelineStats, TextCandidate
|
||||
from detect.stages.vlm_cloud import escalate_cloud, _parse_response
|
||||
from core.detect.models import BoundingBox, Frame, PipelineStats, TextCandidate
|
||||
from core.detect.stages.vlm_cloud import escalate_cloud, _parse_response
|
||||
|
||||
|
||||
def _make_candidate(text: str = "unknown", confidence: float = 0.4) -> TextCandidate:
|
||||
@@ -30,14 +30,14 @@ def test_parse_response_no_confidence():
|
||||
|
||||
def test_escalate_skips_without_api_key(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
monkeypatch.setenv("CLOUD_LLM_PROVIDER", "groq")
|
||||
# Reset cached provider
|
||||
import detect.providers as prov
|
||||
import core.detect.providers as prov
|
||||
monkeypatch.setattr(prov, "_cached", None)
|
||||
|
||||
candidates = [_make_candidate()]
|
||||
@@ -54,7 +54,7 @@ def test_escalate_skips_without_api_key(monkeypatch):
|
||||
|
||||
def test_escalate_empty_candidates(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
stats = PipelineStats()
|
||||
@@ -66,18 +66,18 @@ def test_escalate_empty_candidates(monkeypatch):
|
||||
|
||||
def test_escalate_with_mock_api(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
monkeypatch.setenv("GROQ_API_KEY", "test-key")
|
||||
monkeypatch.setenv("CLOUD_LLM_PROVIDER", "groq")
|
||||
# Reset cached provider
|
||||
import detect.providers as prov
|
||||
import core.detect.providers as prov
|
||||
monkeypatch.setattr(prov, "_cached", None)
|
||||
|
||||
def mock_call(image_b64, prompt):
|
||||
return {"brand": "Heineken", "confidence": 0.75, "reasoning": "green logo", "tokens": 300}
|
||||
|
||||
monkeypatch.setattr("detect.stages.vlm_cloud._call_cloud_api", mock_call)
|
||||
monkeypatch.setattr("core.detect.stages.vlm_cloud._call_cloud_api", mock_call)
|
||||
|
||||
candidates = [_make_candidate("unknown logo")]
|
||||
stats = PipelineStats()
|
||||
|
||||
Reference in New Issue
Block a user