schema clean up and refactor
This commit is contained in:
97
core/schema/models/detect_pipeline.py
Normal file
97
core/schema/models/detect_pipeline.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
"""
|
||||||
|
Detection pipeline runtime models.
|
||||||
|
|
||||||
|
These are the data structures that flow between LangGraph nodes.
|
||||||
|
They contain runtime types (np.ndarray) so they are NOT generated
|
||||||
|
by modelgen — they live here for the schema to be the complete
|
||||||
|
map of the application, but modelgen skips them.
|
||||||
|
|
||||||
|
Wire-format models (SSE events) are in detect.py.
|
||||||
|
DB models (jobs, checkpoints) are in detect_jobs.py.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Frame:
|
||||||
|
sequence: int
|
||||||
|
chunk_id: int
|
||||||
|
timestamp: float # position in video (seconds)
|
||||||
|
image: np.ndarray
|
||||||
|
perceptual_hash: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BoundingBox:
|
||||||
|
x: int
|
||||||
|
y: int
|
||||||
|
w: int
|
||||||
|
h: int
|
||||||
|
confidence: float
|
||||||
|
label: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TextCandidate:
|
||||||
|
frame: Frame
|
||||||
|
bbox: BoundingBox
|
||||||
|
text: str
|
||||||
|
ocr_confidence: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BrandDetection:
|
||||||
|
brand: str
|
||||||
|
timestamp: float
|
||||||
|
duration: float
|
||||||
|
confidence: float
|
||||||
|
source: Literal["ocr", "local_vlm", "cloud_llm", "logo_match", "auxiliary"]
|
||||||
|
bbox: BoundingBox | None = None
|
||||||
|
frame_ref: int | None = None
|
||||||
|
content_type: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BrandStats:
|
||||||
|
total_appearances: int = 0
|
||||||
|
total_screen_time: float = 0.0
|
||||||
|
avg_confidence: float = 0.0
|
||||||
|
first_seen: float = 0.0
|
||||||
|
last_seen: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PipelineStats:
|
||||||
|
frames_extracted: int = 0
|
||||||
|
frames_after_scene_filter: int = 0
|
||||||
|
regions_detected: int = 0
|
||||||
|
regions_resolved_by_ocr: int = 0
|
||||||
|
regions_escalated_to_local_vlm: int = 0
|
||||||
|
regions_escalated_to_cloud_llm: int = 0
|
||||||
|
auxiliary_detections: int = 0
|
||||||
|
cloud_llm_calls: int = 0
|
||||||
|
processing_time_seconds: float = 0.0
|
||||||
|
estimated_cloud_cost_usd: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DetectionReport:
|
||||||
|
video_source: str
|
||||||
|
content_type: str
|
||||||
|
duration_seconds: float
|
||||||
|
brands: dict[str, BrandStats] = field(default_factory=dict)
|
||||||
|
timeline: list[BrandDetection] = field(default_factory=list)
|
||||||
|
pipeline_stats: PipelineStats = field(default_factory=PipelineStats)
|
||||||
|
|
||||||
|
|
||||||
|
# Not in DATACLASSES — modelgen skips these (they contain np.ndarray)
|
||||||
|
RUNTIME_MODELS = [
|
||||||
|
Frame, BoundingBox, TextCandidate, BrandDetection,
|
||||||
|
BrandStats, PipelineStats, DetectionReport,
|
||||||
|
]
|
||||||
11
core/schema/serializers/__init__.py
Normal file
11
core/schema/serializers/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
"""
|
||||||
|
Model serializers — one module per model group, mirroring core/schema/models/.
|
||||||
|
|
||||||
|
models/detect_pipeline.py → serializers/detect_pipeline.py
|
||||||
|
models/detect_jobs.py → serializers/detect_jobs.py
|
||||||
|
models/detect.py → serializers/detect.py (SSE events)
|
||||||
|
|
||||||
|
Common utilities in _common.py.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from ._common import safe_construct, serialize_dataclass, serialize_dataclass_list
|
||||||
38
core/schema/serializers/_common.py
Normal file
38
core/schema/serializers/_common.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""Common serialization utilities."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def safe_construct(cls, data: dict):
|
||||||
|
"""
|
||||||
|
Construct a dataclass from a dict, tolerant of schema changes.
|
||||||
|
|
||||||
|
- Ignores keys not in the dataclass (field was removed)
|
||||||
|
- Uses defaults for missing keys (field was added)
|
||||||
|
- Logs at debug level for mismatches
|
||||||
|
"""
|
||||||
|
field_names = {f.name for f in dataclasses.fields(cls)}
|
||||||
|
|
||||||
|
known = {}
|
||||||
|
for k, v in data.items():
|
||||||
|
if k in field_names:
|
||||||
|
known[k] = v
|
||||||
|
else:
|
||||||
|
logger.debug("Ignoring unknown field %s.%s", cls.__name__, k)
|
||||||
|
|
||||||
|
return cls(**known)
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_dataclass(obj) -> dict:
|
||||||
|
"""Serialize any dataclass to dict via dataclasses.asdict()."""
|
||||||
|
return dataclasses.asdict(obj)
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_dataclass_list(items) -> list[dict]:
|
||||||
|
"""Serialize a list of dataclasses."""
|
||||||
|
return [dataclasses.asdict(item) for item in items]
|
||||||
108
core/schema/serializers/detect_pipeline.py
Normal file
108
core/schema/serializers/detect_pipeline.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
"""
|
||||||
|
Serializers for detection pipeline runtime models.
|
||||||
|
|
||||||
|
Mirrors core/schema/models/detect_pipeline.py.
|
||||||
|
|
||||||
|
Special handling:
|
||||||
|
- Frame.image (np.ndarray → S3, excluded from JSON)
|
||||||
|
- TextCandidate.frame (object ref → frame_sequence integer)
|
||||||
|
Everything else uses dataclasses.asdict() via safe_construct.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
|
||||||
|
from core.schema.models.detect_pipeline import (
|
||||||
|
BoundingBox,
|
||||||
|
BrandDetection,
|
||||||
|
BrandStats,
|
||||||
|
DetectionReport,
|
||||||
|
Frame,
|
||||||
|
PipelineStats,
|
||||||
|
TextCandidate,
|
||||||
|
)
|
||||||
|
from ._common import safe_construct, serialize_dataclass, serialize_dataclass_list
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Frame — image goes to S3 separately
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def serialize_frame_meta(frame: Frame) -> dict:
|
||||||
|
"""Serialize Frame metadata only (no image)."""
|
||||||
|
result = dataclasses.asdict(frame)
|
||||||
|
del result["image"]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_frames_with_upload(frames: list[Frame], job_id: str) -> tuple[list[dict], dict[int, str]]:
|
||||||
|
"""Upload frame images to S3, return metadata + manifest."""
|
||||||
|
from detect.checkpoint.frames import save_frames
|
||||||
|
|
||||||
|
manifest = save_frames(job_id, frames)
|
||||||
|
meta = [serialize_frame_meta(f) for f in frames]
|
||||||
|
return meta, manifest
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_frames_with_download(meta: list[dict], manifest: dict, job_id: str) -> list[Frame]:
|
||||||
|
"""Load frames from S3 + metadata."""
|
||||||
|
from detect.checkpoint.frames import load_frames
|
||||||
|
|
||||||
|
int_manifest = {int(k): v for k, v in manifest.items()}
|
||||||
|
return load_frames(int_manifest, meta)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# TextCandidate — frame ref is an object, stored as sequence int
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def serialize_text_candidate(tc: TextCandidate) -> dict:
|
||||||
|
bbox_dict = dataclasses.asdict(tc.bbox)
|
||||||
|
result = {
|
||||||
|
"frame_sequence": tc.frame.sequence,
|
||||||
|
"bbox": bbox_dict,
|
||||||
|
"text": tc.text,
|
||||||
|
"ocr_confidence": tc.ocr_confidence,
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_text_candidates(candidates: list[TextCandidate]) -> list[dict]:
|
||||||
|
return [serialize_text_candidate(tc) for tc in candidates]
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_text_candidate(data: dict, frame_map: dict[int, Frame]) -> TextCandidate:
|
||||||
|
frame = frame_map[data["frame_sequence"]]
|
||||||
|
bbox = safe_construct(BoundingBox, data["bbox"])
|
||||||
|
candidate = TextCandidate(
|
||||||
|
frame=frame,
|
||||||
|
bbox=bbox,
|
||||||
|
text=data["text"],
|
||||||
|
ocr_confidence=data["ocr_confidence"],
|
||||||
|
)
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_text_candidates(data: list[dict], frame_map: dict[int, Frame]) -> list[TextCandidate]:
|
||||||
|
return [deserialize_text_candidate(d, frame_map) for d in data]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# BoundingBox, BrandDetection, PipelineStats, etc — standard dataclasses
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def deserialize_bounding_box(data: dict) -> BoundingBox:
|
||||||
|
return safe_construct(BoundingBox, data)
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_brand_detection(data: dict) -> BrandDetection:
|
||||||
|
return safe_construct(BrandDetection, data)
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_pipeline_stats(data: dict) -> PipelineStats:
|
||||||
|
return safe_construct(PipelineStats, data)
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_detection_report(data: dict) -> DetectionReport:
|
||||||
|
return safe_construct(DetectionReport, data)
|
||||||
@@ -1,133 +1,108 @@
|
|||||||
"""State serialization — DetectState ↔ JSON-compatible dict."""
|
"""
|
||||||
|
State serialization — DetectState ↔ JSON-compatible dict.
|
||||||
|
|
||||||
|
Delegates to each stage's serialize_fn/deserialize_fn via the registry.
|
||||||
|
This file has no model-specific knowledge — stages own their data format.
|
||||||
|
|
||||||
|
The only things serialized here are the "envelope" fields (job_id, video_path, etc.)
|
||||||
|
that don't belong to any stage.
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import dataclasses
|
from core.schema.serializers._common import serialize_dataclass
|
||||||
|
from core.schema.serializers.detect_pipeline import (
|
||||||
from detect.models import (
|
deserialize_pipeline_stats,
|
||||||
BoundingBox,
|
deserialize_text_candidates,
|
||||||
BrandDetection,
|
|
||||||
Frame,
|
|
||||||
PipelineStats,
|
|
||||||
TextCandidate,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# Envelope fields — not owned by any stage, always present
|
||||||
# Serialize helpers
|
ENVELOPE_KEYS = ["job_id", "video_path", "profile_name", "config_overrides"]
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def serialize_frame_meta(frame: Frame) -> dict:
|
|
||||||
meta = {
|
|
||||||
"sequence": frame.sequence,
|
|
||||||
"chunk_id": frame.chunk_id,
|
|
||||||
"timestamp": frame.timestamp,
|
|
||||||
"perceptual_hash": frame.perceptual_hash,
|
|
||||||
}
|
|
||||||
return meta
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_text_candidate(tc: TextCandidate) -> dict:
|
|
||||||
bbox_dict = dataclasses.asdict(tc.bbox)
|
|
||||||
candidate = {
|
|
||||||
"frame_sequence": tc.frame.sequence,
|
|
||||||
"bbox": bbox_dict,
|
|
||||||
"text": tc.text,
|
|
||||||
"ocr_confidence": tc.ocr_confidence,
|
|
||||||
}
|
|
||||||
return candidate
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_state(state: dict, frames_manifest: dict[int, str]) -> dict:
|
def serialize_state(state: dict, frames_manifest: dict[int, str]) -> dict:
|
||||||
"""
|
"""
|
||||||
Serialize DetectState to a JSON-compatible dict.
|
Serialize DetectState to a JSON-compatible dict.
|
||||||
|
|
||||||
Frame images are replaced with S3 key references.
|
Calls each registered stage's serialize_fn for stage-owned data.
|
||||||
TextCandidate.frame references become frame_sequence integers.
|
Envelope fields (job_id, etc.) are copied directly.
|
||||||
"""
|
"""
|
||||||
frames = state.get("frames", [])
|
from detect.stages.base import _REGISTRY
|
||||||
filtered = state.get("filtered_frames", [])
|
|
||||||
|
|
||||||
manifest_strs = {str(k): v for k, v in frames_manifest.items()}
|
checkpoint = {}
|
||||||
frames_meta = [serialize_frame_meta(f) for f in frames]
|
|
||||||
filtered_seqs = [f.sequence for f in filtered]
|
|
||||||
|
|
||||||
boxes_serialized = {}
|
# Envelope
|
||||||
for seq, boxes in state.get("boxes_by_frame", {}).items():
|
for key in ENVELOPE_KEYS:
|
||||||
boxes_serialized[str(seq)] = [dataclasses.asdict(b) for b in boxes]
|
default = {} if key == "config_overrides" else ""
|
||||||
|
checkpoint[key] = state.get(key, default)
|
||||||
|
|
||||||
text_candidates = [serialize_text_candidate(tc) for tc in state.get("text_candidates", [])]
|
# Frames manifest (needed by frame-loading stages)
|
||||||
unresolved = [serialize_text_candidate(tc) for tc in state.get("unresolved_candidates", [])]
|
checkpoint["frames_manifest"] = {str(k): v for k, v in frames_manifest.items()}
|
||||||
detections = [dataclasses.asdict(d) for d in state.get("detections", [])]
|
|
||||||
stats = dataclasses.asdict(state.get("stats", PipelineStats()))
|
# Stats (shared across stages, not owned by one)
|
||||||
|
stats = state.get("stats")
|
||||||
|
if stats is not None:
|
||||||
|
checkpoint["stats"] = serialize_dataclass(stats)
|
||||||
|
else:
|
||||||
|
checkpoint["stats"] = {}
|
||||||
|
|
||||||
|
# Per-stage data
|
||||||
|
for name, stage_def in _REGISTRY.items():
|
||||||
|
if stage_def.serialize_fn is None:
|
||||||
|
continue
|
||||||
|
job_id = state.get("job_id", "")
|
||||||
|
stage_data = stage_def.serialize_fn(state, job_id)
|
||||||
|
checkpoint[f"stage_{name}"] = stage_data
|
||||||
|
|
||||||
checkpoint = {
|
|
||||||
"job_id": state.get("job_id", ""),
|
|
||||||
"video_path": state.get("video_path", ""),
|
|
||||||
"profile_name": state.get("profile_name", ""),
|
|
||||||
"config_overrides": state.get("config_overrides", {}),
|
|
||||||
"frames_manifest": manifest_strs,
|
|
||||||
"frames_meta": frames_meta,
|
|
||||||
"filtered_frame_sequences": filtered_seqs,
|
|
||||||
"boxes_by_frame": boxes_serialized,
|
|
||||||
"text_candidates": text_candidates,
|
|
||||||
"unresolved_candidates": unresolved,
|
|
||||||
"detections": detections,
|
|
||||||
"stats": stats,
|
|
||||||
}
|
|
||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
def deserialize_state(checkpoint: dict, frames: list) -> dict:
|
||||||
# Deserialize helpers
|
"""
|
||||||
# ---------------------------------------------------------------------------
|
Reconstitute DetectState from a checkpoint dict + loaded frames.
|
||||||
|
|
||||||
def deserialize_text_candidate(d: dict, frame_map: dict[int, Frame]) -> TextCandidate:
|
Calls each stage's deserialize_fn to restore stage-owned data.
|
||||||
frame = frame_map[d["frame_sequence"]]
|
"""
|
||||||
bbox = BoundingBox(**d["bbox"])
|
from detect.stages.base import _REGISTRY
|
||||||
candidate = TextCandidate(
|
|
||||||
frame=frame,
|
|
||||||
bbox=bbox,
|
|
||||||
text=d["text"],
|
|
||||||
ocr_confidence=d["ocr_confidence"],
|
|
||||||
)
|
|
||||||
return candidate
|
|
||||||
|
|
||||||
|
|
||||||
def deserialize_state(checkpoint: dict, frames: list[Frame]) -> dict:
|
|
||||||
"""Reconstitute DetectState from a checkpoint dict + loaded frames."""
|
|
||||||
frame_map = {f.sequence: f for f in frames}
|
frame_map = {f.sequence: f for f in frames}
|
||||||
|
|
||||||
filtered_seqs = set(checkpoint.get("filtered_frame_sequences", []))
|
state = {}
|
||||||
filtered_frames = [f for f in frames if f.sequence in filtered_seqs]
|
|
||||||
|
|
||||||
boxes_by_frame = {}
|
# Envelope
|
||||||
for seq_str, box_dicts in checkpoint.get("boxes_by_frame", {}).items():
|
for key in ENVELOPE_KEYS:
|
||||||
seq = int(seq_str)
|
default = {} if key == "config_overrides" else ""
|
||||||
boxes_by_frame[seq] = [BoundingBox(**b) for b in box_dicts]
|
state[key] = checkpoint.get(key, default)
|
||||||
|
|
||||||
text_candidates = [
|
# Frames (always present, loaded externally)
|
||||||
deserialize_text_candidate(d, frame_map)
|
state["frames"] = frames
|
||||||
for d in checkpoint.get("text_candidates", [])
|
|
||||||
]
|
# Stats
|
||||||
unresolved_candidates = [
|
state["stats"] = deserialize_pipeline_stats(checkpoint.get("stats", {}))
|
||||||
deserialize_text_candidate(d, frame_map)
|
|
||||||
for d in checkpoint.get("unresolved_candidates", [])
|
# Per-stage data
|
||||||
]
|
for name, stage_def in _REGISTRY.items():
|
||||||
detections = [BrandDetection(**d) for d in checkpoint.get("detections", [])]
|
if stage_def.deserialize_fn is None:
|
||||||
stats = PipelineStats(**checkpoint.get("stats", {}))
|
continue
|
||||||
|
|
||||||
|
stage_key = f"stage_{name}"
|
||||||
|
if stage_key not in checkpoint:
|
||||||
|
continue
|
||||||
|
|
||||||
|
job_id = state.get("job_id", "")
|
||||||
|
stage_data = stage_def.deserialize_fn(checkpoint[stage_key], job_id)
|
||||||
|
|
||||||
|
for k, v in stage_data.items():
|
||||||
|
if k == "_filtered_sequences":
|
||||||
|
# Reconnect filtered frames from sequence list
|
||||||
|
seq_set = set(v)
|
||||||
|
state["filtered_frames"] = [f for f in frames if f.sequence in seq_set]
|
||||||
|
elif k.endswith("_raw"):
|
||||||
|
# Raw text candidates need frame reference reconnection
|
||||||
|
real_key = k.removeprefix("_").removesuffix("_raw")
|
||||||
|
state[real_key] = deserialize_text_candidates(v, frame_map)
|
||||||
|
else:
|
||||||
|
state[k] = v
|
||||||
|
|
||||||
state = {
|
|
||||||
"job_id": checkpoint.get("job_id", ""),
|
|
||||||
"video_path": checkpoint.get("video_path", ""),
|
|
||||||
"profile_name": checkpoint.get("profile_name", ""),
|
|
||||||
"config_overrides": checkpoint.get("config_overrides", {}),
|
|
||||||
"frames": frames,
|
|
||||||
"filtered_frames": filtered_frames,
|
|
||||||
"boxes_by_frame": boxes_by_frame,
|
|
||||||
"text_candidates": text_candidates,
|
|
||||||
"unresolved_candidates": unresolved_candidates,
|
|
||||||
"detections": detections,
|
|
||||||
"stats": stats,
|
|
||||||
}
|
|
||||||
return state
|
return state
|
||||||
|
|||||||
102
detect/models.py
102
detect/models.py
@@ -1,86 +1,26 @@
|
|||||||
"""
|
"""
|
||||||
Core domain models for the detection pipeline.
|
Re-export pipeline runtime models from core/schema/models/detect_pipeline.py.
|
||||||
|
|
||||||
These are pipeline-internal models — the data structures that flow
|
All models are defined in core/schema/ — this module exists for backward
|
||||||
between LangGraph nodes. SSE event payloads (sse_contract.py) are
|
compatibility so existing imports (from detect.models import Frame) keep working.
|
||||||
derived from these when emitting to the UI.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from core.schema.models.detect_pipeline import (
|
||||||
|
BoundingBox,
|
||||||
|
BrandDetection,
|
||||||
|
BrandStats,
|
||||||
|
DetectionReport,
|
||||||
|
Frame,
|
||||||
|
PipelineStats,
|
||||||
|
TextCandidate,
|
||||||
|
)
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
__all__ = [
|
||||||
from typing import Literal
|
"BoundingBox",
|
||||||
|
"BrandDetection",
|
||||||
import numpy as np
|
"BrandStats",
|
||||||
|
"DetectionReport",
|
||||||
|
"Frame",
|
||||||
@dataclass
|
"PipelineStats",
|
||||||
class Frame:
|
"TextCandidate",
|
||||||
sequence: int
|
]
|
||||||
chunk_id: int
|
|
||||||
timestamp: float # position in video (seconds)
|
|
||||||
image: np.ndarray
|
|
||||||
perceptual_hash: str = ""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BoundingBox:
|
|
||||||
x: int
|
|
||||||
y: int
|
|
||||||
w: int
|
|
||||||
h: int
|
|
||||||
confidence: float
|
|
||||||
label: str
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TextCandidate:
|
|
||||||
frame: Frame
|
|
||||||
bbox: BoundingBox
|
|
||||||
text: str
|
|
||||||
ocr_confidence: float
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BrandDetection:
|
|
||||||
brand: str
|
|
||||||
timestamp: float
|
|
||||||
duration: float
|
|
||||||
confidence: float
|
|
||||||
source: Literal["ocr", "local_vlm", "cloud_llm", "logo_match", "auxiliary"]
|
|
||||||
bbox: BoundingBox | None = None
|
|
||||||
frame_ref: int | None = None
|
|
||||||
content_type: str = ""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BrandStats:
|
|
||||||
total_appearances: int = 0
|
|
||||||
total_screen_time: float = 0.0
|
|
||||||
avg_confidence: float = 0.0
|
|
||||||
first_seen: float = 0.0
|
|
||||||
last_seen: float = 0.0
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PipelineStats:
|
|
||||||
frames_extracted: int = 0
|
|
||||||
frames_after_scene_filter: int = 0
|
|
||||||
regions_detected: int = 0
|
|
||||||
regions_resolved_by_ocr: int = 0
|
|
||||||
regions_escalated_to_local_vlm: int = 0
|
|
||||||
regions_escalated_to_cloud_llm: int = 0
|
|
||||||
auxiliary_detections: int = 0
|
|
||||||
cloud_llm_calls: int = 0
|
|
||||||
processing_time_seconds: float = 0.0
|
|
||||||
estimated_cloud_cost_usd: float = 0.0
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DetectionReport:
|
|
||||||
video_source: str
|
|
||||||
content_type: str
|
|
||||||
duration_seconds: float
|
|
||||||
brands: dict[str, BrandStats] = field(default_factory=dict)
|
|
||||||
timeline: list[BrandDetection] = field(default_factory=list)
|
|
||||||
pipeline_stats: PipelineStats = field(default_factory=PipelineStats)
|
|
||||||
|
|||||||
@@ -0,0 +1,21 @@
|
|||||||
|
"""
|
||||||
|
Pipeline stages.
|
||||||
|
|
||||||
|
Each stage registers its StageDefinition on import,
|
||||||
|
declaring IO (what it reads/writes from state),
|
||||||
|
config fields (what's tunable from the editor),
|
||||||
|
and serialization (how to checkpoint its outputs).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .base import (
|
||||||
|
StageDefinition,
|
||||||
|
StageIO,
|
||||||
|
StageConfigField,
|
||||||
|
register_stage,
|
||||||
|
get_stage,
|
||||||
|
list_stages,
|
||||||
|
get_palette,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Populate registry with built-in stages
|
||||||
|
from . import registry # noqa: F401
|
||||||
|
|||||||
101
detect/stages/base.py
Normal file
101
detect/stages/base.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
"""
|
||||||
|
Stage protocol — common interface for all pipeline stages.
|
||||||
|
|
||||||
|
Every stage declares:
|
||||||
|
- IO: what it reads/writes from DetectState
|
||||||
|
- Config: tunable parameters for the editor
|
||||||
|
- Serialization: how to persist/restore its own outputs
|
||||||
|
|
||||||
|
The checkpoint layer is a black box — it asks each stage to serialize its
|
||||||
|
outputs and stores the result. Stages own their data format. Binary data
|
||||||
|
(frames, crops) goes to S3 via the stage itself. The checkpoint just
|
||||||
|
stores the JSON envelope.
|
||||||
|
|
||||||
|
The graph builder uses StageIO to validate that a stage's inputs are
|
||||||
|
satisfied by previous stages' outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StageIO:
|
||||||
|
"""Declares what a stage reads and writes from/to DetectState."""
|
||||||
|
reads: list[str]
|
||||||
|
writes: list[str]
|
||||||
|
optional_reads: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StageConfigField:
|
||||||
|
"""A single tunable config parameter for the editor UI."""
|
||||||
|
name: str
|
||||||
|
type: str # "float", "int", "str", "bool", "list[str]"
|
||||||
|
default: Any
|
||||||
|
description: str = ""
|
||||||
|
min: float | None = None
|
||||||
|
max: float | None = None
|
||||||
|
options: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StageDefinition:
|
||||||
|
"""
|
||||||
|
Complete metadata for a pipeline stage.
|
||||||
|
|
||||||
|
The profile editor uses this to build the palette, generate config
|
||||||
|
forms, and validate graph connections. The checkpoint uses serialize_fn
|
||||||
|
and deserialize_fn to persist stage outputs without knowing the internals.
|
||||||
|
"""
|
||||||
|
name: str
|
||||||
|
label: str
|
||||||
|
description: str
|
||||||
|
io: StageIO
|
||||||
|
config_fields: list[StageConfigField] = field(default_factory=list)
|
||||||
|
category: str = "detection"
|
||||||
|
|
||||||
|
# The actual graph node function: (DetectState) → dict
|
||||||
|
fn: Callable | None = None
|
||||||
|
|
||||||
|
# Stage-owned serialization for checkpointing.
|
||||||
|
# serialize_fn: (state: dict, job_id: str) → json-compatible dict
|
||||||
|
# Stage picks its writes from state, serializes them.
|
||||||
|
# Binary data (frames) → S3 via stage, returns refs.
|
||||||
|
# deserialize_fn: (data: dict, job_id: str) → state update dict
|
||||||
|
# Stage restores its writes from the persisted data.
|
||||||
|
serialize_fn: Callable | None = None
|
||||||
|
deserialize_fn: Callable | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Registry
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_REGISTRY: dict[str, StageDefinition] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_stage(definition: StageDefinition):
|
||||||
|
_REGISTRY[definition.name] = definition
|
||||||
|
|
||||||
|
|
||||||
|
def get_stage(name: str) -> StageDefinition:
|
||||||
|
if name not in _REGISTRY:
|
||||||
|
raise KeyError(f"Unknown stage: {name!r}. Registered: {list(_REGISTRY)}")
|
||||||
|
return _REGISTRY[name]
|
||||||
|
|
||||||
|
|
||||||
|
def list_stages() -> list[StageDefinition]:
|
||||||
|
return list(_REGISTRY.values())
|
||||||
|
|
||||||
|
|
||||||
|
def get_palette() -> dict[str, list[StageDefinition]]:
|
||||||
|
"""Group stages by category for the editor palette."""
|
||||||
|
palette: dict[str, list[StageDefinition]] = {}
|
||||||
|
for stage in _REGISTRY.values():
|
||||||
|
if stage.category not in palette:
|
||||||
|
palette[stage.category] = []
|
||||||
|
palette[stage.category].append(stage)
|
||||||
|
return palette
|
||||||
28
detect/stages/registry/__init__.py
Normal file
28
detect/stages/registry/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""
|
||||||
|
Stage registry — registers all built-in stages.
|
||||||
|
|
||||||
|
Split by category:
|
||||||
|
preprocessing.py — extract_frames, filter_scenes
|
||||||
|
detection.py — detect_objects, run_ocr
|
||||||
|
resolution.py — match_brands
|
||||||
|
escalation.py — escalate_vlm, escalate_cloud
|
||||||
|
output.py — compile_report
|
||||||
|
_serializers.py — shared serialization helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
from . import preprocessing
|
||||||
|
from . import detection
|
||||||
|
from . import resolution
|
||||||
|
from . import escalation
|
||||||
|
from . import output
|
||||||
|
|
||||||
|
|
||||||
|
def register_all():
|
||||||
|
preprocessing.register()
|
||||||
|
detection.register()
|
||||||
|
resolution.register()
|
||||||
|
escalation.register()
|
||||||
|
output.register()
|
||||||
|
|
||||||
|
|
||||||
|
register_all()
|
||||||
25
detect/stages/registry/_serializers.py
Normal file
25
detect/stages/registry/_serializers.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
"""
|
||||||
|
Re-export serializers from core/schema/serializers/.
|
||||||
|
|
||||||
|
Stage registry modules import from here for convenience.
|
||||||
|
All serialization logic lives in core/schema/serializers/.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from core.schema.serializers._common import (
|
||||||
|
safe_construct,
|
||||||
|
serialize_dataclass,
|
||||||
|
serialize_dataclass_list,
|
||||||
|
)
|
||||||
|
from core.schema.serializers.detect_pipeline import (
|
||||||
|
serialize_frame_meta,
|
||||||
|
serialize_frames_with_upload as serialize_frames,
|
||||||
|
deserialize_frames_with_download as deserialize_frames,
|
||||||
|
serialize_text_candidate,
|
||||||
|
serialize_text_candidates,
|
||||||
|
deserialize_text_candidate,
|
||||||
|
deserialize_text_candidates,
|
||||||
|
deserialize_bounding_box,
|
||||||
|
deserialize_brand_detection,
|
||||||
|
deserialize_pipeline_stats,
|
||||||
|
deserialize_detection_report,
|
||||||
|
)
|
||||||
63
detect/stages/registry/detection.py
Normal file
63
detect/stages/registry/detection.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
"""Registration for detection stages: YOLO, OCR."""
|
||||||
|
|
||||||
|
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
|
||||||
|
from ._serializers import (
|
||||||
|
serialize_dataclass_list,
|
||||||
|
serialize_text_candidates,
|
||||||
|
deserialize_bounding_box,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ser_detect(state: dict, job_id: str) -> dict:
|
||||||
|
boxes = state.get("boxes_by_frame", {})
|
||||||
|
serialized = {str(seq): serialize_dataclass_list(bl) for seq, bl in boxes.items()}
|
||||||
|
return {"boxes_by_frame": serialized}
|
||||||
|
|
||||||
|
|
||||||
|
def _deser_detect(data: dict, job_id: str) -> dict:
|
||||||
|
boxes = {}
|
||||||
|
for seq_str, box_dicts in data.get("boxes_by_frame", {}).items():
|
||||||
|
boxes[int(seq_str)] = [deserialize_bounding_box(b) for b in box_dicts]
|
||||||
|
return {"boxes_by_frame": boxes}
|
||||||
|
|
||||||
|
|
||||||
|
def _ser_ocr(state: dict, job_id: str) -> dict:
|
||||||
|
candidates = state.get("text_candidates", [])
|
||||||
|
return {"text_candidates": serialize_text_candidates(candidates)}
|
||||||
|
|
||||||
|
|
||||||
|
def _deser_ocr(data: dict, job_id: str) -> dict:
|
||||||
|
return {"_text_candidates_raw": data["text_candidates"]}
|
||||||
|
|
||||||
|
|
||||||
|
def register():
|
||||||
|
yolo = StageDefinition(
|
||||||
|
name="detect_objects",
|
||||||
|
label="Object Detection",
|
||||||
|
description="YOLO object detection on filtered frames",
|
||||||
|
category="detection",
|
||||||
|
io=StageIO(reads=["filtered_frames"], writes=["boxes_by_frame"]),
|
||||||
|
config_fields=[
|
||||||
|
StageConfigField("model_name", "str", "yolov8n.pt", "YOLO model file"),
|
||||||
|
StageConfigField("confidence_threshold", "float", 0.3, "Min detection confidence", min=0.0, max=1.0),
|
||||||
|
StageConfigField("target_classes", "list[str]", [], "YOLO classes to detect (empty = all)"),
|
||||||
|
],
|
||||||
|
serialize_fn=_ser_detect,
|
||||||
|
deserialize_fn=_deser_detect,
|
||||||
|
)
|
||||||
|
register_stage(yolo)
|
||||||
|
|
||||||
|
ocr = StageDefinition(
|
||||||
|
name="run_ocr",
|
||||||
|
label="OCR",
|
||||||
|
description="Extract text from detected regions",
|
||||||
|
category="detection",
|
||||||
|
io=StageIO(reads=["filtered_frames", "boxes_by_frame"], writes=["text_candidates"]),
|
||||||
|
config_fields=[
|
||||||
|
StageConfigField("languages", "list[str]", ["en"], "OCR languages"),
|
||||||
|
StageConfigField("min_confidence", "float", 0.5, "Min OCR confidence", min=0.0, max=1.0),
|
||||||
|
],
|
||||||
|
serialize_fn=_ser_ocr,
|
||||||
|
deserialize_fn=_deser_ocr,
|
||||||
|
)
|
||||||
|
register_stage(ocr)
|
||||||
63
detect/stages/registry/escalation.py
Normal file
63
detect/stages/registry/escalation.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
"""Registration for escalation stages: local VLM, cloud LLM."""
|
||||||
|
|
||||||
|
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
|
||||||
|
from ._serializers import (
|
||||||
|
serialize_dataclass_list,
|
||||||
|
serialize_text_candidates,
|
||||||
|
deserialize_brand_detection,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ser_escalation(state: dict, job_id: str) -> dict:
|
||||||
|
matched = state.get("detections", [])
|
||||||
|
unresolved = state.get("unresolved_candidates", [])
|
||||||
|
return {
|
||||||
|
"detections": serialize_dataclass_list(matched),
|
||||||
|
"unresolved_candidates": serialize_text_candidates(unresolved),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _deser_escalation(data: dict, job_id: str) -> dict:
|
||||||
|
detections = [deserialize_brand_detection(d) for d in data.get("detections", [])]
|
||||||
|
return {
|
||||||
|
"detections": detections,
|
||||||
|
"_unresolved_raw": data.get("unresolved_candidates", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def register():
|
||||||
|
vlm = StageDefinition(
|
||||||
|
name="escalate_vlm",
|
||||||
|
label="Local VLM",
|
||||||
|
description="Process unresolved crops with moondream2",
|
||||||
|
category="escalation",
|
||||||
|
io=StageIO(
|
||||||
|
reads=["unresolved_candidates"],
|
||||||
|
writes=["detections", "unresolved_candidates"],
|
||||||
|
optional_reads=["source_asset_id"],
|
||||||
|
),
|
||||||
|
config_fields=[
|
||||||
|
StageConfigField("min_confidence", "float", 0.5, "Min VLM confidence", min=0.0, max=1.0),
|
||||||
|
],
|
||||||
|
serialize_fn=_ser_escalation,
|
||||||
|
deserialize_fn=_deser_escalation,
|
||||||
|
)
|
||||||
|
register_stage(vlm)
|
||||||
|
|
||||||
|
cloud = StageDefinition(
|
||||||
|
name="escalate_cloud",
|
||||||
|
label="Cloud LLM",
|
||||||
|
description="Escalate remaining crops to cloud provider",
|
||||||
|
category="escalation",
|
||||||
|
io=StageIO(
|
||||||
|
reads=["unresolved_candidates"],
|
||||||
|
writes=["detections"],
|
||||||
|
optional_reads=["source_asset_id"],
|
||||||
|
),
|
||||||
|
config_fields=[
|
||||||
|
StageConfigField("min_confidence", "float", 0.4, "Min cloud confidence", min=0.0, max=1.0),
|
||||||
|
],
|
||||||
|
serialize_fn=_ser_escalation,
|
||||||
|
deserialize_fn=_deser_escalation,
|
||||||
|
)
|
||||||
|
register_stage(cloud)
|
||||||
32
detect/stages/registry/output.py
Normal file
32
detect/stages/registry/output.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""Registration for output stages: report compilation."""
|
||||||
|
|
||||||
|
from detect.stages.base import StageDefinition, StageIO, register_stage
|
||||||
|
from ._serializers import serialize_dataclass, deserialize_detection_report
|
||||||
|
|
||||||
|
|
||||||
|
def _ser_report(state: dict, job_id: str) -> dict:
|
||||||
|
report = state.get("report")
|
||||||
|
if report is None:
|
||||||
|
return {"report": None}
|
||||||
|
return {"report": serialize_dataclass(report)}
|
||||||
|
|
||||||
|
|
||||||
|
def _deser_report(data: dict, job_id: str) -> dict:
|
||||||
|
raw = data.get("report")
|
||||||
|
if raw is None:
|
||||||
|
return {"report": None}
|
||||||
|
return {"report": deserialize_detection_report(raw)}
|
||||||
|
|
||||||
|
|
||||||
|
def register():
|
||||||
|
report = StageDefinition(
|
||||||
|
name="compile_report",
|
||||||
|
label="Report",
|
||||||
|
description="Merge detections and compile final report",
|
||||||
|
category="output",
|
||||||
|
io=StageIO(reads=["detections"], writes=["report"]),
|
||||||
|
config_fields=[],
|
||||||
|
serialize_fn=_ser_report,
|
||||||
|
deserialize_fn=_deser_report,
|
||||||
|
)
|
||||||
|
register_stage(report)
|
||||||
57
detect/stages/registry/preprocessing.py
Normal file
57
detect/stages/registry/preprocessing.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
"""Registration for preprocessing stages: frame extraction, scene filter."""
|
||||||
|
|
||||||
|
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
|
||||||
|
from ._serializers import serialize_frames, deserialize_frames
|
||||||
|
|
||||||
|
|
||||||
|
def _ser_extract(state: dict, job_id: str) -> dict:
|
||||||
|
frames = state.get("frames", [])
|
||||||
|
meta, manifest = serialize_frames(frames, job_id)
|
||||||
|
return {"frames_meta": meta, "frames_manifest": manifest}
|
||||||
|
|
||||||
|
|
||||||
|
def _deser_extract(data: dict, job_id: str) -> dict:
|
||||||
|
frames = deserialize_frames(data["frames_meta"], data["frames_manifest"], job_id)
|
||||||
|
return {"frames": frames}
|
||||||
|
|
||||||
|
|
||||||
|
def _ser_filter(state: dict, job_id: str) -> dict:
|
||||||
|
filtered = state.get("filtered_frames", [])
|
||||||
|
seqs = [f.sequence for f in filtered]
|
||||||
|
return {"filtered_frame_sequences": seqs}
|
||||||
|
|
||||||
|
|
||||||
|
def _deser_filter(data: dict, job_id: str) -> dict:
|
||||||
|
return {"_filtered_sequences": data["filtered_frame_sequences"]}
|
||||||
|
|
||||||
|
|
||||||
|
def register():
|
||||||
|
extract = StageDefinition(
|
||||||
|
name="extract_frames",
|
||||||
|
label="Frame Extraction",
|
||||||
|
description="Extract frames from video at configurable FPS",
|
||||||
|
category="preprocessing",
|
||||||
|
io=StageIO(reads=["video_path"], writes=["frames"]),
|
||||||
|
config_fields=[
|
||||||
|
StageConfigField("fps", "float", 2.0, "Frames per second", min=0.1, max=30.0),
|
||||||
|
StageConfigField("max_frames", "int", 500, "Maximum frames to extract", min=1, max=10000),
|
||||||
|
],
|
||||||
|
serialize_fn=_ser_extract,
|
||||||
|
deserialize_fn=_deser_extract,
|
||||||
|
)
|
||||||
|
register_stage(extract)
|
||||||
|
|
||||||
|
scene_filter = StageDefinition(
|
||||||
|
name="filter_scenes",
|
||||||
|
label="Scene Filter",
|
||||||
|
description="Deduplicate similar frames using perceptual hashing",
|
||||||
|
category="preprocessing",
|
||||||
|
io=StageIO(reads=["frames"], writes=["filtered_frames"]),
|
||||||
|
config_fields=[
|
||||||
|
StageConfigField("hamming_threshold", "int", 8, "Hamming distance threshold", min=0, max=64),
|
||||||
|
StageConfigField("enabled", "bool", True, "Enable scene filtering"),
|
||||||
|
],
|
||||||
|
serialize_fn=_ser_filter,
|
||||||
|
deserialize_fn=_deser_filter,
|
||||||
|
)
|
||||||
|
register_stage(scene_filter)
|
||||||
45
detect/stages/registry/resolution.py
Normal file
45
detect/stages/registry/resolution.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
"""Registration for resolution stages: brand resolver."""
|
||||||
|
|
||||||
|
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
|
||||||
|
from ._serializers import (
|
||||||
|
serialize_dataclass_list,
|
||||||
|
serialize_text_candidates,
|
||||||
|
deserialize_brand_detection,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ser_brands(state: dict, job_id: str) -> dict:
|
||||||
|
matched = state.get("detections", [])
|
||||||
|
unresolved = state.get("unresolved_candidates", [])
|
||||||
|
return {
|
||||||
|
"detections": serialize_dataclass_list(matched),
|
||||||
|
"unresolved_candidates": serialize_text_candidates(unresolved),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _deser_brands(data: dict, job_id: str) -> dict:
|
||||||
|
detections = [deserialize_brand_detection(d) for d in data.get("detections", [])]
|
||||||
|
return {
|
||||||
|
"detections": detections,
|
||||||
|
"_unresolved_raw": data.get("unresolved_candidates", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def register():
|
||||||
|
resolver = StageDefinition(
|
||||||
|
name="match_brands",
|
||||||
|
label="Brand Resolver",
|
||||||
|
description="Match OCR text against known brands (session + global DB)",
|
||||||
|
category="resolution",
|
||||||
|
io=StageIO(
|
||||||
|
reads=["text_candidates"],
|
||||||
|
writes=["detections", "unresolved_candidates"],
|
||||||
|
optional_reads=["session_brands", "source_asset_id"],
|
||||||
|
),
|
||||||
|
config_fields=[
|
||||||
|
StageConfigField("fuzzy_threshold", "int", 75, "Fuzzy match threshold", min=0, max=100),
|
||||||
|
],
|
||||||
|
serialize_fn=_ser_brands,
|
||||||
|
deserialize_fn=_deser_brands,
|
||||||
|
)
|
||||||
|
register_stage(resolver)
|
||||||
@@ -1,26 +1,30 @@
|
|||||||
"""Tests for checkpoint serialization — round-trip without S3."""
|
"""Tests for checkpoint serialization — unit tests without S3."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from detect.models import BoundingBox, BrandDetection, Frame, PipelineStats, TextCandidate
|
from detect.models import BoundingBox, BrandDetection, Frame, PipelineStats, TextCandidate
|
||||||
from detect.checkpoint.serializer import (
|
from core.schema.serializers._common import safe_construct
|
||||||
serialize_state,
|
from core.schema.serializers.detect_pipeline import (
|
||||||
deserialize_state,
|
|
||||||
serialize_frame_meta,
|
serialize_frame_meta,
|
||||||
serialize_text_candidate,
|
serialize_text_candidate,
|
||||||
|
serialize_text_candidates,
|
||||||
|
serialize_dataclass_list,
|
||||||
deserialize_text_candidate,
|
deserialize_text_candidate,
|
||||||
|
deserialize_text_candidates,
|
||||||
|
deserialize_bounding_box,
|
||||||
|
deserialize_brand_detection,
|
||||||
|
deserialize_pipeline_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _make_frame(seq: int = 0, w: int = 100, h: int = 80) -> Frame:
|
def _make_frame(seq: int = 0, w: int = 100, h: int = 80) -> Frame:
|
||||||
image = np.random.randint(0, 255, (h, w, 3), dtype=np.uint8)
|
image = np.random.randint(0, 255, (h, w, 3), dtype=np.uint8)
|
||||||
return Frame(
|
return Frame(
|
||||||
sequence=seq,
|
sequence=seq, chunk_id=0, timestamp=float(seq) * 0.5,
|
||||||
chunk_id=0,
|
image=image, perceptual_hash=f"hash_{seq}",
|
||||||
timestamp=float(seq) * 0.5,
|
|
||||||
image=image,
|
|
||||||
perceptual_hash=f"hash_{seq}",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -35,13 +39,8 @@ def _make_candidate(frame: Frame, text: str = "NIKE") -> TextCandidate:
|
|||||||
|
|
||||||
def _make_detection(brand: str = "Nike", timestamp: float = 1.0) -> BrandDetection:
|
def _make_detection(brand: str = "Nike", timestamp: float = 1.0) -> BrandDetection:
|
||||||
return BrandDetection(
|
return BrandDetection(
|
||||||
brand=brand,
|
brand=brand, timestamp=timestamp, duration=0.5,
|
||||||
timestamp=timestamp,
|
confidence=0.92, source="ocr", content_type="soccer_broadcast", frame_ref=0,
|
||||||
duration=0.5,
|
|
||||||
confidence=0.92,
|
|
||||||
source="ocr",
|
|
||||||
content_type="soccer_broadcast",
|
|
||||||
frame_ref=0,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -81,102 +80,84 @@ def test_deserialize_text_candidate():
|
|||||||
|
|
||||||
assert restored.text == "ADIDAS"
|
assert restored.text == "ADIDAS"
|
||||||
assert restored.ocr_confidence == 0.85
|
assert restored.ocr_confidence == 0.85
|
||||||
assert restored.frame is frame # same object reference
|
assert restored.frame is frame
|
||||||
assert restored.bbox.x == 10
|
assert restored.bbox.x == 10
|
||||||
|
|
||||||
|
|
||||||
# --- Full state round-trip ---
|
# --- BoundingBox ---
|
||||||
|
|
||||||
def test_state_round_trip():
|
def test_bounding_box_round_trip():
|
||||||
frames = [_make_frame(seq=i) for i in range(3)]
|
box = _make_box(x=15, y=25, w=40, h=30)
|
||||||
filtered = frames[:2]
|
serialized = serialize_dataclass_list([box])[0]
|
||||||
|
restored = deserialize_bounding_box(serialized)
|
||||||
|
|
||||||
box = _make_box()
|
assert restored.x == 15
|
||||||
boxes_by_frame = {0: [box], 1: [box]}
|
assert restored.w == 40
|
||||||
|
assert restored.confidence == 0.9
|
||||||
|
|
||||||
candidates = [_make_candidate(frames[0], "NIKE"), _make_candidate(frames[1], "EMIRATES")]
|
|
||||||
unresolved = [_make_candidate(frames[2], "unknown")]
|
|
||||||
detections = [_make_detection("Nike", 0.5), _make_detection("Emirates", 1.0)]
|
|
||||||
|
|
||||||
|
# --- BrandDetection ---
|
||||||
|
|
||||||
|
def test_brand_detection_round_trip():
|
||||||
|
det = _make_detection("Emirates", 3.5)
|
||||||
|
serialized = serialize_dataclass_list([det])[0]
|
||||||
|
restored = deserialize_brand_detection(serialized)
|
||||||
|
|
||||||
|
assert restored.brand == "Emirates"
|
||||||
|
assert restored.timestamp == 3.5
|
||||||
|
assert restored.source == "ocr"
|
||||||
|
|
||||||
|
|
||||||
|
# --- PipelineStats ---
|
||||||
|
|
||||||
|
def test_pipeline_stats_round_trip():
|
||||||
stats = PipelineStats(
|
stats = PipelineStats(
|
||||||
frames_extracted=3,
|
frames_extracted=120, cloud_llm_calls=3,
|
||||||
frames_after_scene_filter=2,
|
estimated_cloud_cost_usd=0.005,
|
||||||
regions_detected=2,
|
|
||||||
regions_resolved_by_ocr=2,
|
|
||||||
cloud_llm_calls=1,
|
|
||||||
estimated_cloud_cost_usd=0.003,
|
|
||||||
)
|
)
|
||||||
|
serialized = serialize_dataclass_list([stats])[0]
|
||||||
|
restored = deserialize_pipeline_stats(serialized)
|
||||||
|
|
||||||
state = {
|
assert restored.frames_extracted == 120
|
||||||
"job_id": "test-123",
|
assert restored.cloud_llm_calls == 3
|
||||||
"video_path": "/tmp/test.mp4",
|
assert restored.estimated_cloud_cost_usd == 0.005
|
||||||
"profile_name": "soccer_broadcast",
|
|
||||||
"config_overrides": {"ocr": {"min_confidence": 0.3}},
|
|
||||||
"frames": frames,
|
# --- safe_construct tolerates schema changes ---
|
||||||
"filtered_frames": filtered,
|
|
||||||
"boxes_by_frame": boxes_by_frame,
|
def test_safe_construct_ignores_unknown_fields():
|
||||||
"text_candidates": candidates,
|
data = {"x": 1, "y": 2, "w": 3, "h": 4, "confidence": 0.5, "label": "t", "extra_field": 99}
|
||||||
"unresolved_candidates": unresolved,
|
box = safe_construct(BoundingBox, data)
|
||||||
"detections": detections,
|
|
||||||
"stats": stats,
|
assert box.x == 1
|
||||||
|
assert not hasattr(box, "extra_field")
|
||||||
|
|
||||||
|
|
||||||
|
def test_safe_construct_uses_defaults():
|
||||||
|
data = {"frames_extracted": 50}
|
||||||
|
stats = safe_construct(PipelineStats, data)
|
||||||
|
|
||||||
|
assert stats.frames_extracted == 50
|
||||||
|
assert stats.cloud_llm_calls == 0 # default
|
||||||
|
|
||||||
|
|
||||||
|
# --- JSON compatibility ---
|
||||||
|
|
||||||
|
def test_all_serialized_is_json_compatible():
|
||||||
|
frame = _make_frame()
|
||||||
|
candidate = _make_candidate(frame)
|
||||||
|
detection = _make_detection()
|
||||||
|
stats = PipelineStats(frames_extracted=10)
|
||||||
|
|
||||||
|
all_data = {
|
||||||
|
"frame_meta": serialize_frame_meta(frame),
|
||||||
|
"candidate": serialize_text_candidate(candidate),
|
||||||
|
"detection": serialize_dataclass_list([detection])[0],
|
||||||
|
"stats": serialize_dataclass_list([stats])[0],
|
||||||
}
|
}
|
||||||
|
|
||||||
manifest = {f.sequence: f"s3://fake/frames/{f.sequence}.jpg" for f in frames}
|
json_str = json.dumps(all_data, default=str)
|
||||||
|
|
||||||
# Serialize
|
|
||||||
serialized = serialize_state(state, manifest)
|
|
||||||
|
|
||||||
# Verify JSON-compatible (no numpy, no Frame objects)
|
|
||||||
import json
|
|
||||||
json_str = json.dumps(serialized, default=str)
|
|
||||||
assert len(json_str) > 0
|
assert len(json_str) > 0
|
||||||
|
|
||||||
# Deserialize with the original frames (simulating frame load from S3)
|
roundtrip = json.loads(json_str)
|
||||||
restored = deserialize_state(serialized, frames)
|
assert roundtrip["frame_meta"]["sequence"] == frame.sequence
|
||||||
|
|
||||||
# Verify round-trip
|
|
||||||
assert restored["job_id"] == "test-123"
|
|
||||||
assert restored["video_path"] == "/tmp/test.mp4"
|
|
||||||
assert restored["profile_name"] == "soccer_broadcast"
|
|
||||||
assert restored["config_overrides"] == {"ocr": {"min_confidence": 0.3}}
|
|
||||||
|
|
||||||
assert len(restored["frames"]) == 3
|
|
||||||
assert len(restored["filtered_frames"]) == 2
|
|
||||||
assert len(restored["boxes_by_frame"]) == 2
|
|
||||||
assert len(restored["text_candidates"]) == 2
|
|
||||||
assert len(restored["unresolved_candidates"]) == 1
|
|
||||||
assert len(restored["detections"]) == 2
|
|
||||||
|
|
||||||
restored_stats = restored["stats"]
|
|
||||||
assert restored_stats.frames_extracted == 3
|
|
||||||
assert restored_stats.cloud_llm_calls == 1
|
|
||||||
assert restored_stats.estimated_cloud_cost_usd == 0.003
|
|
||||||
|
|
||||||
# TextCandidate frame references should point to actual Frame objects
|
|
||||||
tc = restored["text_candidates"][0]
|
|
||||||
assert tc.frame is frames[0]
|
|
||||||
assert tc.text == "NIKE"
|
|
||||||
|
|
||||||
|
|
||||||
def test_state_round_trip_empty():
|
|
||||||
"""Empty state should serialize/deserialize cleanly."""
|
|
||||||
state = {
|
|
||||||
"job_id": "empty-job",
|
|
||||||
"video_path": "",
|
|
||||||
"profile_name": "soccer_broadcast",
|
|
||||||
"frames": [],
|
|
||||||
"filtered_frames": [],
|
|
||||||
"boxes_by_frame": {},
|
|
||||||
"text_candidates": [],
|
|
||||||
"unresolved_candidates": [],
|
|
||||||
"detections": [],
|
|
||||||
"stats": PipelineStats(),
|
|
||||||
}
|
|
||||||
|
|
||||||
serialized = serialize_state(state, {})
|
|
||||||
restored = deserialize_state(serialized, [])
|
|
||||||
|
|
||||||
assert restored["job_id"] == "empty-job"
|
|
||||||
assert len(restored["frames"]) == 0
|
|
||||||
assert len(restored["detections"]) == 0
|
|
||||||
assert restored["stats"].frames_extracted == 0
|
|
||||||
|
|||||||
58
tests/detect/test_stage_registry.py
Normal file
58
tests/detect/test_stage_registry.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
"""Tests for the stage registry."""
|
||||||
|
|
||||||
|
from detect.stages import list_stages, get_stage, get_palette
|
||||||
|
|
||||||
|
|
||||||
|
EXPECTED_STAGES = [
|
||||||
|
"extract_frames", "filter_scenes", "detect_objects", "run_ocr",
|
||||||
|
"match_brands", "escalate_vlm", "escalate_cloud", "compile_report",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_stages_registered():
|
||||||
|
stages = list_stages()
|
||||||
|
names = [s.name for s in stages]
|
||||||
|
for expected in EXPECTED_STAGES:
|
||||||
|
assert expected in names, f"Missing stage: {expected}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_stage_has_io():
|
||||||
|
for name in EXPECTED_STAGES:
|
||||||
|
stage = get_stage(name)
|
||||||
|
assert len(stage.io.writes) > 0, f"{name} has no writes"
|
||||||
|
assert stage.label, f"{name} has no label"
|
||||||
|
assert stage.description, f"{name} has no description"
|
||||||
|
|
||||||
|
|
||||||
|
def test_stage_has_serialization():
|
||||||
|
for name in EXPECTED_STAGES:
|
||||||
|
stage = get_stage(name)
|
||||||
|
assert stage.serialize_fn is not None, f"{name} has no serialize_fn"
|
||||||
|
assert stage.deserialize_fn is not None, f"{name} has no deserialize_fn"
|
||||||
|
|
||||||
|
|
||||||
|
def test_palette_groups():
|
||||||
|
palette = get_palette()
|
||||||
|
assert len(palette) > 0
|
||||||
|
all_stages = []
|
||||||
|
for category, stages in palette.items():
|
||||||
|
assert isinstance(category, str)
|
||||||
|
all_stages.extend(stages)
|
||||||
|
assert len(all_stages) == len(EXPECTED_STAGES)
|
||||||
|
|
||||||
|
|
||||||
|
def test_io_chain_valid():
|
||||||
|
"""Each stage's reads should be satisfied by previous stages' writes."""
|
||||||
|
stages = [get_stage(name) for name in EXPECTED_STAGES]
|
||||||
|
available_keys = {"video_path", "job_id", "profile_name",
|
||||||
|
"source_asset_id", "session_brands", "stats",
|
||||||
|
"config_overrides"}
|
||||||
|
|
||||||
|
for stage in stages:
|
||||||
|
for read_key in stage.io.reads:
|
||||||
|
assert read_key in available_keys, (
|
||||||
|
f"Stage {stage.name} reads '{read_key}' but no previous stage writes it. "
|
||||||
|
f"Available: {available_keys}"
|
||||||
|
)
|
||||||
|
for write_key in stage.io.writes:
|
||||||
|
available_keys.add(write_key)
|
||||||
Reference in New Issue
Block a user