schema clean up and refactor

This commit is contained in:
2026-03-26 05:14:33 -03:00
parent 08c58a6a9d
commit d58a90157a
17 changed files with 930 additions and 287 deletions

View File

@@ -0,0 +1,97 @@
"""
Detection pipeline runtime models.
These are the data structures that flow between LangGraph nodes.
They contain runtime types (np.ndarray) so they are NOT generated
by modelgen — they live here for the schema to be the complete
map of the application, but modelgen skips them.
Wire-format models (SSE events) are in detect.py.
DB models (jobs, checkpoints) are in detect_jobs.py.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Literal
import numpy as np
@dataclass
class Frame:
sequence: int
chunk_id: int
timestamp: float # position in video (seconds)
image: np.ndarray
perceptual_hash: str = ""
@dataclass
class BoundingBox:
x: int
y: int
w: int
h: int
confidence: float
label: str
@dataclass
class TextCandidate:
frame: Frame
bbox: BoundingBox
text: str
ocr_confidence: float
@dataclass
class BrandDetection:
brand: str
timestamp: float
duration: float
confidence: float
source: Literal["ocr", "local_vlm", "cloud_llm", "logo_match", "auxiliary"]
bbox: BoundingBox | None = None
frame_ref: int | None = None
content_type: str = ""
@dataclass
class BrandStats:
total_appearances: int = 0
total_screen_time: float = 0.0
avg_confidence: float = 0.0
first_seen: float = 0.0
last_seen: float = 0.0
@dataclass
class PipelineStats:
frames_extracted: int = 0
frames_after_scene_filter: int = 0
regions_detected: int = 0
regions_resolved_by_ocr: int = 0
regions_escalated_to_local_vlm: int = 0
regions_escalated_to_cloud_llm: int = 0
auxiliary_detections: int = 0
cloud_llm_calls: int = 0
processing_time_seconds: float = 0.0
estimated_cloud_cost_usd: float = 0.0
@dataclass
class DetectionReport:
video_source: str
content_type: str
duration_seconds: float
brands: dict[str, BrandStats] = field(default_factory=dict)
timeline: list[BrandDetection] = field(default_factory=list)
pipeline_stats: PipelineStats = field(default_factory=PipelineStats)
# Not in DATACLASSES — modelgen skips these (they contain np.ndarray)
RUNTIME_MODELS = [
Frame, BoundingBox, TextCandidate, BrandDetection,
BrandStats, PipelineStats, DetectionReport,
]

View File

@@ -0,0 +1,11 @@
"""
Model serializers — one module per model group, mirroring core/schema/models/.
models/detect_pipeline.py → serializers/detect_pipeline.py
models/detect_jobs.py → serializers/detect_jobs.py
models/detect.py → serializers/detect.py (SSE events)
Common utilities in _common.py.
"""
from ._common import safe_construct, serialize_dataclass, serialize_dataclass_list

View File

@@ -0,0 +1,38 @@
"""Common serialization utilities."""
from __future__ import annotations
import dataclasses
import logging
logger = logging.getLogger(__name__)
def safe_construct(cls, data: dict):
"""
Construct a dataclass from a dict, tolerant of schema changes.
- Ignores keys not in the dataclass (field was removed)
- Uses defaults for missing keys (field was added)
- Logs at debug level for mismatches
"""
field_names = {f.name for f in dataclasses.fields(cls)}
known = {}
for k, v in data.items():
if k in field_names:
known[k] = v
else:
logger.debug("Ignoring unknown field %s.%s", cls.__name__, k)
return cls(**known)
def serialize_dataclass(obj) -> dict:
"""Serialize any dataclass to dict via dataclasses.asdict()."""
return dataclasses.asdict(obj)
def serialize_dataclass_list(items) -> list[dict]:
"""Serialize a list of dataclasses."""
return [dataclasses.asdict(item) for item in items]

View File

@@ -0,0 +1,108 @@
"""
Serializers for detection pipeline runtime models.
Mirrors core/schema/models/detect_pipeline.py.
Special handling:
- Frame.image (np.ndarray → S3, excluded from JSON)
- TextCandidate.frame (object ref → frame_sequence integer)
Everything else uses dataclasses.asdict() via safe_construct.
"""
from __future__ import annotations
import dataclasses
from core.schema.models.detect_pipeline import (
BoundingBox,
BrandDetection,
BrandStats,
DetectionReport,
Frame,
PipelineStats,
TextCandidate,
)
from ._common import safe_construct, serialize_dataclass, serialize_dataclass_list
# ---------------------------------------------------------------------------
# Frame — image goes to S3 separately
# ---------------------------------------------------------------------------
def serialize_frame_meta(frame: Frame) -> dict:
"""Serialize Frame metadata only (no image)."""
result = dataclasses.asdict(frame)
del result["image"]
return result
def serialize_frames_with_upload(frames: list[Frame], job_id: str) -> tuple[list[dict], dict[int, str]]:
"""Upload frame images to S3, return metadata + manifest."""
from detect.checkpoint.frames import save_frames
manifest = save_frames(job_id, frames)
meta = [serialize_frame_meta(f) for f in frames]
return meta, manifest
def deserialize_frames_with_download(meta: list[dict], manifest: dict, job_id: str) -> list[Frame]:
"""Load frames from S3 + metadata."""
from detect.checkpoint.frames import load_frames
int_manifest = {int(k): v for k, v in manifest.items()}
return load_frames(int_manifest, meta)
# ---------------------------------------------------------------------------
# TextCandidate — frame ref is an object, stored as sequence int
# ---------------------------------------------------------------------------
def serialize_text_candidate(tc: TextCandidate) -> dict:
bbox_dict = dataclasses.asdict(tc.bbox)
result = {
"frame_sequence": tc.frame.sequence,
"bbox": bbox_dict,
"text": tc.text,
"ocr_confidence": tc.ocr_confidence,
}
return result
def serialize_text_candidates(candidates: list[TextCandidate]) -> list[dict]:
return [serialize_text_candidate(tc) for tc in candidates]
def deserialize_text_candidate(data: dict, frame_map: dict[int, Frame]) -> TextCandidate:
frame = frame_map[data["frame_sequence"]]
bbox = safe_construct(BoundingBox, data["bbox"])
candidate = TextCandidate(
frame=frame,
bbox=bbox,
text=data["text"],
ocr_confidence=data["ocr_confidence"],
)
return candidate
def deserialize_text_candidates(data: list[dict], frame_map: dict[int, Frame]) -> list[TextCandidate]:
return [deserialize_text_candidate(d, frame_map) for d in data]
# ---------------------------------------------------------------------------
# BoundingBox, BrandDetection, PipelineStats, etc — standard dataclasses
# ---------------------------------------------------------------------------
def deserialize_bounding_box(data: dict) -> BoundingBox:
return safe_construct(BoundingBox, data)
def deserialize_brand_detection(data: dict) -> BrandDetection:
return safe_construct(BrandDetection, data)
def deserialize_pipeline_stats(data: dict) -> PipelineStats:
return safe_construct(PipelineStats, data)
def deserialize_detection_report(data: dict) -> DetectionReport:
return safe_construct(DetectionReport, data)

View File

@@ -1,133 +1,108 @@
"""State serialization — DetectState ↔ JSON-compatible dict."""
"""
State serialization — DetectState ↔ JSON-compatible dict.
Delegates to each stage's serialize_fn/deserialize_fn via the registry.
This file has no model-specific knowledge — stages own their data format.
The only things serialized here are the "envelope" fields (job_id, video_path, etc.)
that don't belong to any stage.
"""
from __future__ import annotations
import dataclasses
from detect.models import (
BoundingBox,
BrandDetection,
Frame,
PipelineStats,
TextCandidate,
from core.schema.serializers._common import serialize_dataclass
from core.schema.serializers.detect_pipeline import (
deserialize_pipeline_stats,
deserialize_text_candidates,
)
# ---------------------------------------------------------------------------
# Serialize helpers
# ---------------------------------------------------------------------------
def serialize_frame_meta(frame: Frame) -> dict:
meta = {
"sequence": frame.sequence,
"chunk_id": frame.chunk_id,
"timestamp": frame.timestamp,
"perceptual_hash": frame.perceptual_hash,
}
return meta
def serialize_text_candidate(tc: TextCandidate) -> dict:
bbox_dict = dataclasses.asdict(tc.bbox)
candidate = {
"frame_sequence": tc.frame.sequence,
"bbox": bbox_dict,
"text": tc.text,
"ocr_confidence": tc.ocr_confidence,
}
return candidate
# Envelope fields — not owned by any stage, always present
ENVELOPE_KEYS = ["job_id", "video_path", "profile_name", "config_overrides"]
def serialize_state(state: dict, frames_manifest: dict[int, str]) -> dict:
"""
Serialize DetectState to a JSON-compatible dict.
Frame images are replaced with S3 key references.
TextCandidate.frame references become frame_sequence integers.
Calls each registered stage's serialize_fn for stage-owned data.
Envelope fields (job_id, etc.) are copied directly.
"""
frames = state.get("frames", [])
filtered = state.get("filtered_frames", [])
from detect.stages.base import _REGISTRY
manifest_strs = {str(k): v for k, v in frames_manifest.items()}
frames_meta = [serialize_frame_meta(f) for f in frames]
filtered_seqs = [f.sequence for f in filtered]
checkpoint = {}
boxes_serialized = {}
for seq, boxes in state.get("boxes_by_frame", {}).items():
boxes_serialized[str(seq)] = [dataclasses.asdict(b) for b in boxes]
# Envelope
for key in ENVELOPE_KEYS:
default = {} if key == "config_overrides" else ""
checkpoint[key] = state.get(key, default)
text_candidates = [serialize_text_candidate(tc) for tc in state.get("text_candidates", [])]
unresolved = [serialize_text_candidate(tc) for tc in state.get("unresolved_candidates", [])]
detections = [dataclasses.asdict(d) for d in state.get("detections", [])]
stats = dataclasses.asdict(state.get("stats", PipelineStats()))
# Frames manifest (needed by frame-loading stages)
checkpoint["frames_manifest"] = {str(k): v for k, v in frames_manifest.items()}
# Stats (shared across stages, not owned by one)
stats = state.get("stats")
if stats is not None:
checkpoint["stats"] = serialize_dataclass(stats)
else:
checkpoint["stats"] = {}
# Per-stage data
for name, stage_def in _REGISTRY.items():
if stage_def.serialize_fn is None:
continue
job_id = state.get("job_id", "")
stage_data = stage_def.serialize_fn(state, job_id)
checkpoint[f"stage_{name}"] = stage_data
checkpoint = {
"job_id": state.get("job_id", ""),
"video_path": state.get("video_path", ""),
"profile_name": state.get("profile_name", ""),
"config_overrides": state.get("config_overrides", {}),
"frames_manifest": manifest_strs,
"frames_meta": frames_meta,
"filtered_frame_sequences": filtered_seqs,
"boxes_by_frame": boxes_serialized,
"text_candidates": text_candidates,
"unresolved_candidates": unresolved,
"detections": detections,
"stats": stats,
}
return checkpoint
# ---------------------------------------------------------------------------
# Deserialize helpers
# ---------------------------------------------------------------------------
def deserialize_state(checkpoint: dict, frames: list) -> dict:
"""
Reconstitute DetectState from a checkpoint dict + loaded frames.
def deserialize_text_candidate(d: dict, frame_map: dict[int, Frame]) -> TextCandidate:
frame = frame_map[d["frame_sequence"]]
bbox = BoundingBox(**d["bbox"])
candidate = TextCandidate(
frame=frame,
bbox=bbox,
text=d["text"],
ocr_confidence=d["ocr_confidence"],
)
return candidate
Calls each stage's deserialize_fn to restore stage-owned data.
"""
from detect.stages.base import _REGISTRY
def deserialize_state(checkpoint: dict, frames: list[Frame]) -> dict:
"""Reconstitute DetectState from a checkpoint dict + loaded frames."""
frame_map = {f.sequence: f for f in frames}
filtered_seqs = set(checkpoint.get("filtered_frame_sequences", []))
filtered_frames = [f for f in frames if f.sequence in filtered_seqs]
state = {}
boxes_by_frame = {}
for seq_str, box_dicts in checkpoint.get("boxes_by_frame", {}).items():
seq = int(seq_str)
boxes_by_frame[seq] = [BoundingBox(**b) for b in box_dicts]
# Envelope
for key in ENVELOPE_KEYS:
default = {} if key == "config_overrides" else ""
state[key] = checkpoint.get(key, default)
text_candidates = [
deserialize_text_candidate(d, frame_map)
for d in checkpoint.get("text_candidates", [])
]
unresolved_candidates = [
deserialize_text_candidate(d, frame_map)
for d in checkpoint.get("unresolved_candidates", [])
]
detections = [BrandDetection(**d) for d in checkpoint.get("detections", [])]
stats = PipelineStats(**checkpoint.get("stats", {}))
# Frames (always present, loaded externally)
state["frames"] = frames
# Stats
state["stats"] = deserialize_pipeline_stats(checkpoint.get("stats", {}))
# Per-stage data
for name, stage_def in _REGISTRY.items():
if stage_def.deserialize_fn is None:
continue
stage_key = f"stage_{name}"
if stage_key not in checkpoint:
continue
job_id = state.get("job_id", "")
stage_data = stage_def.deserialize_fn(checkpoint[stage_key], job_id)
for k, v in stage_data.items():
if k == "_filtered_sequences":
# Reconnect filtered frames from sequence list
seq_set = set(v)
state["filtered_frames"] = [f for f in frames if f.sequence in seq_set]
elif k.endswith("_raw"):
# Raw text candidates need frame reference reconnection
real_key = k.removeprefix("_").removesuffix("_raw")
state[real_key] = deserialize_text_candidates(v, frame_map)
else:
state[k] = v
state = {
"job_id": checkpoint.get("job_id", ""),
"video_path": checkpoint.get("video_path", ""),
"profile_name": checkpoint.get("profile_name", ""),
"config_overrides": checkpoint.get("config_overrides", {}),
"frames": frames,
"filtered_frames": filtered_frames,
"boxes_by_frame": boxes_by_frame,
"text_candidates": text_candidates,
"unresolved_candidates": unresolved_candidates,
"detections": detections,
"stats": stats,
}
return state

View File

@@ -1,86 +1,26 @@
"""
Core domain models for the detection pipeline.
Re-export pipeline runtime models from core/schema/models/detect_pipeline.py.
These are pipeline-internal models — the data structures that flow
between LangGraph nodes. SSE event payloads (sse_contract.py) are
derived from these when emitting to the UI.
All models are defined in core/schema/ — this module exists for backward
compatibility so existing imports (from detect.models import Frame) keep working.
"""
from __future__ import annotations
from core.schema.models.detect_pipeline import (
BoundingBox,
BrandDetection,
BrandStats,
DetectionReport,
Frame,
PipelineStats,
TextCandidate,
)
from dataclasses import dataclass, field
from typing import Literal
import numpy as np
@dataclass
class Frame:
sequence: int
chunk_id: int
timestamp: float # position in video (seconds)
image: np.ndarray
perceptual_hash: str = ""
@dataclass
class BoundingBox:
x: int
y: int
w: int
h: int
confidence: float
label: str
@dataclass
class TextCandidate:
frame: Frame
bbox: BoundingBox
text: str
ocr_confidence: float
@dataclass
class BrandDetection:
brand: str
timestamp: float
duration: float
confidence: float
source: Literal["ocr", "local_vlm", "cloud_llm", "logo_match", "auxiliary"]
bbox: BoundingBox | None = None
frame_ref: int | None = None
content_type: str = ""
@dataclass
class BrandStats:
total_appearances: int = 0
total_screen_time: float = 0.0
avg_confidence: float = 0.0
first_seen: float = 0.0
last_seen: float = 0.0
@dataclass
class PipelineStats:
frames_extracted: int = 0
frames_after_scene_filter: int = 0
regions_detected: int = 0
regions_resolved_by_ocr: int = 0
regions_escalated_to_local_vlm: int = 0
regions_escalated_to_cloud_llm: int = 0
auxiliary_detections: int = 0
cloud_llm_calls: int = 0
processing_time_seconds: float = 0.0
estimated_cloud_cost_usd: float = 0.0
@dataclass
class DetectionReport:
video_source: str
content_type: str
duration_seconds: float
brands: dict[str, BrandStats] = field(default_factory=dict)
timeline: list[BrandDetection] = field(default_factory=list)
pipeline_stats: PipelineStats = field(default_factory=PipelineStats)
__all__ = [
"BoundingBox",
"BrandDetection",
"BrandStats",
"DetectionReport",
"Frame",
"PipelineStats",
"TextCandidate",
]

View File

@@ -0,0 +1,21 @@
"""
Pipeline stages.
Each stage registers its StageDefinition on import,
declaring IO (what it reads/writes from state),
config fields (what's tunable from the editor),
and serialization (how to checkpoint its outputs).
"""
from .base import (
StageDefinition,
StageIO,
StageConfigField,
register_stage,
get_stage,
list_stages,
get_palette,
)
# Populate registry with built-in stages
from . import registry # noqa: F401

101
detect/stages/base.py Normal file
View File

@@ -0,0 +1,101 @@
"""
Stage protocol — common interface for all pipeline stages.
Every stage declares:
- IO: what it reads/writes from DetectState
- Config: tunable parameters for the editor
- Serialization: how to persist/restore its own outputs
The checkpoint layer is a black box — it asks each stage to serialize its
outputs and stores the result. Stages own their data format. Binary data
(frames, crops) goes to S3 via the stage itself. The checkpoint just
stores the JSON envelope.
The graph builder uses StageIO to validate that a stage's inputs are
satisfied by previous stages' outputs.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Callable
@dataclass
class StageIO:
"""Declares what a stage reads and writes from/to DetectState."""
reads: list[str]
writes: list[str]
optional_reads: list[str] = field(default_factory=list)
@dataclass
class StageConfigField:
"""A single tunable config parameter for the editor UI."""
name: str
type: str # "float", "int", "str", "bool", "list[str]"
default: Any
description: str = ""
min: float | None = None
max: float | None = None
options: list[str] | None = None
@dataclass
class StageDefinition:
"""
Complete metadata for a pipeline stage.
The profile editor uses this to build the palette, generate config
forms, and validate graph connections. The checkpoint uses serialize_fn
and deserialize_fn to persist stage outputs without knowing the internals.
"""
name: str
label: str
description: str
io: StageIO
config_fields: list[StageConfigField] = field(default_factory=list)
category: str = "detection"
# The actual graph node function: (DetectState) → dict
fn: Callable | None = None
# Stage-owned serialization for checkpointing.
# serialize_fn: (state: dict, job_id: str) → json-compatible dict
# Stage picks its writes from state, serializes them.
# Binary data (frames) → S3 via stage, returns refs.
# deserialize_fn: (data: dict, job_id: str) → state update dict
# Stage restores its writes from the persisted data.
serialize_fn: Callable | None = None
deserialize_fn: Callable | None = None
# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
_REGISTRY: dict[str, StageDefinition] = {}
def register_stage(definition: StageDefinition):
_REGISTRY[definition.name] = definition
def get_stage(name: str) -> StageDefinition:
if name not in _REGISTRY:
raise KeyError(f"Unknown stage: {name!r}. Registered: {list(_REGISTRY)}")
return _REGISTRY[name]
def list_stages() -> list[StageDefinition]:
return list(_REGISTRY.values())
def get_palette() -> dict[str, list[StageDefinition]]:
"""Group stages by category for the editor palette."""
palette: dict[str, list[StageDefinition]] = {}
for stage in _REGISTRY.values():
if stage.category not in palette:
palette[stage.category] = []
palette[stage.category].append(stage)
return palette

View File

@@ -0,0 +1,28 @@
"""
Stage registry — registers all built-in stages.
Split by category:
preprocessing.py — extract_frames, filter_scenes
detection.py — detect_objects, run_ocr
resolution.py — match_brands
escalation.py — escalate_vlm, escalate_cloud
output.py — compile_report
_serializers.py — shared serialization helpers
"""
from . import preprocessing
from . import detection
from . import resolution
from . import escalation
from . import output
def register_all():
preprocessing.register()
detection.register()
resolution.register()
escalation.register()
output.register()
register_all()

View File

@@ -0,0 +1,25 @@
"""
Re-export serializers from core/schema/serializers/.
Stage registry modules import from here for convenience.
All serialization logic lives in core/schema/serializers/.
"""
from core.schema.serializers._common import (
safe_construct,
serialize_dataclass,
serialize_dataclass_list,
)
from core.schema.serializers.detect_pipeline import (
serialize_frame_meta,
serialize_frames_with_upload as serialize_frames,
deserialize_frames_with_download as deserialize_frames,
serialize_text_candidate,
serialize_text_candidates,
deserialize_text_candidate,
deserialize_text_candidates,
deserialize_bounding_box,
deserialize_brand_detection,
deserialize_pipeline_stats,
deserialize_detection_report,
)

View File

@@ -0,0 +1,63 @@
"""Registration for detection stages: YOLO, OCR."""
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
from ._serializers import (
serialize_dataclass_list,
serialize_text_candidates,
deserialize_bounding_box,
)
def _ser_detect(state: dict, job_id: str) -> dict:
boxes = state.get("boxes_by_frame", {})
serialized = {str(seq): serialize_dataclass_list(bl) for seq, bl in boxes.items()}
return {"boxes_by_frame": serialized}
def _deser_detect(data: dict, job_id: str) -> dict:
boxes = {}
for seq_str, box_dicts in data.get("boxes_by_frame", {}).items():
boxes[int(seq_str)] = [deserialize_bounding_box(b) for b in box_dicts]
return {"boxes_by_frame": boxes}
def _ser_ocr(state: dict, job_id: str) -> dict:
candidates = state.get("text_candidates", [])
return {"text_candidates": serialize_text_candidates(candidates)}
def _deser_ocr(data: dict, job_id: str) -> dict:
return {"_text_candidates_raw": data["text_candidates"]}
def register():
yolo = StageDefinition(
name="detect_objects",
label="Object Detection",
description="YOLO object detection on filtered frames",
category="detection",
io=StageIO(reads=["filtered_frames"], writes=["boxes_by_frame"]),
config_fields=[
StageConfigField("model_name", "str", "yolov8n.pt", "YOLO model file"),
StageConfigField("confidence_threshold", "float", 0.3, "Min detection confidence", min=0.0, max=1.0),
StageConfigField("target_classes", "list[str]", [], "YOLO classes to detect (empty = all)"),
],
serialize_fn=_ser_detect,
deserialize_fn=_deser_detect,
)
register_stage(yolo)
ocr = StageDefinition(
name="run_ocr",
label="OCR",
description="Extract text from detected regions",
category="detection",
io=StageIO(reads=["filtered_frames", "boxes_by_frame"], writes=["text_candidates"]),
config_fields=[
StageConfigField("languages", "list[str]", ["en"], "OCR languages"),
StageConfigField("min_confidence", "float", 0.5, "Min OCR confidence", min=0.0, max=1.0),
],
serialize_fn=_ser_ocr,
deserialize_fn=_deser_ocr,
)
register_stage(ocr)

View File

@@ -0,0 +1,63 @@
"""Registration for escalation stages: local VLM, cloud LLM."""
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
from ._serializers import (
serialize_dataclass_list,
serialize_text_candidates,
deserialize_brand_detection,
)
def _ser_escalation(state: dict, job_id: str) -> dict:
matched = state.get("detections", [])
unresolved = state.get("unresolved_candidates", [])
return {
"detections": serialize_dataclass_list(matched),
"unresolved_candidates": serialize_text_candidates(unresolved),
}
def _deser_escalation(data: dict, job_id: str) -> dict:
detections = [deserialize_brand_detection(d) for d in data.get("detections", [])]
return {
"detections": detections,
"_unresolved_raw": data.get("unresolved_candidates", []),
}
def register():
vlm = StageDefinition(
name="escalate_vlm",
label="Local VLM",
description="Process unresolved crops with moondream2",
category="escalation",
io=StageIO(
reads=["unresolved_candidates"],
writes=["detections", "unresolved_candidates"],
optional_reads=["source_asset_id"],
),
config_fields=[
StageConfigField("min_confidence", "float", 0.5, "Min VLM confidence", min=0.0, max=1.0),
],
serialize_fn=_ser_escalation,
deserialize_fn=_deser_escalation,
)
register_stage(vlm)
cloud = StageDefinition(
name="escalate_cloud",
label="Cloud LLM",
description="Escalate remaining crops to cloud provider",
category="escalation",
io=StageIO(
reads=["unresolved_candidates"],
writes=["detections"],
optional_reads=["source_asset_id"],
),
config_fields=[
StageConfigField("min_confidence", "float", 0.4, "Min cloud confidence", min=0.0, max=1.0),
],
serialize_fn=_ser_escalation,
deserialize_fn=_deser_escalation,
)
register_stage(cloud)

View File

@@ -0,0 +1,32 @@
"""Registration for output stages: report compilation."""
from detect.stages.base import StageDefinition, StageIO, register_stage
from ._serializers import serialize_dataclass, deserialize_detection_report
def _ser_report(state: dict, job_id: str) -> dict:
report = state.get("report")
if report is None:
return {"report": None}
return {"report": serialize_dataclass(report)}
def _deser_report(data: dict, job_id: str) -> dict:
raw = data.get("report")
if raw is None:
return {"report": None}
return {"report": deserialize_detection_report(raw)}
def register():
report = StageDefinition(
name="compile_report",
label="Report",
description="Merge detections and compile final report",
category="output",
io=StageIO(reads=["detections"], writes=["report"]),
config_fields=[],
serialize_fn=_ser_report,
deserialize_fn=_deser_report,
)
register_stage(report)

View File

@@ -0,0 +1,57 @@
"""Registration for preprocessing stages: frame extraction, scene filter."""
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
from ._serializers import serialize_frames, deserialize_frames
def _ser_extract(state: dict, job_id: str) -> dict:
frames = state.get("frames", [])
meta, manifest = serialize_frames(frames, job_id)
return {"frames_meta": meta, "frames_manifest": manifest}
def _deser_extract(data: dict, job_id: str) -> dict:
frames = deserialize_frames(data["frames_meta"], data["frames_manifest"], job_id)
return {"frames": frames}
def _ser_filter(state: dict, job_id: str) -> dict:
filtered = state.get("filtered_frames", [])
seqs = [f.sequence for f in filtered]
return {"filtered_frame_sequences": seqs}
def _deser_filter(data: dict, job_id: str) -> dict:
return {"_filtered_sequences": data["filtered_frame_sequences"]}
def register():
extract = StageDefinition(
name="extract_frames",
label="Frame Extraction",
description="Extract frames from video at configurable FPS",
category="preprocessing",
io=StageIO(reads=["video_path"], writes=["frames"]),
config_fields=[
StageConfigField("fps", "float", 2.0, "Frames per second", min=0.1, max=30.0),
StageConfigField("max_frames", "int", 500, "Maximum frames to extract", min=1, max=10000),
],
serialize_fn=_ser_extract,
deserialize_fn=_deser_extract,
)
register_stage(extract)
scene_filter = StageDefinition(
name="filter_scenes",
label="Scene Filter",
description="Deduplicate similar frames using perceptual hashing",
category="preprocessing",
io=StageIO(reads=["frames"], writes=["filtered_frames"]),
config_fields=[
StageConfigField("hamming_threshold", "int", 8, "Hamming distance threshold", min=0, max=64),
StageConfigField("enabled", "bool", True, "Enable scene filtering"),
],
serialize_fn=_ser_filter,
deserialize_fn=_deser_filter,
)
register_stage(scene_filter)

View File

@@ -0,0 +1,45 @@
"""Registration for resolution stages: brand resolver."""
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
from ._serializers import (
serialize_dataclass_list,
serialize_text_candidates,
deserialize_brand_detection,
)
def _ser_brands(state: dict, job_id: str) -> dict:
matched = state.get("detections", [])
unresolved = state.get("unresolved_candidates", [])
return {
"detections": serialize_dataclass_list(matched),
"unresolved_candidates": serialize_text_candidates(unresolved),
}
def _deser_brands(data: dict, job_id: str) -> dict:
detections = [deserialize_brand_detection(d) for d in data.get("detections", [])]
return {
"detections": detections,
"_unresolved_raw": data.get("unresolved_candidates", []),
}
def register():
resolver = StageDefinition(
name="match_brands",
label="Brand Resolver",
description="Match OCR text against known brands (session + global DB)",
category="resolution",
io=StageIO(
reads=["text_candidates"],
writes=["detections", "unresolved_candidates"],
optional_reads=["session_brands", "source_asset_id"],
),
config_fields=[
StageConfigField("fuzzy_threshold", "int", 75, "Fuzzy match threshold", min=0, max=100),
],
serialize_fn=_ser_brands,
deserialize_fn=_deser_brands,
)
register_stage(resolver)

View File

@@ -1,26 +1,30 @@
"""Tests for checkpoint serialization — round-trip without S3."""
"""Tests for checkpoint serialization — unit tests without S3."""
import json
import numpy as np
import pytest
from detect.models import BoundingBox, BrandDetection, Frame, PipelineStats, TextCandidate
from detect.checkpoint.serializer import (
serialize_state,
deserialize_state,
from core.schema.serializers._common import safe_construct
from core.schema.serializers.detect_pipeline import (
serialize_frame_meta,
serialize_text_candidate,
serialize_text_candidates,
serialize_dataclass_list,
deserialize_text_candidate,
deserialize_text_candidates,
deserialize_bounding_box,
deserialize_brand_detection,
deserialize_pipeline_stats,
)
def _make_frame(seq: int = 0, w: int = 100, h: int = 80) -> Frame:
image = np.random.randint(0, 255, (h, w, 3), dtype=np.uint8)
return Frame(
sequence=seq,
chunk_id=0,
timestamp=float(seq) * 0.5,
image=image,
perceptual_hash=f"hash_{seq}",
sequence=seq, chunk_id=0, timestamp=float(seq) * 0.5,
image=image, perceptual_hash=f"hash_{seq}",
)
@@ -35,13 +39,8 @@ def _make_candidate(frame: Frame, text: str = "NIKE") -> TextCandidate:
def _make_detection(brand: str = "Nike", timestamp: float = 1.0) -> BrandDetection:
return BrandDetection(
brand=brand,
timestamp=timestamp,
duration=0.5,
confidence=0.92,
source="ocr",
content_type="soccer_broadcast",
frame_ref=0,
brand=brand, timestamp=timestamp, duration=0.5,
confidence=0.92, source="ocr", content_type="soccer_broadcast", frame_ref=0,
)
@@ -81,102 +80,84 @@ def test_deserialize_text_candidate():
assert restored.text == "ADIDAS"
assert restored.ocr_confidence == 0.85
assert restored.frame is frame # same object reference
assert restored.frame is frame
assert restored.bbox.x == 10
# --- Full state round-trip ---
# --- BoundingBox ---
def test_state_round_trip():
frames = [_make_frame(seq=i) for i in range(3)]
filtered = frames[:2]
def test_bounding_box_round_trip():
box = _make_box(x=15, y=25, w=40, h=30)
serialized = serialize_dataclass_list([box])[0]
restored = deserialize_bounding_box(serialized)
box = _make_box()
boxes_by_frame = {0: [box], 1: [box]}
assert restored.x == 15
assert restored.w == 40
assert restored.confidence == 0.9
candidates = [_make_candidate(frames[0], "NIKE"), _make_candidate(frames[1], "EMIRATES")]
unresolved = [_make_candidate(frames[2], "unknown")]
detections = [_make_detection("Nike", 0.5), _make_detection("Emirates", 1.0)]
# --- BrandDetection ---
def test_brand_detection_round_trip():
det = _make_detection("Emirates", 3.5)
serialized = serialize_dataclass_list([det])[0]
restored = deserialize_brand_detection(serialized)
assert restored.brand == "Emirates"
assert restored.timestamp == 3.5
assert restored.source == "ocr"
# --- PipelineStats ---
def test_pipeline_stats_round_trip():
stats = PipelineStats(
frames_extracted=3,
frames_after_scene_filter=2,
regions_detected=2,
regions_resolved_by_ocr=2,
cloud_llm_calls=1,
estimated_cloud_cost_usd=0.003,
frames_extracted=120, cloud_llm_calls=3,
estimated_cloud_cost_usd=0.005,
)
serialized = serialize_dataclass_list([stats])[0]
restored = deserialize_pipeline_stats(serialized)
state = {
"job_id": "test-123",
"video_path": "/tmp/test.mp4",
"profile_name": "soccer_broadcast",
"config_overrides": {"ocr": {"min_confidence": 0.3}},
"frames": frames,
"filtered_frames": filtered,
"boxes_by_frame": boxes_by_frame,
"text_candidates": candidates,
"unresolved_candidates": unresolved,
"detections": detections,
"stats": stats,
assert restored.frames_extracted == 120
assert restored.cloud_llm_calls == 3
assert restored.estimated_cloud_cost_usd == 0.005
# --- safe_construct tolerates schema changes ---
def test_safe_construct_ignores_unknown_fields():
data = {"x": 1, "y": 2, "w": 3, "h": 4, "confidence": 0.5, "label": "t", "extra_field": 99}
box = safe_construct(BoundingBox, data)
assert box.x == 1
assert not hasattr(box, "extra_field")
def test_safe_construct_uses_defaults():
data = {"frames_extracted": 50}
stats = safe_construct(PipelineStats, data)
assert stats.frames_extracted == 50
assert stats.cloud_llm_calls == 0 # default
# --- JSON compatibility ---
def test_all_serialized_is_json_compatible():
frame = _make_frame()
candidate = _make_candidate(frame)
detection = _make_detection()
stats = PipelineStats(frames_extracted=10)
all_data = {
"frame_meta": serialize_frame_meta(frame),
"candidate": serialize_text_candidate(candidate),
"detection": serialize_dataclass_list([detection])[0],
"stats": serialize_dataclass_list([stats])[0],
}
manifest = {f.sequence: f"s3://fake/frames/{f.sequence}.jpg" for f in frames}
# Serialize
serialized = serialize_state(state, manifest)
# Verify JSON-compatible (no numpy, no Frame objects)
import json
json_str = json.dumps(serialized, default=str)
json_str = json.dumps(all_data, default=str)
assert len(json_str) > 0
# Deserialize with the original frames (simulating frame load from S3)
restored = deserialize_state(serialized, frames)
# Verify round-trip
assert restored["job_id"] == "test-123"
assert restored["video_path"] == "/tmp/test.mp4"
assert restored["profile_name"] == "soccer_broadcast"
assert restored["config_overrides"] == {"ocr": {"min_confidence": 0.3}}
assert len(restored["frames"]) == 3
assert len(restored["filtered_frames"]) == 2
assert len(restored["boxes_by_frame"]) == 2
assert len(restored["text_candidates"]) == 2
assert len(restored["unresolved_candidates"]) == 1
assert len(restored["detections"]) == 2
restored_stats = restored["stats"]
assert restored_stats.frames_extracted == 3
assert restored_stats.cloud_llm_calls == 1
assert restored_stats.estimated_cloud_cost_usd == 0.003
# TextCandidate frame references should point to actual Frame objects
tc = restored["text_candidates"][0]
assert tc.frame is frames[0]
assert tc.text == "NIKE"
def test_state_round_trip_empty():
"""Empty state should serialize/deserialize cleanly."""
state = {
"job_id": "empty-job",
"video_path": "",
"profile_name": "soccer_broadcast",
"frames": [],
"filtered_frames": [],
"boxes_by_frame": {},
"text_candidates": [],
"unresolved_candidates": [],
"detections": [],
"stats": PipelineStats(),
}
serialized = serialize_state(state, {})
restored = deserialize_state(serialized, [])
assert restored["job_id"] == "empty-job"
assert len(restored["frames"]) == 0
assert len(restored["detections"]) == 0
assert restored["stats"].frames_extracted == 0
roundtrip = json.loads(json_str)
assert roundtrip["frame_meta"]["sequence"] == frame.sequence

View File

@@ -0,0 +1,58 @@
"""Tests for the stage registry."""
from detect.stages import list_stages, get_stage, get_palette
EXPECTED_STAGES = [
"extract_frames", "filter_scenes", "detect_objects", "run_ocr",
"match_brands", "escalate_vlm", "escalate_cloud", "compile_report",
]
def test_all_stages_registered():
stages = list_stages()
names = [s.name for s in stages]
for expected in EXPECTED_STAGES:
assert expected in names, f"Missing stage: {expected}"
def test_stage_has_io():
for name in EXPECTED_STAGES:
stage = get_stage(name)
assert len(stage.io.writes) > 0, f"{name} has no writes"
assert stage.label, f"{name} has no label"
assert stage.description, f"{name} has no description"
def test_stage_has_serialization():
for name in EXPECTED_STAGES:
stage = get_stage(name)
assert stage.serialize_fn is not None, f"{name} has no serialize_fn"
assert stage.deserialize_fn is not None, f"{name} has no deserialize_fn"
def test_palette_groups():
palette = get_palette()
assert len(palette) > 0
all_stages = []
for category, stages in palette.items():
assert isinstance(category, str)
all_stages.extend(stages)
assert len(all_stages) == len(EXPECTED_STAGES)
def test_io_chain_valid():
"""Each stage's reads should be satisfied by previous stages' writes."""
stages = [get_stage(name) for name in EXPECTED_STAGES]
available_keys = {"video_path", "job_id", "profile_name",
"source_asset_id", "session_brands", "stats",
"config_overrides"}
for stage in stages:
for read_key in stage.io.reads:
assert read_key in available_keys, (
f"Stage {stage.name} reads '{read_key}' but no previous stage writes it. "
f"Available: {available_keys}"
)
for write_key in stage.io.writes:
available_keys.add(write_key)