schema clean up and refactor

This commit is contained in:
2026-03-26 05:14:33 -03:00
parent 08c58a6a9d
commit d58a90157a
17 changed files with 930 additions and 287 deletions

View File

@@ -1,133 +1,108 @@
"""State serialization — DetectState ↔ JSON-compatible dict."""
"""
State serialization — DetectState ↔ JSON-compatible dict.
Delegates to each stage's serialize_fn/deserialize_fn via the registry.
This file has no model-specific knowledge — stages own their data format.
The only things serialized here are the "envelope" fields (job_id, video_path, etc.)
that don't belong to any stage.
"""
from __future__ import annotations
import dataclasses
from detect.models import (
BoundingBox,
BrandDetection,
Frame,
PipelineStats,
TextCandidate,
from core.schema.serializers._common import serialize_dataclass
from core.schema.serializers.detect_pipeline import (
deserialize_pipeline_stats,
deserialize_text_candidates,
)
# ---------------------------------------------------------------------------
# Serialize helpers
# ---------------------------------------------------------------------------
def serialize_frame_meta(frame: Frame) -> dict:
meta = {
"sequence": frame.sequence,
"chunk_id": frame.chunk_id,
"timestamp": frame.timestamp,
"perceptual_hash": frame.perceptual_hash,
}
return meta
def serialize_text_candidate(tc: TextCandidate) -> dict:
bbox_dict = dataclasses.asdict(tc.bbox)
candidate = {
"frame_sequence": tc.frame.sequence,
"bbox": bbox_dict,
"text": tc.text,
"ocr_confidence": tc.ocr_confidence,
}
return candidate
# Envelope fields — not owned by any stage, always present
ENVELOPE_KEYS = ["job_id", "video_path", "profile_name", "config_overrides"]
def serialize_state(state: dict, frames_manifest: dict[int, str]) -> dict:
"""
Serialize DetectState to a JSON-compatible dict.
Frame images are replaced with S3 key references.
TextCandidate.frame references become frame_sequence integers.
Calls each registered stage's serialize_fn for stage-owned data.
Envelope fields (job_id, etc.) are copied directly.
"""
frames = state.get("frames", [])
filtered = state.get("filtered_frames", [])
from detect.stages.base import _REGISTRY
manifest_strs = {str(k): v for k, v in frames_manifest.items()}
frames_meta = [serialize_frame_meta(f) for f in frames]
filtered_seqs = [f.sequence for f in filtered]
checkpoint = {}
boxes_serialized = {}
for seq, boxes in state.get("boxes_by_frame", {}).items():
boxes_serialized[str(seq)] = [dataclasses.asdict(b) for b in boxes]
# Envelope
for key in ENVELOPE_KEYS:
default = {} if key == "config_overrides" else ""
checkpoint[key] = state.get(key, default)
text_candidates = [serialize_text_candidate(tc) for tc in state.get("text_candidates", [])]
unresolved = [serialize_text_candidate(tc) for tc in state.get("unresolved_candidates", [])]
detections = [dataclasses.asdict(d) for d in state.get("detections", [])]
stats = dataclasses.asdict(state.get("stats", PipelineStats()))
# Frames manifest (needed by frame-loading stages)
checkpoint["frames_manifest"] = {str(k): v for k, v in frames_manifest.items()}
# Stats (shared across stages, not owned by one)
stats = state.get("stats")
if stats is not None:
checkpoint["stats"] = serialize_dataclass(stats)
else:
checkpoint["stats"] = {}
# Per-stage data
for name, stage_def in _REGISTRY.items():
if stage_def.serialize_fn is None:
continue
job_id = state.get("job_id", "")
stage_data = stage_def.serialize_fn(state, job_id)
checkpoint[f"stage_{name}"] = stage_data
checkpoint = {
"job_id": state.get("job_id", ""),
"video_path": state.get("video_path", ""),
"profile_name": state.get("profile_name", ""),
"config_overrides": state.get("config_overrides", {}),
"frames_manifest": manifest_strs,
"frames_meta": frames_meta,
"filtered_frame_sequences": filtered_seqs,
"boxes_by_frame": boxes_serialized,
"text_candidates": text_candidates,
"unresolved_candidates": unresolved,
"detections": detections,
"stats": stats,
}
return checkpoint
# ---------------------------------------------------------------------------
# Deserialize helpers
# ---------------------------------------------------------------------------
def deserialize_state(checkpoint: dict, frames: list) -> dict:
"""
Reconstitute DetectState from a checkpoint dict + loaded frames.
def deserialize_text_candidate(d: dict, frame_map: dict[int, Frame]) -> TextCandidate:
frame = frame_map[d["frame_sequence"]]
bbox = BoundingBox(**d["bbox"])
candidate = TextCandidate(
frame=frame,
bbox=bbox,
text=d["text"],
ocr_confidence=d["ocr_confidence"],
)
return candidate
Calls each stage's deserialize_fn to restore stage-owned data.
"""
from detect.stages.base import _REGISTRY
def deserialize_state(checkpoint: dict, frames: list[Frame]) -> dict:
"""Reconstitute DetectState from a checkpoint dict + loaded frames."""
frame_map = {f.sequence: f for f in frames}
filtered_seqs = set(checkpoint.get("filtered_frame_sequences", []))
filtered_frames = [f for f in frames if f.sequence in filtered_seqs]
state = {}
boxes_by_frame = {}
for seq_str, box_dicts in checkpoint.get("boxes_by_frame", {}).items():
seq = int(seq_str)
boxes_by_frame[seq] = [BoundingBox(**b) for b in box_dicts]
# Envelope
for key in ENVELOPE_KEYS:
default = {} if key == "config_overrides" else ""
state[key] = checkpoint.get(key, default)
text_candidates = [
deserialize_text_candidate(d, frame_map)
for d in checkpoint.get("text_candidates", [])
]
unresolved_candidates = [
deserialize_text_candidate(d, frame_map)
for d in checkpoint.get("unresolved_candidates", [])
]
detections = [BrandDetection(**d) for d in checkpoint.get("detections", [])]
stats = PipelineStats(**checkpoint.get("stats", {}))
# Frames (always present, loaded externally)
state["frames"] = frames
# Stats
state["stats"] = deserialize_pipeline_stats(checkpoint.get("stats", {}))
# Per-stage data
for name, stage_def in _REGISTRY.items():
if stage_def.deserialize_fn is None:
continue
stage_key = f"stage_{name}"
if stage_key not in checkpoint:
continue
job_id = state.get("job_id", "")
stage_data = stage_def.deserialize_fn(checkpoint[stage_key], job_id)
for k, v in stage_data.items():
if k == "_filtered_sequences":
# Reconnect filtered frames from sequence list
seq_set = set(v)
state["filtered_frames"] = [f for f in frames if f.sequence in seq_set]
elif k.endswith("_raw"):
# Raw text candidates need frame reference reconnection
real_key = k.removeprefix("_").removesuffix("_raw")
state[real_key] = deserialize_text_candidates(v, frame_map)
else:
state[k] = v
state = {
"job_id": checkpoint.get("job_id", ""),
"video_path": checkpoint.get("video_path", ""),
"profile_name": checkpoint.get("profile_name", ""),
"config_overrides": checkpoint.get("config_overrides", {}),
"frames": frames,
"filtered_frames": filtered_frames,
"boxes_by_frame": boxes_by_frame,
"text_candidates": text_candidates,
"unresolved_candidates": unresolved_candidates,
"detections": detections,
"stats": stats,
}
return state

View File

@@ -1,86 +1,26 @@
"""
Core domain models for the detection pipeline.
Re-export pipeline runtime models from core/schema/models/detect_pipeline.py.
These are pipeline-internal models — the data structures that flow
between LangGraph nodes. SSE event payloads (sse_contract.py) are
derived from these when emitting to the UI.
All models are defined in core/schema/ — this module exists for backward
compatibility so existing imports (from detect.models import Frame) keep working.
"""
from __future__ import annotations
from core.schema.models.detect_pipeline import (
BoundingBox,
BrandDetection,
BrandStats,
DetectionReport,
Frame,
PipelineStats,
TextCandidate,
)
from dataclasses import dataclass, field
from typing import Literal
import numpy as np
@dataclass
class Frame:
sequence: int
chunk_id: int
timestamp: float # position in video (seconds)
image: np.ndarray
perceptual_hash: str = ""
@dataclass
class BoundingBox:
x: int
y: int
w: int
h: int
confidence: float
label: str
@dataclass
class TextCandidate:
frame: Frame
bbox: BoundingBox
text: str
ocr_confidence: float
@dataclass
class BrandDetection:
brand: str
timestamp: float
duration: float
confidence: float
source: Literal["ocr", "local_vlm", "cloud_llm", "logo_match", "auxiliary"]
bbox: BoundingBox | None = None
frame_ref: int | None = None
content_type: str = ""
@dataclass
class BrandStats:
total_appearances: int = 0
total_screen_time: float = 0.0
avg_confidence: float = 0.0
first_seen: float = 0.0
last_seen: float = 0.0
@dataclass
class PipelineStats:
frames_extracted: int = 0
frames_after_scene_filter: int = 0
regions_detected: int = 0
regions_resolved_by_ocr: int = 0
regions_escalated_to_local_vlm: int = 0
regions_escalated_to_cloud_llm: int = 0
auxiliary_detections: int = 0
cloud_llm_calls: int = 0
processing_time_seconds: float = 0.0
estimated_cloud_cost_usd: float = 0.0
@dataclass
class DetectionReport:
video_source: str
content_type: str
duration_seconds: float
brands: dict[str, BrandStats] = field(default_factory=dict)
timeline: list[BrandDetection] = field(default_factory=list)
pipeline_stats: PipelineStats = field(default_factory=PipelineStats)
__all__ = [
"BoundingBox",
"BrandDetection",
"BrandStats",
"DetectionReport",
"Frame",
"PipelineStats",
"TextCandidate",
]

View File

@@ -0,0 +1,21 @@
"""
Pipeline stages.
Each stage registers its StageDefinition on import,
declaring IO (what it reads/writes from state),
config fields (what's tunable from the editor),
and serialization (how to checkpoint its outputs).
"""
from .base import (
StageDefinition,
StageIO,
StageConfigField,
register_stage,
get_stage,
list_stages,
get_palette,
)
# Populate registry with built-in stages
from . import registry # noqa: F401

101
detect/stages/base.py Normal file
View File

@@ -0,0 +1,101 @@
"""
Stage protocol — common interface for all pipeline stages.
Every stage declares:
- IO: what it reads/writes from DetectState
- Config: tunable parameters for the editor
- Serialization: how to persist/restore its own outputs
The checkpoint layer is a black box — it asks each stage to serialize its
outputs and stores the result. Stages own their data format. Binary data
(frames, crops) goes to S3 via the stage itself. The checkpoint just
stores the JSON envelope.
The graph builder uses StageIO to validate that a stage's inputs are
satisfied by previous stages' outputs.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Callable
@dataclass
class StageIO:
"""Declares what a stage reads and writes from/to DetectState."""
reads: list[str]
writes: list[str]
optional_reads: list[str] = field(default_factory=list)
@dataclass
class StageConfigField:
"""A single tunable config parameter for the editor UI."""
name: str
type: str # "float", "int", "str", "bool", "list[str]"
default: Any
description: str = ""
min: float | None = None
max: float | None = None
options: list[str] | None = None
@dataclass
class StageDefinition:
"""
Complete metadata for a pipeline stage.
The profile editor uses this to build the palette, generate config
forms, and validate graph connections. The checkpoint uses serialize_fn
and deserialize_fn to persist stage outputs without knowing the internals.
"""
name: str
label: str
description: str
io: StageIO
config_fields: list[StageConfigField] = field(default_factory=list)
category: str = "detection"
# The actual graph node function: (DetectState) → dict
fn: Callable | None = None
# Stage-owned serialization for checkpointing.
# serialize_fn: (state: dict, job_id: str) → json-compatible dict
# Stage picks its writes from state, serializes them.
# Binary data (frames) → S3 via stage, returns refs.
# deserialize_fn: (data: dict, job_id: str) → state update dict
# Stage restores its writes from the persisted data.
serialize_fn: Callable | None = None
deserialize_fn: Callable | None = None
# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
_REGISTRY: dict[str, StageDefinition] = {}
def register_stage(definition: StageDefinition):
_REGISTRY[definition.name] = definition
def get_stage(name: str) -> StageDefinition:
if name not in _REGISTRY:
raise KeyError(f"Unknown stage: {name!r}. Registered: {list(_REGISTRY)}")
return _REGISTRY[name]
def list_stages() -> list[StageDefinition]:
return list(_REGISTRY.values())
def get_palette() -> dict[str, list[StageDefinition]]:
"""Group stages by category for the editor palette."""
palette: dict[str, list[StageDefinition]] = {}
for stage in _REGISTRY.values():
if stage.category not in palette:
palette[stage.category] = []
palette[stage.category].append(stage)
return palette

View File

@@ -0,0 +1,28 @@
"""
Stage registry — registers all built-in stages.
Split by category:
preprocessing.py — extract_frames, filter_scenes
detection.py — detect_objects, run_ocr
resolution.py — match_brands
escalation.py — escalate_vlm, escalate_cloud
output.py — compile_report
_serializers.py — shared serialization helpers
"""
from . import preprocessing
from . import detection
from . import resolution
from . import escalation
from . import output
def register_all():
preprocessing.register()
detection.register()
resolution.register()
escalation.register()
output.register()
register_all()

View File

@@ -0,0 +1,25 @@
"""
Re-export serializers from core/schema/serializers/.
Stage registry modules import from here for convenience.
All serialization logic lives in core/schema/serializers/.
"""
from core.schema.serializers._common import (
safe_construct,
serialize_dataclass,
serialize_dataclass_list,
)
from core.schema.serializers.detect_pipeline import (
serialize_frame_meta,
serialize_frames_with_upload as serialize_frames,
deserialize_frames_with_download as deserialize_frames,
serialize_text_candidate,
serialize_text_candidates,
deserialize_text_candidate,
deserialize_text_candidates,
deserialize_bounding_box,
deserialize_brand_detection,
deserialize_pipeline_stats,
deserialize_detection_report,
)

View File

@@ -0,0 +1,63 @@
"""Registration for detection stages: YOLO, OCR."""
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
from ._serializers import (
serialize_dataclass_list,
serialize_text_candidates,
deserialize_bounding_box,
)
def _ser_detect(state: dict, job_id: str) -> dict:
boxes = state.get("boxes_by_frame", {})
serialized = {str(seq): serialize_dataclass_list(bl) for seq, bl in boxes.items()}
return {"boxes_by_frame": serialized}
def _deser_detect(data: dict, job_id: str) -> dict:
boxes = {}
for seq_str, box_dicts in data.get("boxes_by_frame", {}).items():
boxes[int(seq_str)] = [deserialize_bounding_box(b) for b in box_dicts]
return {"boxes_by_frame": boxes}
def _ser_ocr(state: dict, job_id: str) -> dict:
candidates = state.get("text_candidates", [])
return {"text_candidates": serialize_text_candidates(candidates)}
def _deser_ocr(data: dict, job_id: str) -> dict:
return {"_text_candidates_raw": data["text_candidates"]}
def register():
yolo = StageDefinition(
name="detect_objects",
label="Object Detection",
description="YOLO object detection on filtered frames",
category="detection",
io=StageIO(reads=["filtered_frames"], writes=["boxes_by_frame"]),
config_fields=[
StageConfigField("model_name", "str", "yolov8n.pt", "YOLO model file"),
StageConfigField("confidence_threshold", "float", 0.3, "Min detection confidence", min=0.0, max=1.0),
StageConfigField("target_classes", "list[str]", [], "YOLO classes to detect (empty = all)"),
],
serialize_fn=_ser_detect,
deserialize_fn=_deser_detect,
)
register_stage(yolo)
ocr = StageDefinition(
name="run_ocr",
label="OCR",
description="Extract text from detected regions",
category="detection",
io=StageIO(reads=["filtered_frames", "boxes_by_frame"], writes=["text_candidates"]),
config_fields=[
StageConfigField("languages", "list[str]", ["en"], "OCR languages"),
StageConfigField("min_confidence", "float", 0.5, "Min OCR confidence", min=0.0, max=1.0),
],
serialize_fn=_ser_ocr,
deserialize_fn=_deser_ocr,
)
register_stage(ocr)

View File

@@ -0,0 +1,63 @@
"""Registration for escalation stages: local VLM, cloud LLM."""
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
from ._serializers import (
serialize_dataclass_list,
serialize_text_candidates,
deserialize_brand_detection,
)
def _ser_escalation(state: dict, job_id: str) -> dict:
matched = state.get("detections", [])
unresolved = state.get("unresolved_candidates", [])
return {
"detections": serialize_dataclass_list(matched),
"unresolved_candidates": serialize_text_candidates(unresolved),
}
def _deser_escalation(data: dict, job_id: str) -> dict:
detections = [deserialize_brand_detection(d) for d in data.get("detections", [])]
return {
"detections": detections,
"_unresolved_raw": data.get("unresolved_candidates", []),
}
def register():
vlm = StageDefinition(
name="escalate_vlm",
label="Local VLM",
description="Process unresolved crops with moondream2",
category="escalation",
io=StageIO(
reads=["unresolved_candidates"],
writes=["detections", "unresolved_candidates"],
optional_reads=["source_asset_id"],
),
config_fields=[
StageConfigField("min_confidence", "float", 0.5, "Min VLM confidence", min=0.0, max=1.0),
],
serialize_fn=_ser_escalation,
deserialize_fn=_deser_escalation,
)
register_stage(vlm)
cloud = StageDefinition(
name="escalate_cloud",
label="Cloud LLM",
description="Escalate remaining crops to cloud provider",
category="escalation",
io=StageIO(
reads=["unresolved_candidates"],
writes=["detections"],
optional_reads=["source_asset_id"],
),
config_fields=[
StageConfigField("min_confidence", "float", 0.4, "Min cloud confidence", min=0.0, max=1.0),
],
serialize_fn=_ser_escalation,
deserialize_fn=_deser_escalation,
)
register_stage(cloud)

View File

@@ -0,0 +1,32 @@
"""Registration for output stages: report compilation."""
from detect.stages.base import StageDefinition, StageIO, register_stage
from ._serializers import serialize_dataclass, deserialize_detection_report
def _ser_report(state: dict, job_id: str) -> dict:
report = state.get("report")
if report is None:
return {"report": None}
return {"report": serialize_dataclass(report)}
def _deser_report(data: dict, job_id: str) -> dict:
raw = data.get("report")
if raw is None:
return {"report": None}
return {"report": deserialize_detection_report(raw)}
def register():
report = StageDefinition(
name="compile_report",
label="Report",
description="Merge detections and compile final report",
category="output",
io=StageIO(reads=["detections"], writes=["report"]),
config_fields=[],
serialize_fn=_ser_report,
deserialize_fn=_deser_report,
)
register_stage(report)

View File

@@ -0,0 +1,57 @@
"""Registration for preprocessing stages: frame extraction, scene filter."""
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
from ._serializers import serialize_frames, deserialize_frames
def _ser_extract(state: dict, job_id: str) -> dict:
frames = state.get("frames", [])
meta, manifest = serialize_frames(frames, job_id)
return {"frames_meta": meta, "frames_manifest": manifest}
def _deser_extract(data: dict, job_id: str) -> dict:
frames = deserialize_frames(data["frames_meta"], data["frames_manifest"], job_id)
return {"frames": frames}
def _ser_filter(state: dict, job_id: str) -> dict:
filtered = state.get("filtered_frames", [])
seqs = [f.sequence for f in filtered]
return {"filtered_frame_sequences": seqs}
def _deser_filter(data: dict, job_id: str) -> dict:
return {"_filtered_sequences": data["filtered_frame_sequences"]}
def register():
extract = StageDefinition(
name="extract_frames",
label="Frame Extraction",
description="Extract frames from video at configurable FPS",
category="preprocessing",
io=StageIO(reads=["video_path"], writes=["frames"]),
config_fields=[
StageConfigField("fps", "float", 2.0, "Frames per second", min=0.1, max=30.0),
StageConfigField("max_frames", "int", 500, "Maximum frames to extract", min=1, max=10000),
],
serialize_fn=_ser_extract,
deserialize_fn=_deser_extract,
)
register_stage(extract)
scene_filter = StageDefinition(
name="filter_scenes",
label="Scene Filter",
description="Deduplicate similar frames using perceptual hashing",
category="preprocessing",
io=StageIO(reads=["frames"], writes=["filtered_frames"]),
config_fields=[
StageConfigField("hamming_threshold", "int", 8, "Hamming distance threshold", min=0, max=64),
StageConfigField("enabled", "bool", True, "Enable scene filtering"),
],
serialize_fn=_ser_filter,
deserialize_fn=_deser_filter,
)
register_stage(scene_filter)

View File

@@ -0,0 +1,45 @@
"""Registration for resolution stages: brand resolver."""
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
from ._serializers import (
serialize_dataclass_list,
serialize_text_candidates,
deserialize_brand_detection,
)
def _ser_brands(state: dict, job_id: str) -> dict:
matched = state.get("detections", [])
unresolved = state.get("unresolved_candidates", [])
return {
"detections": serialize_dataclass_list(matched),
"unresolved_candidates": serialize_text_candidates(unresolved),
}
def _deser_brands(data: dict, job_id: str) -> dict:
detections = [deserialize_brand_detection(d) for d in data.get("detections", [])]
return {
"detections": detections,
"_unresolved_raw": data.get("unresolved_candidates", []),
}
def register():
resolver = StageDefinition(
name="match_brands",
label="Brand Resolver",
description="Match OCR text against known brands (session + global DB)",
category="resolution",
io=StageIO(
reads=["text_candidates"],
writes=["detections", "unresolved_candidates"],
optional_reads=["session_brands", "source_asset_id"],
),
config_fields=[
StageConfigField("fuzzy_threshold", "int", 75, "Fuzzy match threshold", min=0, max=100),
],
serialize_fn=_ser_brands,
deserialize_fn=_deser_brands,
)
register_stage(resolver)