From d58a90157a3a14d00c14e68f41ef3a31c0e68105 Mon Sep 17 00:00:00 2001 From: buenosairesam Date: Thu, 26 Mar 2026 05:14:33 -0300 Subject: [PATCH] schema clean up and refactor --- core/schema/models/detect_pipeline.py | 97 +++++++++++ core/schema/serializers/__init__.py | 11 ++ core/schema/serializers/_common.py | 38 +++++ core/schema/serializers/detect_pipeline.py | 108 ++++++++++++ detect/checkpoint/serializer.py | 185 +++++++++------------ detect/models.py | 102 +++--------- detect/stages/__init__.py | 21 +++ detect/stages/base.py | 101 +++++++++++ detect/stages/registry/__init__.py | 28 ++++ detect/stages/registry/_serializers.py | 25 +++ detect/stages/registry/detection.py | 63 +++++++ detect/stages/registry/escalation.py | 63 +++++++ detect/stages/registry/output.py | 32 ++++ detect/stages/registry/preprocessing.py | 57 +++++++ detect/stages/registry/resolution.py | 45 +++++ tests/detect/test_checkpoint.py | 183 +++++++++----------- tests/detect/test_stage_registry.py | 58 +++++++ 17 files changed, 930 insertions(+), 287 deletions(-) create mode 100644 core/schema/models/detect_pipeline.py create mode 100644 core/schema/serializers/__init__.py create mode 100644 core/schema/serializers/_common.py create mode 100644 core/schema/serializers/detect_pipeline.py create mode 100644 detect/stages/base.py create mode 100644 detect/stages/registry/__init__.py create mode 100644 detect/stages/registry/_serializers.py create mode 100644 detect/stages/registry/detection.py create mode 100644 detect/stages/registry/escalation.py create mode 100644 detect/stages/registry/output.py create mode 100644 detect/stages/registry/preprocessing.py create mode 100644 detect/stages/registry/resolution.py create mode 100644 tests/detect/test_stage_registry.py diff --git a/core/schema/models/detect_pipeline.py b/core/schema/models/detect_pipeline.py new file mode 100644 index 0000000..e3673be --- /dev/null +++ b/core/schema/models/detect_pipeline.py @@ -0,0 +1,97 @@ +""" +Detection pipeline runtime models. + +These are the data structures that flow between LangGraph nodes. +They contain runtime types (np.ndarray) so they are NOT generated +by modelgen — they live here for the schema to be the complete +map of the application, but modelgen skips them. + +Wire-format models (SSE events) are in detect.py. +DB models (jobs, checkpoints) are in detect_jobs.py. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Literal + +import numpy as np + + +@dataclass +class Frame: + sequence: int + chunk_id: int + timestamp: float # position in video (seconds) + image: np.ndarray + perceptual_hash: str = "" + + +@dataclass +class BoundingBox: + x: int + y: int + w: int + h: int + confidence: float + label: str + + +@dataclass +class TextCandidate: + frame: Frame + bbox: BoundingBox + text: str + ocr_confidence: float + + +@dataclass +class BrandDetection: + brand: str + timestamp: float + duration: float + confidence: float + source: Literal["ocr", "local_vlm", "cloud_llm", "logo_match", "auxiliary"] + bbox: BoundingBox | None = None + frame_ref: int | None = None + content_type: str = "" + + +@dataclass +class BrandStats: + total_appearances: int = 0 + total_screen_time: float = 0.0 + avg_confidence: float = 0.0 + first_seen: float = 0.0 + last_seen: float = 0.0 + + +@dataclass +class PipelineStats: + frames_extracted: int = 0 + frames_after_scene_filter: int = 0 + regions_detected: int = 0 + regions_resolved_by_ocr: int = 0 + regions_escalated_to_local_vlm: int = 0 + regions_escalated_to_cloud_llm: int = 0 + auxiliary_detections: int = 0 + cloud_llm_calls: int = 0 + processing_time_seconds: float = 0.0 + estimated_cloud_cost_usd: float = 0.0 + + +@dataclass +class DetectionReport: + video_source: str + content_type: str + duration_seconds: float + brands: dict[str, BrandStats] = field(default_factory=dict) + timeline: list[BrandDetection] = field(default_factory=list) + pipeline_stats: PipelineStats = field(default_factory=PipelineStats) + + +# Not in DATACLASSES — modelgen skips these (they contain np.ndarray) +RUNTIME_MODELS = [ + Frame, BoundingBox, TextCandidate, BrandDetection, + BrandStats, PipelineStats, DetectionReport, +] diff --git a/core/schema/serializers/__init__.py b/core/schema/serializers/__init__.py new file mode 100644 index 0000000..623b0e8 --- /dev/null +++ b/core/schema/serializers/__init__.py @@ -0,0 +1,11 @@ +""" +Model serializers — one module per model group, mirroring core/schema/models/. + + models/detect_pipeline.py → serializers/detect_pipeline.py + models/detect_jobs.py → serializers/detect_jobs.py + models/detect.py → serializers/detect.py (SSE events) + +Common utilities in _common.py. +""" + +from ._common import safe_construct, serialize_dataclass, serialize_dataclass_list diff --git a/core/schema/serializers/_common.py b/core/schema/serializers/_common.py new file mode 100644 index 0000000..ede26a9 --- /dev/null +++ b/core/schema/serializers/_common.py @@ -0,0 +1,38 @@ +"""Common serialization utilities.""" + +from __future__ import annotations + +import dataclasses +import logging + +logger = logging.getLogger(__name__) + + +def safe_construct(cls, data: dict): + """ + Construct a dataclass from a dict, tolerant of schema changes. + + - Ignores keys not in the dataclass (field was removed) + - Uses defaults for missing keys (field was added) + - Logs at debug level for mismatches + """ + field_names = {f.name for f in dataclasses.fields(cls)} + + known = {} + for k, v in data.items(): + if k in field_names: + known[k] = v + else: + logger.debug("Ignoring unknown field %s.%s", cls.__name__, k) + + return cls(**known) + + +def serialize_dataclass(obj) -> dict: + """Serialize any dataclass to dict via dataclasses.asdict().""" + return dataclasses.asdict(obj) + + +def serialize_dataclass_list(items) -> list[dict]: + """Serialize a list of dataclasses.""" + return [dataclasses.asdict(item) for item in items] diff --git a/core/schema/serializers/detect_pipeline.py b/core/schema/serializers/detect_pipeline.py new file mode 100644 index 0000000..9738cb4 --- /dev/null +++ b/core/schema/serializers/detect_pipeline.py @@ -0,0 +1,108 @@ +""" +Serializers for detection pipeline runtime models. + +Mirrors core/schema/models/detect_pipeline.py. + +Special handling: + - Frame.image (np.ndarray → S3, excluded from JSON) + - TextCandidate.frame (object ref → frame_sequence integer) +Everything else uses dataclasses.asdict() via safe_construct. +""" + +from __future__ import annotations + +import dataclasses + +from core.schema.models.detect_pipeline import ( + BoundingBox, + BrandDetection, + BrandStats, + DetectionReport, + Frame, + PipelineStats, + TextCandidate, +) +from ._common import safe_construct, serialize_dataclass, serialize_dataclass_list + + +# --------------------------------------------------------------------------- +# Frame — image goes to S3 separately +# --------------------------------------------------------------------------- + +def serialize_frame_meta(frame: Frame) -> dict: + """Serialize Frame metadata only (no image).""" + result = dataclasses.asdict(frame) + del result["image"] + return result + + +def serialize_frames_with_upload(frames: list[Frame], job_id: str) -> tuple[list[dict], dict[int, str]]: + """Upload frame images to S3, return metadata + manifest.""" + from detect.checkpoint.frames import save_frames + + manifest = save_frames(job_id, frames) + meta = [serialize_frame_meta(f) for f in frames] + return meta, manifest + + +def deserialize_frames_with_download(meta: list[dict], manifest: dict, job_id: str) -> list[Frame]: + """Load frames from S3 + metadata.""" + from detect.checkpoint.frames import load_frames + + int_manifest = {int(k): v for k, v in manifest.items()} + return load_frames(int_manifest, meta) + + +# --------------------------------------------------------------------------- +# TextCandidate — frame ref is an object, stored as sequence int +# --------------------------------------------------------------------------- + +def serialize_text_candidate(tc: TextCandidate) -> dict: + bbox_dict = dataclasses.asdict(tc.bbox) + result = { + "frame_sequence": tc.frame.sequence, + "bbox": bbox_dict, + "text": tc.text, + "ocr_confidence": tc.ocr_confidence, + } + return result + + +def serialize_text_candidates(candidates: list[TextCandidate]) -> list[dict]: + return [serialize_text_candidate(tc) for tc in candidates] + + +def deserialize_text_candidate(data: dict, frame_map: dict[int, Frame]) -> TextCandidate: + frame = frame_map[data["frame_sequence"]] + bbox = safe_construct(BoundingBox, data["bbox"]) + candidate = TextCandidate( + frame=frame, + bbox=bbox, + text=data["text"], + ocr_confidence=data["ocr_confidence"], + ) + return candidate + + +def deserialize_text_candidates(data: list[dict], frame_map: dict[int, Frame]) -> list[TextCandidate]: + return [deserialize_text_candidate(d, frame_map) for d in data] + + +# --------------------------------------------------------------------------- +# BoundingBox, BrandDetection, PipelineStats, etc — standard dataclasses +# --------------------------------------------------------------------------- + +def deserialize_bounding_box(data: dict) -> BoundingBox: + return safe_construct(BoundingBox, data) + + +def deserialize_brand_detection(data: dict) -> BrandDetection: + return safe_construct(BrandDetection, data) + + +def deserialize_pipeline_stats(data: dict) -> PipelineStats: + return safe_construct(PipelineStats, data) + + +def deserialize_detection_report(data: dict) -> DetectionReport: + return safe_construct(DetectionReport, data) diff --git a/detect/checkpoint/serializer.py b/detect/checkpoint/serializer.py index 35a9874..087ca9a 100644 --- a/detect/checkpoint/serializer.py +++ b/detect/checkpoint/serializer.py @@ -1,133 +1,108 @@ -"""State serialization — DetectState ↔ JSON-compatible dict.""" +""" +State serialization — DetectState ↔ JSON-compatible dict. + +Delegates to each stage's serialize_fn/deserialize_fn via the registry. +This file has no model-specific knowledge — stages own their data format. + +The only things serialized here are the "envelope" fields (job_id, video_path, etc.) +that don't belong to any stage. +""" from __future__ import annotations -import dataclasses - -from detect.models import ( - BoundingBox, - BrandDetection, - Frame, - PipelineStats, - TextCandidate, +from core.schema.serializers._common import serialize_dataclass +from core.schema.serializers.detect_pipeline import ( + deserialize_pipeline_stats, + deserialize_text_candidates, ) -# --------------------------------------------------------------------------- -# Serialize helpers -# --------------------------------------------------------------------------- - -def serialize_frame_meta(frame: Frame) -> dict: - meta = { - "sequence": frame.sequence, - "chunk_id": frame.chunk_id, - "timestamp": frame.timestamp, - "perceptual_hash": frame.perceptual_hash, - } - return meta - - -def serialize_text_candidate(tc: TextCandidate) -> dict: - bbox_dict = dataclasses.asdict(tc.bbox) - candidate = { - "frame_sequence": tc.frame.sequence, - "bbox": bbox_dict, - "text": tc.text, - "ocr_confidence": tc.ocr_confidence, - } - return candidate +# Envelope fields — not owned by any stage, always present +ENVELOPE_KEYS = ["job_id", "video_path", "profile_name", "config_overrides"] def serialize_state(state: dict, frames_manifest: dict[int, str]) -> dict: """ Serialize DetectState to a JSON-compatible dict. - Frame images are replaced with S3 key references. - TextCandidate.frame references become frame_sequence integers. + Calls each registered stage's serialize_fn for stage-owned data. + Envelope fields (job_id, etc.) are copied directly. """ - frames = state.get("frames", []) - filtered = state.get("filtered_frames", []) + from detect.stages.base import _REGISTRY - manifest_strs = {str(k): v for k, v in frames_manifest.items()} - frames_meta = [serialize_frame_meta(f) for f in frames] - filtered_seqs = [f.sequence for f in filtered] + checkpoint = {} - boxes_serialized = {} - for seq, boxes in state.get("boxes_by_frame", {}).items(): - boxes_serialized[str(seq)] = [dataclasses.asdict(b) for b in boxes] + # Envelope + for key in ENVELOPE_KEYS: + default = {} if key == "config_overrides" else "" + checkpoint[key] = state.get(key, default) - text_candidates = [serialize_text_candidate(tc) for tc in state.get("text_candidates", [])] - unresolved = [serialize_text_candidate(tc) for tc in state.get("unresolved_candidates", [])] - detections = [dataclasses.asdict(d) for d in state.get("detections", [])] - stats = dataclasses.asdict(state.get("stats", PipelineStats())) + # Frames manifest (needed by frame-loading stages) + checkpoint["frames_manifest"] = {str(k): v for k, v in frames_manifest.items()} + + # Stats (shared across stages, not owned by one) + stats = state.get("stats") + if stats is not None: + checkpoint["stats"] = serialize_dataclass(stats) + else: + checkpoint["stats"] = {} + + # Per-stage data + for name, stage_def in _REGISTRY.items(): + if stage_def.serialize_fn is None: + continue + job_id = state.get("job_id", "") + stage_data = stage_def.serialize_fn(state, job_id) + checkpoint[f"stage_{name}"] = stage_data - checkpoint = { - "job_id": state.get("job_id", ""), - "video_path": state.get("video_path", ""), - "profile_name": state.get("profile_name", ""), - "config_overrides": state.get("config_overrides", {}), - "frames_manifest": manifest_strs, - "frames_meta": frames_meta, - "filtered_frame_sequences": filtered_seqs, - "boxes_by_frame": boxes_serialized, - "text_candidates": text_candidates, - "unresolved_candidates": unresolved, - "detections": detections, - "stats": stats, - } return checkpoint -# --------------------------------------------------------------------------- -# Deserialize helpers -# --------------------------------------------------------------------------- +def deserialize_state(checkpoint: dict, frames: list) -> dict: + """ + Reconstitute DetectState from a checkpoint dict + loaded frames. -def deserialize_text_candidate(d: dict, frame_map: dict[int, Frame]) -> TextCandidate: - frame = frame_map[d["frame_sequence"]] - bbox = BoundingBox(**d["bbox"]) - candidate = TextCandidate( - frame=frame, - bbox=bbox, - text=d["text"], - ocr_confidence=d["ocr_confidence"], - ) - return candidate + Calls each stage's deserialize_fn to restore stage-owned data. + """ + from detect.stages.base import _REGISTRY - -def deserialize_state(checkpoint: dict, frames: list[Frame]) -> dict: - """Reconstitute DetectState from a checkpoint dict + loaded frames.""" frame_map = {f.sequence: f for f in frames} - filtered_seqs = set(checkpoint.get("filtered_frame_sequences", [])) - filtered_frames = [f for f in frames if f.sequence in filtered_seqs] + state = {} - boxes_by_frame = {} - for seq_str, box_dicts in checkpoint.get("boxes_by_frame", {}).items(): - seq = int(seq_str) - boxes_by_frame[seq] = [BoundingBox(**b) for b in box_dicts] + # Envelope + for key in ENVELOPE_KEYS: + default = {} if key == "config_overrides" else "" + state[key] = checkpoint.get(key, default) - text_candidates = [ - deserialize_text_candidate(d, frame_map) - for d in checkpoint.get("text_candidates", []) - ] - unresolved_candidates = [ - deserialize_text_candidate(d, frame_map) - for d in checkpoint.get("unresolved_candidates", []) - ] - detections = [BrandDetection(**d) for d in checkpoint.get("detections", [])] - stats = PipelineStats(**checkpoint.get("stats", {})) + # Frames (always present, loaded externally) + state["frames"] = frames + + # Stats + state["stats"] = deserialize_pipeline_stats(checkpoint.get("stats", {})) + + # Per-stage data + for name, stage_def in _REGISTRY.items(): + if stage_def.deserialize_fn is None: + continue + + stage_key = f"stage_{name}" + if stage_key not in checkpoint: + continue + + job_id = state.get("job_id", "") + stage_data = stage_def.deserialize_fn(checkpoint[stage_key], job_id) + + for k, v in stage_data.items(): + if k == "_filtered_sequences": + # Reconnect filtered frames from sequence list + seq_set = set(v) + state["filtered_frames"] = [f for f in frames if f.sequence in seq_set] + elif k.endswith("_raw"): + # Raw text candidates need frame reference reconnection + real_key = k.removeprefix("_").removesuffix("_raw") + state[real_key] = deserialize_text_candidates(v, frame_map) + else: + state[k] = v - state = { - "job_id": checkpoint.get("job_id", ""), - "video_path": checkpoint.get("video_path", ""), - "profile_name": checkpoint.get("profile_name", ""), - "config_overrides": checkpoint.get("config_overrides", {}), - "frames": frames, - "filtered_frames": filtered_frames, - "boxes_by_frame": boxes_by_frame, - "text_candidates": text_candidates, - "unresolved_candidates": unresolved_candidates, - "detections": detections, - "stats": stats, - } return state diff --git a/detect/models.py b/detect/models.py index 694b22b..c1b6889 100644 --- a/detect/models.py +++ b/detect/models.py @@ -1,86 +1,26 @@ """ -Core domain models for the detection pipeline. +Re-export pipeline runtime models from core/schema/models/detect_pipeline.py. -These are pipeline-internal models — the data structures that flow -between LangGraph nodes. SSE event payloads (sse_contract.py) are -derived from these when emitting to the UI. +All models are defined in core/schema/ — this module exists for backward +compatibility so existing imports (from detect.models import Frame) keep working. """ -from __future__ import annotations +from core.schema.models.detect_pipeline import ( + BoundingBox, + BrandDetection, + BrandStats, + DetectionReport, + Frame, + PipelineStats, + TextCandidate, +) -from dataclasses import dataclass, field -from typing import Literal - -import numpy as np - - -@dataclass -class Frame: - sequence: int - chunk_id: int - timestamp: float # position in video (seconds) - image: np.ndarray - perceptual_hash: str = "" - - -@dataclass -class BoundingBox: - x: int - y: int - w: int - h: int - confidence: float - label: str - - -@dataclass -class TextCandidate: - frame: Frame - bbox: BoundingBox - text: str - ocr_confidence: float - - -@dataclass -class BrandDetection: - brand: str - timestamp: float - duration: float - confidence: float - source: Literal["ocr", "local_vlm", "cloud_llm", "logo_match", "auxiliary"] - bbox: BoundingBox | None = None - frame_ref: int | None = None - content_type: str = "" - - -@dataclass -class BrandStats: - total_appearances: int = 0 - total_screen_time: float = 0.0 - avg_confidence: float = 0.0 - first_seen: float = 0.0 - last_seen: float = 0.0 - - -@dataclass -class PipelineStats: - frames_extracted: int = 0 - frames_after_scene_filter: int = 0 - regions_detected: int = 0 - regions_resolved_by_ocr: int = 0 - regions_escalated_to_local_vlm: int = 0 - regions_escalated_to_cloud_llm: int = 0 - auxiliary_detections: int = 0 - cloud_llm_calls: int = 0 - processing_time_seconds: float = 0.0 - estimated_cloud_cost_usd: float = 0.0 - - -@dataclass -class DetectionReport: - video_source: str - content_type: str - duration_seconds: float - brands: dict[str, BrandStats] = field(default_factory=dict) - timeline: list[BrandDetection] = field(default_factory=list) - pipeline_stats: PipelineStats = field(default_factory=PipelineStats) +__all__ = [ + "BoundingBox", + "BrandDetection", + "BrandStats", + "DetectionReport", + "Frame", + "PipelineStats", + "TextCandidate", +] diff --git a/detect/stages/__init__.py b/detect/stages/__init__.py index e69de29..7c66a29 100644 --- a/detect/stages/__init__.py +++ b/detect/stages/__init__.py @@ -0,0 +1,21 @@ +""" +Pipeline stages. + +Each stage registers its StageDefinition on import, +declaring IO (what it reads/writes from state), +config fields (what's tunable from the editor), +and serialization (how to checkpoint its outputs). +""" + +from .base import ( + StageDefinition, + StageIO, + StageConfigField, + register_stage, + get_stage, + list_stages, + get_palette, +) + +# Populate registry with built-in stages +from . import registry # noqa: F401 diff --git a/detect/stages/base.py b/detect/stages/base.py new file mode 100644 index 0000000..13e109a --- /dev/null +++ b/detect/stages/base.py @@ -0,0 +1,101 @@ +""" +Stage protocol — common interface for all pipeline stages. + +Every stage declares: + - IO: what it reads/writes from DetectState + - Config: tunable parameters for the editor + - Serialization: how to persist/restore its own outputs + +The checkpoint layer is a black box — it asks each stage to serialize its +outputs and stores the result. Stages own their data format. Binary data +(frames, crops) goes to S3 via the stage itself. The checkpoint just +stores the JSON envelope. + +The graph builder uses StageIO to validate that a stage's inputs are +satisfied by previous stages' outputs. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Callable + + +@dataclass +class StageIO: + """Declares what a stage reads and writes from/to DetectState.""" + reads: list[str] + writes: list[str] + optional_reads: list[str] = field(default_factory=list) + + +@dataclass +class StageConfigField: + """A single tunable config parameter for the editor UI.""" + name: str + type: str # "float", "int", "str", "bool", "list[str]" + default: Any + description: str = "" + min: float | None = None + max: float | None = None + options: list[str] | None = None + + +@dataclass +class StageDefinition: + """ + Complete metadata for a pipeline stage. + + The profile editor uses this to build the palette, generate config + forms, and validate graph connections. The checkpoint uses serialize_fn + and deserialize_fn to persist stage outputs without knowing the internals. + """ + name: str + label: str + description: str + io: StageIO + config_fields: list[StageConfigField] = field(default_factory=list) + category: str = "detection" + + # The actual graph node function: (DetectState) → dict + fn: Callable | None = None + + # Stage-owned serialization for checkpointing. + # serialize_fn: (state: dict, job_id: str) → json-compatible dict + # Stage picks its writes from state, serializes them. + # Binary data (frames) → S3 via stage, returns refs. + # deserialize_fn: (data: dict, job_id: str) → state update dict + # Stage restores its writes from the persisted data. + serialize_fn: Callable | None = None + deserialize_fn: Callable | None = None + + +# --------------------------------------------------------------------------- +# Registry +# --------------------------------------------------------------------------- + +_REGISTRY: dict[str, StageDefinition] = {} + + +def register_stage(definition: StageDefinition): + _REGISTRY[definition.name] = definition + + +def get_stage(name: str) -> StageDefinition: + if name not in _REGISTRY: + raise KeyError(f"Unknown stage: {name!r}. Registered: {list(_REGISTRY)}") + return _REGISTRY[name] + + +def list_stages() -> list[StageDefinition]: + return list(_REGISTRY.values()) + + +def get_palette() -> dict[str, list[StageDefinition]]: + """Group stages by category for the editor palette.""" + palette: dict[str, list[StageDefinition]] = {} + for stage in _REGISTRY.values(): + if stage.category not in palette: + palette[stage.category] = [] + palette[stage.category].append(stage) + return palette diff --git a/detect/stages/registry/__init__.py b/detect/stages/registry/__init__.py new file mode 100644 index 0000000..1e410df --- /dev/null +++ b/detect/stages/registry/__init__.py @@ -0,0 +1,28 @@ +""" +Stage registry — registers all built-in stages. + +Split by category: + preprocessing.py — extract_frames, filter_scenes + detection.py — detect_objects, run_ocr + resolution.py — match_brands + escalation.py — escalate_vlm, escalate_cloud + output.py — compile_report + _serializers.py — shared serialization helpers +""" + +from . import preprocessing +from . import detection +from . import resolution +from . import escalation +from . import output + + +def register_all(): + preprocessing.register() + detection.register() + resolution.register() + escalation.register() + output.register() + + +register_all() diff --git a/detect/stages/registry/_serializers.py b/detect/stages/registry/_serializers.py new file mode 100644 index 0000000..1375e6c --- /dev/null +++ b/detect/stages/registry/_serializers.py @@ -0,0 +1,25 @@ +""" +Re-export serializers from core/schema/serializers/. + +Stage registry modules import from here for convenience. +All serialization logic lives in core/schema/serializers/. +""" + +from core.schema.serializers._common import ( + safe_construct, + serialize_dataclass, + serialize_dataclass_list, +) +from core.schema.serializers.detect_pipeline import ( + serialize_frame_meta, + serialize_frames_with_upload as serialize_frames, + deserialize_frames_with_download as deserialize_frames, + serialize_text_candidate, + serialize_text_candidates, + deserialize_text_candidate, + deserialize_text_candidates, + deserialize_bounding_box, + deserialize_brand_detection, + deserialize_pipeline_stats, + deserialize_detection_report, +) diff --git a/detect/stages/registry/detection.py b/detect/stages/registry/detection.py new file mode 100644 index 0000000..14b671d --- /dev/null +++ b/detect/stages/registry/detection.py @@ -0,0 +1,63 @@ +"""Registration for detection stages: YOLO, OCR.""" + +from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage +from ._serializers import ( + serialize_dataclass_list, + serialize_text_candidates, + deserialize_bounding_box, +) + + +def _ser_detect(state: dict, job_id: str) -> dict: + boxes = state.get("boxes_by_frame", {}) + serialized = {str(seq): serialize_dataclass_list(bl) for seq, bl in boxes.items()} + return {"boxes_by_frame": serialized} + + +def _deser_detect(data: dict, job_id: str) -> dict: + boxes = {} + for seq_str, box_dicts in data.get("boxes_by_frame", {}).items(): + boxes[int(seq_str)] = [deserialize_bounding_box(b) for b in box_dicts] + return {"boxes_by_frame": boxes} + + +def _ser_ocr(state: dict, job_id: str) -> dict: + candidates = state.get("text_candidates", []) + return {"text_candidates": serialize_text_candidates(candidates)} + + +def _deser_ocr(data: dict, job_id: str) -> dict: + return {"_text_candidates_raw": data["text_candidates"]} + + +def register(): + yolo = StageDefinition( + name="detect_objects", + label="Object Detection", + description="YOLO object detection on filtered frames", + category="detection", + io=StageIO(reads=["filtered_frames"], writes=["boxes_by_frame"]), + config_fields=[ + StageConfigField("model_name", "str", "yolov8n.pt", "YOLO model file"), + StageConfigField("confidence_threshold", "float", 0.3, "Min detection confidence", min=0.0, max=1.0), + StageConfigField("target_classes", "list[str]", [], "YOLO classes to detect (empty = all)"), + ], + serialize_fn=_ser_detect, + deserialize_fn=_deser_detect, + ) + register_stage(yolo) + + ocr = StageDefinition( + name="run_ocr", + label="OCR", + description="Extract text from detected regions", + category="detection", + io=StageIO(reads=["filtered_frames", "boxes_by_frame"], writes=["text_candidates"]), + config_fields=[ + StageConfigField("languages", "list[str]", ["en"], "OCR languages"), + StageConfigField("min_confidence", "float", 0.5, "Min OCR confidence", min=0.0, max=1.0), + ], + serialize_fn=_ser_ocr, + deserialize_fn=_deser_ocr, + ) + register_stage(ocr) diff --git a/detect/stages/registry/escalation.py b/detect/stages/registry/escalation.py new file mode 100644 index 0000000..fca0222 --- /dev/null +++ b/detect/stages/registry/escalation.py @@ -0,0 +1,63 @@ +"""Registration for escalation stages: local VLM, cloud LLM.""" + +from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage +from ._serializers import ( + serialize_dataclass_list, + serialize_text_candidates, + deserialize_brand_detection, +) + + +def _ser_escalation(state: dict, job_id: str) -> dict: + matched = state.get("detections", []) + unresolved = state.get("unresolved_candidates", []) + return { + "detections": serialize_dataclass_list(matched), + "unresolved_candidates": serialize_text_candidates(unresolved), + } + + +def _deser_escalation(data: dict, job_id: str) -> dict: + detections = [deserialize_brand_detection(d) for d in data.get("detections", [])] + return { + "detections": detections, + "_unresolved_raw": data.get("unresolved_candidates", []), + } + + +def register(): + vlm = StageDefinition( + name="escalate_vlm", + label="Local VLM", + description="Process unresolved crops with moondream2", + category="escalation", + io=StageIO( + reads=["unresolved_candidates"], + writes=["detections", "unresolved_candidates"], + optional_reads=["source_asset_id"], + ), + config_fields=[ + StageConfigField("min_confidence", "float", 0.5, "Min VLM confidence", min=0.0, max=1.0), + ], + serialize_fn=_ser_escalation, + deserialize_fn=_deser_escalation, + ) + register_stage(vlm) + + cloud = StageDefinition( + name="escalate_cloud", + label="Cloud LLM", + description="Escalate remaining crops to cloud provider", + category="escalation", + io=StageIO( + reads=["unresolved_candidates"], + writes=["detections"], + optional_reads=["source_asset_id"], + ), + config_fields=[ + StageConfigField("min_confidence", "float", 0.4, "Min cloud confidence", min=0.0, max=1.0), + ], + serialize_fn=_ser_escalation, + deserialize_fn=_deser_escalation, + ) + register_stage(cloud) diff --git a/detect/stages/registry/output.py b/detect/stages/registry/output.py new file mode 100644 index 0000000..4efa4b0 --- /dev/null +++ b/detect/stages/registry/output.py @@ -0,0 +1,32 @@ +"""Registration for output stages: report compilation.""" + +from detect.stages.base import StageDefinition, StageIO, register_stage +from ._serializers import serialize_dataclass, deserialize_detection_report + + +def _ser_report(state: dict, job_id: str) -> dict: + report = state.get("report") + if report is None: + return {"report": None} + return {"report": serialize_dataclass(report)} + + +def _deser_report(data: dict, job_id: str) -> dict: + raw = data.get("report") + if raw is None: + return {"report": None} + return {"report": deserialize_detection_report(raw)} + + +def register(): + report = StageDefinition( + name="compile_report", + label="Report", + description="Merge detections and compile final report", + category="output", + io=StageIO(reads=["detections"], writes=["report"]), + config_fields=[], + serialize_fn=_ser_report, + deserialize_fn=_deser_report, + ) + register_stage(report) diff --git a/detect/stages/registry/preprocessing.py b/detect/stages/registry/preprocessing.py new file mode 100644 index 0000000..11d40f5 --- /dev/null +++ b/detect/stages/registry/preprocessing.py @@ -0,0 +1,57 @@ +"""Registration for preprocessing stages: frame extraction, scene filter.""" + +from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage +from ._serializers import serialize_frames, deserialize_frames + + +def _ser_extract(state: dict, job_id: str) -> dict: + frames = state.get("frames", []) + meta, manifest = serialize_frames(frames, job_id) + return {"frames_meta": meta, "frames_manifest": manifest} + + +def _deser_extract(data: dict, job_id: str) -> dict: + frames = deserialize_frames(data["frames_meta"], data["frames_manifest"], job_id) + return {"frames": frames} + + +def _ser_filter(state: dict, job_id: str) -> dict: + filtered = state.get("filtered_frames", []) + seqs = [f.sequence for f in filtered] + return {"filtered_frame_sequences": seqs} + + +def _deser_filter(data: dict, job_id: str) -> dict: + return {"_filtered_sequences": data["filtered_frame_sequences"]} + + +def register(): + extract = StageDefinition( + name="extract_frames", + label="Frame Extraction", + description="Extract frames from video at configurable FPS", + category="preprocessing", + io=StageIO(reads=["video_path"], writes=["frames"]), + config_fields=[ + StageConfigField("fps", "float", 2.0, "Frames per second", min=0.1, max=30.0), + StageConfigField("max_frames", "int", 500, "Maximum frames to extract", min=1, max=10000), + ], + serialize_fn=_ser_extract, + deserialize_fn=_deser_extract, + ) + register_stage(extract) + + scene_filter = StageDefinition( + name="filter_scenes", + label="Scene Filter", + description="Deduplicate similar frames using perceptual hashing", + category="preprocessing", + io=StageIO(reads=["frames"], writes=["filtered_frames"]), + config_fields=[ + StageConfigField("hamming_threshold", "int", 8, "Hamming distance threshold", min=0, max=64), + StageConfigField("enabled", "bool", True, "Enable scene filtering"), + ], + serialize_fn=_ser_filter, + deserialize_fn=_deser_filter, + ) + register_stage(scene_filter) diff --git a/detect/stages/registry/resolution.py b/detect/stages/registry/resolution.py new file mode 100644 index 0000000..a2affe8 --- /dev/null +++ b/detect/stages/registry/resolution.py @@ -0,0 +1,45 @@ +"""Registration for resolution stages: brand resolver.""" + +from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage +from ._serializers import ( + serialize_dataclass_list, + serialize_text_candidates, + deserialize_brand_detection, +) + + +def _ser_brands(state: dict, job_id: str) -> dict: + matched = state.get("detections", []) + unresolved = state.get("unresolved_candidates", []) + return { + "detections": serialize_dataclass_list(matched), + "unresolved_candidates": serialize_text_candidates(unresolved), + } + + +def _deser_brands(data: dict, job_id: str) -> dict: + detections = [deserialize_brand_detection(d) for d in data.get("detections", [])] + return { + "detections": detections, + "_unresolved_raw": data.get("unresolved_candidates", []), + } + + +def register(): + resolver = StageDefinition( + name="match_brands", + label="Brand Resolver", + description="Match OCR text against known brands (session + global DB)", + category="resolution", + io=StageIO( + reads=["text_candidates"], + writes=["detections", "unresolved_candidates"], + optional_reads=["session_brands", "source_asset_id"], + ), + config_fields=[ + StageConfigField("fuzzy_threshold", "int", 75, "Fuzzy match threshold", min=0, max=100), + ], + serialize_fn=_ser_brands, + deserialize_fn=_deser_brands, + ) + register_stage(resolver) diff --git a/tests/detect/test_checkpoint.py b/tests/detect/test_checkpoint.py index a147a5e..c55b96f 100644 --- a/tests/detect/test_checkpoint.py +++ b/tests/detect/test_checkpoint.py @@ -1,26 +1,30 @@ -"""Tests for checkpoint serialization — round-trip without S3.""" +"""Tests for checkpoint serialization — unit tests without S3.""" + +import json import numpy as np import pytest from detect.models import BoundingBox, BrandDetection, Frame, PipelineStats, TextCandidate -from detect.checkpoint.serializer import ( - serialize_state, - deserialize_state, +from core.schema.serializers._common import safe_construct +from core.schema.serializers.detect_pipeline import ( serialize_frame_meta, serialize_text_candidate, + serialize_text_candidates, + serialize_dataclass_list, deserialize_text_candidate, + deserialize_text_candidates, + deserialize_bounding_box, + deserialize_brand_detection, + deserialize_pipeline_stats, ) def _make_frame(seq: int = 0, w: int = 100, h: int = 80) -> Frame: image = np.random.randint(0, 255, (h, w, 3), dtype=np.uint8) return Frame( - sequence=seq, - chunk_id=0, - timestamp=float(seq) * 0.5, - image=image, - perceptual_hash=f"hash_{seq}", + sequence=seq, chunk_id=0, timestamp=float(seq) * 0.5, + image=image, perceptual_hash=f"hash_{seq}", ) @@ -35,13 +39,8 @@ def _make_candidate(frame: Frame, text: str = "NIKE") -> TextCandidate: def _make_detection(brand: str = "Nike", timestamp: float = 1.0) -> BrandDetection: return BrandDetection( - brand=brand, - timestamp=timestamp, - duration=0.5, - confidence=0.92, - source="ocr", - content_type="soccer_broadcast", - frame_ref=0, + brand=brand, timestamp=timestamp, duration=0.5, + confidence=0.92, source="ocr", content_type="soccer_broadcast", frame_ref=0, ) @@ -81,102 +80,84 @@ def test_deserialize_text_candidate(): assert restored.text == "ADIDAS" assert restored.ocr_confidence == 0.85 - assert restored.frame is frame # same object reference + assert restored.frame is frame assert restored.bbox.x == 10 -# --- Full state round-trip --- +# --- BoundingBox --- -def test_state_round_trip(): - frames = [_make_frame(seq=i) for i in range(3)] - filtered = frames[:2] +def test_bounding_box_round_trip(): + box = _make_box(x=15, y=25, w=40, h=30) + serialized = serialize_dataclass_list([box])[0] + restored = deserialize_bounding_box(serialized) - box = _make_box() - boxes_by_frame = {0: [box], 1: [box]} + assert restored.x == 15 + assert restored.w == 40 + assert restored.confidence == 0.9 - candidates = [_make_candidate(frames[0], "NIKE"), _make_candidate(frames[1], "EMIRATES")] - unresolved = [_make_candidate(frames[2], "unknown")] - detections = [_make_detection("Nike", 0.5), _make_detection("Emirates", 1.0)] +# --- BrandDetection --- + +def test_brand_detection_round_trip(): + det = _make_detection("Emirates", 3.5) + serialized = serialize_dataclass_list([det])[0] + restored = deserialize_brand_detection(serialized) + + assert restored.brand == "Emirates" + assert restored.timestamp == 3.5 + assert restored.source == "ocr" + + +# --- PipelineStats --- + +def test_pipeline_stats_round_trip(): stats = PipelineStats( - frames_extracted=3, - frames_after_scene_filter=2, - regions_detected=2, - regions_resolved_by_ocr=2, - cloud_llm_calls=1, - estimated_cloud_cost_usd=0.003, + frames_extracted=120, cloud_llm_calls=3, + estimated_cloud_cost_usd=0.005, ) + serialized = serialize_dataclass_list([stats])[0] + restored = deserialize_pipeline_stats(serialized) - state = { - "job_id": "test-123", - "video_path": "/tmp/test.mp4", - "profile_name": "soccer_broadcast", - "config_overrides": {"ocr": {"min_confidence": 0.3}}, - "frames": frames, - "filtered_frames": filtered, - "boxes_by_frame": boxes_by_frame, - "text_candidates": candidates, - "unresolved_candidates": unresolved, - "detections": detections, - "stats": stats, + assert restored.frames_extracted == 120 + assert restored.cloud_llm_calls == 3 + assert restored.estimated_cloud_cost_usd == 0.005 + + +# --- safe_construct tolerates schema changes --- + +def test_safe_construct_ignores_unknown_fields(): + data = {"x": 1, "y": 2, "w": 3, "h": 4, "confidence": 0.5, "label": "t", "extra_field": 99} + box = safe_construct(BoundingBox, data) + + assert box.x == 1 + assert not hasattr(box, "extra_field") + + +def test_safe_construct_uses_defaults(): + data = {"frames_extracted": 50} + stats = safe_construct(PipelineStats, data) + + assert stats.frames_extracted == 50 + assert stats.cloud_llm_calls == 0 # default + + +# --- JSON compatibility --- + +def test_all_serialized_is_json_compatible(): + frame = _make_frame() + candidate = _make_candidate(frame) + detection = _make_detection() + stats = PipelineStats(frames_extracted=10) + + all_data = { + "frame_meta": serialize_frame_meta(frame), + "candidate": serialize_text_candidate(candidate), + "detection": serialize_dataclass_list([detection])[0], + "stats": serialize_dataclass_list([stats])[0], } - manifest = {f.sequence: f"s3://fake/frames/{f.sequence}.jpg" for f in frames} - - # Serialize - serialized = serialize_state(state, manifest) - - # Verify JSON-compatible (no numpy, no Frame objects) - import json - json_str = json.dumps(serialized, default=str) + json_str = json.dumps(all_data, default=str) assert len(json_str) > 0 - # Deserialize with the original frames (simulating frame load from S3) - restored = deserialize_state(serialized, frames) - - # Verify round-trip - assert restored["job_id"] == "test-123" - assert restored["video_path"] == "/tmp/test.mp4" - assert restored["profile_name"] == "soccer_broadcast" - assert restored["config_overrides"] == {"ocr": {"min_confidence": 0.3}} - - assert len(restored["frames"]) == 3 - assert len(restored["filtered_frames"]) == 2 - assert len(restored["boxes_by_frame"]) == 2 - assert len(restored["text_candidates"]) == 2 - assert len(restored["unresolved_candidates"]) == 1 - assert len(restored["detections"]) == 2 - - restored_stats = restored["stats"] - assert restored_stats.frames_extracted == 3 - assert restored_stats.cloud_llm_calls == 1 - assert restored_stats.estimated_cloud_cost_usd == 0.003 - - # TextCandidate frame references should point to actual Frame objects - tc = restored["text_candidates"][0] - assert tc.frame is frames[0] - assert tc.text == "NIKE" - - -def test_state_round_trip_empty(): - """Empty state should serialize/deserialize cleanly.""" - state = { - "job_id": "empty-job", - "video_path": "", - "profile_name": "soccer_broadcast", - "frames": [], - "filtered_frames": [], - "boxes_by_frame": {}, - "text_candidates": [], - "unresolved_candidates": [], - "detections": [], - "stats": PipelineStats(), - } - - serialized = serialize_state(state, {}) - restored = deserialize_state(serialized, []) - - assert restored["job_id"] == "empty-job" - assert len(restored["frames"]) == 0 - assert len(restored["detections"]) == 0 - assert restored["stats"].frames_extracted == 0 + roundtrip = json.loads(json_str) + assert roundtrip["frame_meta"]["sequence"] == frame.sequence diff --git a/tests/detect/test_stage_registry.py b/tests/detect/test_stage_registry.py new file mode 100644 index 0000000..8f9919b --- /dev/null +++ b/tests/detect/test_stage_registry.py @@ -0,0 +1,58 @@ +"""Tests for the stage registry.""" + +from detect.stages import list_stages, get_stage, get_palette + + +EXPECTED_STAGES = [ + "extract_frames", "filter_scenes", "detect_objects", "run_ocr", + "match_brands", "escalate_vlm", "escalate_cloud", "compile_report", +] + + +def test_all_stages_registered(): + stages = list_stages() + names = [s.name for s in stages] + for expected in EXPECTED_STAGES: + assert expected in names, f"Missing stage: {expected}" + + +def test_stage_has_io(): + for name in EXPECTED_STAGES: + stage = get_stage(name) + assert len(stage.io.writes) > 0, f"{name} has no writes" + assert stage.label, f"{name} has no label" + assert stage.description, f"{name} has no description" + + +def test_stage_has_serialization(): + for name in EXPECTED_STAGES: + stage = get_stage(name) + assert stage.serialize_fn is not None, f"{name} has no serialize_fn" + assert stage.deserialize_fn is not None, f"{name} has no deserialize_fn" + + +def test_palette_groups(): + palette = get_palette() + assert len(palette) > 0 + all_stages = [] + for category, stages in palette.items(): + assert isinstance(category, str) + all_stages.extend(stages) + assert len(all_stages) == len(EXPECTED_STAGES) + + +def test_io_chain_valid(): + """Each stage's reads should be satisfied by previous stages' writes.""" + stages = [get_stage(name) for name in EXPECTED_STAGES] + available_keys = {"video_path", "job_id", "profile_name", + "source_asset_id", "session_brands", "stats", + "config_overrides"} + + for stage in stages: + for read_key in stage.io.reads: + assert read_key in available_keys, ( + f"Stage {stage.name} reads '{read_key}' but no previous stage writes it. " + f"Available: {available_keys}" + ) + for write_key in stage.io.writes: + available_keys.add(write_key)