schema clean up and refactor
This commit is contained in:
97
core/schema/models/detect_pipeline.py
Normal file
97
core/schema/models/detect_pipeline.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Detection pipeline runtime models.
|
||||
|
||||
These are the data structures that flow between LangGraph nodes.
|
||||
They contain runtime types (np.ndarray) so they are NOT generated
|
||||
by modelgen — they live here for the schema to be the complete
|
||||
map of the application, but modelgen skips them.
|
||||
|
||||
Wire-format models (SSE events) are in detect.py.
|
||||
DB models (jobs, checkpoints) are in detect_jobs.py.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class Frame:
|
||||
sequence: int
|
||||
chunk_id: int
|
||||
timestamp: float # position in video (seconds)
|
||||
image: np.ndarray
|
||||
perceptual_hash: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BoundingBox:
|
||||
x: int
|
||||
y: int
|
||||
w: int
|
||||
h: int
|
||||
confidence: float
|
||||
label: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextCandidate:
|
||||
frame: Frame
|
||||
bbox: BoundingBox
|
||||
text: str
|
||||
ocr_confidence: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class BrandDetection:
|
||||
brand: str
|
||||
timestamp: float
|
||||
duration: float
|
||||
confidence: float
|
||||
source: Literal["ocr", "local_vlm", "cloud_llm", "logo_match", "auxiliary"]
|
||||
bbox: BoundingBox | None = None
|
||||
frame_ref: int | None = None
|
||||
content_type: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BrandStats:
|
||||
total_appearances: int = 0
|
||||
total_screen_time: float = 0.0
|
||||
avg_confidence: float = 0.0
|
||||
first_seen: float = 0.0
|
||||
last_seen: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineStats:
|
||||
frames_extracted: int = 0
|
||||
frames_after_scene_filter: int = 0
|
||||
regions_detected: int = 0
|
||||
regions_resolved_by_ocr: int = 0
|
||||
regions_escalated_to_local_vlm: int = 0
|
||||
regions_escalated_to_cloud_llm: int = 0
|
||||
auxiliary_detections: int = 0
|
||||
cloud_llm_calls: int = 0
|
||||
processing_time_seconds: float = 0.0
|
||||
estimated_cloud_cost_usd: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectionReport:
|
||||
video_source: str
|
||||
content_type: str
|
||||
duration_seconds: float
|
||||
brands: dict[str, BrandStats] = field(default_factory=dict)
|
||||
timeline: list[BrandDetection] = field(default_factory=list)
|
||||
pipeline_stats: PipelineStats = field(default_factory=PipelineStats)
|
||||
|
||||
|
||||
# Not in DATACLASSES — modelgen skips these (they contain np.ndarray)
|
||||
RUNTIME_MODELS = [
|
||||
Frame, BoundingBox, TextCandidate, BrandDetection,
|
||||
BrandStats, PipelineStats, DetectionReport,
|
||||
]
|
||||
11
core/schema/serializers/__init__.py
Normal file
11
core/schema/serializers/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
Model serializers — one module per model group, mirroring core/schema/models/.
|
||||
|
||||
models/detect_pipeline.py → serializers/detect_pipeline.py
|
||||
models/detect_jobs.py → serializers/detect_jobs.py
|
||||
models/detect.py → serializers/detect.py (SSE events)
|
||||
|
||||
Common utilities in _common.py.
|
||||
"""
|
||||
|
||||
from ._common import safe_construct, serialize_dataclass, serialize_dataclass_list
|
||||
38
core/schema/serializers/_common.py
Normal file
38
core/schema/serializers/_common.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Common serialization utilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def safe_construct(cls, data: dict):
|
||||
"""
|
||||
Construct a dataclass from a dict, tolerant of schema changes.
|
||||
|
||||
- Ignores keys not in the dataclass (field was removed)
|
||||
- Uses defaults for missing keys (field was added)
|
||||
- Logs at debug level for mismatches
|
||||
"""
|
||||
field_names = {f.name for f in dataclasses.fields(cls)}
|
||||
|
||||
known = {}
|
||||
for k, v in data.items():
|
||||
if k in field_names:
|
||||
known[k] = v
|
||||
else:
|
||||
logger.debug("Ignoring unknown field %s.%s", cls.__name__, k)
|
||||
|
||||
return cls(**known)
|
||||
|
||||
|
||||
def serialize_dataclass(obj) -> dict:
|
||||
"""Serialize any dataclass to dict via dataclasses.asdict()."""
|
||||
return dataclasses.asdict(obj)
|
||||
|
||||
|
||||
def serialize_dataclass_list(items) -> list[dict]:
|
||||
"""Serialize a list of dataclasses."""
|
||||
return [dataclasses.asdict(item) for item in items]
|
||||
108
core/schema/serializers/detect_pipeline.py
Normal file
108
core/schema/serializers/detect_pipeline.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
Serializers for detection pipeline runtime models.
|
||||
|
||||
Mirrors core/schema/models/detect_pipeline.py.
|
||||
|
||||
Special handling:
|
||||
- Frame.image (np.ndarray → S3, excluded from JSON)
|
||||
- TextCandidate.frame (object ref → frame_sequence integer)
|
||||
Everything else uses dataclasses.asdict() via safe_construct.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
|
||||
from core.schema.models.detect_pipeline import (
|
||||
BoundingBox,
|
||||
BrandDetection,
|
||||
BrandStats,
|
||||
DetectionReport,
|
||||
Frame,
|
||||
PipelineStats,
|
||||
TextCandidate,
|
||||
)
|
||||
from ._common import safe_construct, serialize_dataclass, serialize_dataclass_list
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Frame — image goes to S3 separately
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def serialize_frame_meta(frame: Frame) -> dict:
|
||||
"""Serialize Frame metadata only (no image)."""
|
||||
result = dataclasses.asdict(frame)
|
||||
del result["image"]
|
||||
return result
|
||||
|
||||
|
||||
def serialize_frames_with_upload(frames: list[Frame], job_id: str) -> tuple[list[dict], dict[int, str]]:
|
||||
"""Upload frame images to S3, return metadata + manifest."""
|
||||
from detect.checkpoint.frames import save_frames
|
||||
|
||||
manifest = save_frames(job_id, frames)
|
||||
meta = [serialize_frame_meta(f) for f in frames]
|
||||
return meta, manifest
|
||||
|
||||
|
||||
def deserialize_frames_with_download(meta: list[dict], manifest: dict, job_id: str) -> list[Frame]:
|
||||
"""Load frames from S3 + metadata."""
|
||||
from detect.checkpoint.frames import load_frames
|
||||
|
||||
int_manifest = {int(k): v for k, v in manifest.items()}
|
||||
return load_frames(int_manifest, meta)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TextCandidate — frame ref is an object, stored as sequence int
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def serialize_text_candidate(tc: TextCandidate) -> dict:
|
||||
bbox_dict = dataclasses.asdict(tc.bbox)
|
||||
result = {
|
||||
"frame_sequence": tc.frame.sequence,
|
||||
"bbox": bbox_dict,
|
||||
"text": tc.text,
|
||||
"ocr_confidence": tc.ocr_confidence,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def serialize_text_candidates(candidates: list[TextCandidate]) -> list[dict]:
|
||||
return [serialize_text_candidate(tc) for tc in candidates]
|
||||
|
||||
|
||||
def deserialize_text_candidate(data: dict, frame_map: dict[int, Frame]) -> TextCandidate:
|
||||
frame = frame_map[data["frame_sequence"]]
|
||||
bbox = safe_construct(BoundingBox, data["bbox"])
|
||||
candidate = TextCandidate(
|
||||
frame=frame,
|
||||
bbox=bbox,
|
||||
text=data["text"],
|
||||
ocr_confidence=data["ocr_confidence"],
|
||||
)
|
||||
return candidate
|
||||
|
||||
|
||||
def deserialize_text_candidates(data: list[dict], frame_map: dict[int, Frame]) -> list[TextCandidate]:
|
||||
return [deserialize_text_candidate(d, frame_map) for d in data]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BoundingBox, BrandDetection, PipelineStats, etc — standard dataclasses
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def deserialize_bounding_box(data: dict) -> BoundingBox:
|
||||
return safe_construct(BoundingBox, data)
|
||||
|
||||
|
||||
def deserialize_brand_detection(data: dict) -> BrandDetection:
|
||||
return safe_construct(BrandDetection, data)
|
||||
|
||||
|
||||
def deserialize_pipeline_stats(data: dict) -> PipelineStats:
|
||||
return safe_construct(PipelineStats, data)
|
||||
|
||||
|
||||
def deserialize_detection_report(data: dict) -> DetectionReport:
|
||||
return safe_construct(DetectionReport, data)
|
||||
@@ -1,133 +1,108 @@
|
||||
"""State serialization — DetectState ↔ JSON-compatible dict."""
|
||||
"""
|
||||
State serialization — DetectState ↔ JSON-compatible dict.
|
||||
|
||||
Delegates to each stage's serialize_fn/deserialize_fn via the registry.
|
||||
This file has no model-specific knowledge — stages own their data format.
|
||||
|
||||
The only things serialized here are the "envelope" fields (job_id, video_path, etc.)
|
||||
that don't belong to any stage.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
|
||||
from detect.models import (
|
||||
BoundingBox,
|
||||
BrandDetection,
|
||||
Frame,
|
||||
PipelineStats,
|
||||
TextCandidate,
|
||||
from core.schema.serializers._common import serialize_dataclass
|
||||
from core.schema.serializers.detect_pipeline import (
|
||||
deserialize_pipeline_stats,
|
||||
deserialize_text_candidates,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Serialize helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def serialize_frame_meta(frame: Frame) -> dict:
|
||||
meta = {
|
||||
"sequence": frame.sequence,
|
||||
"chunk_id": frame.chunk_id,
|
||||
"timestamp": frame.timestamp,
|
||||
"perceptual_hash": frame.perceptual_hash,
|
||||
}
|
||||
return meta
|
||||
|
||||
|
||||
def serialize_text_candidate(tc: TextCandidate) -> dict:
|
||||
bbox_dict = dataclasses.asdict(tc.bbox)
|
||||
candidate = {
|
||||
"frame_sequence": tc.frame.sequence,
|
||||
"bbox": bbox_dict,
|
||||
"text": tc.text,
|
||||
"ocr_confidence": tc.ocr_confidence,
|
||||
}
|
||||
return candidate
|
||||
# Envelope fields — not owned by any stage, always present
|
||||
ENVELOPE_KEYS = ["job_id", "video_path", "profile_name", "config_overrides"]
|
||||
|
||||
|
||||
def serialize_state(state: dict, frames_manifest: dict[int, str]) -> dict:
|
||||
"""
|
||||
Serialize DetectState to a JSON-compatible dict.
|
||||
|
||||
Frame images are replaced with S3 key references.
|
||||
TextCandidate.frame references become frame_sequence integers.
|
||||
Calls each registered stage's serialize_fn for stage-owned data.
|
||||
Envelope fields (job_id, etc.) are copied directly.
|
||||
"""
|
||||
frames = state.get("frames", [])
|
||||
filtered = state.get("filtered_frames", [])
|
||||
from detect.stages.base import _REGISTRY
|
||||
|
||||
manifest_strs = {str(k): v for k, v in frames_manifest.items()}
|
||||
frames_meta = [serialize_frame_meta(f) for f in frames]
|
||||
filtered_seqs = [f.sequence for f in filtered]
|
||||
checkpoint = {}
|
||||
|
||||
boxes_serialized = {}
|
||||
for seq, boxes in state.get("boxes_by_frame", {}).items():
|
||||
boxes_serialized[str(seq)] = [dataclasses.asdict(b) for b in boxes]
|
||||
# Envelope
|
||||
for key in ENVELOPE_KEYS:
|
||||
default = {} if key == "config_overrides" else ""
|
||||
checkpoint[key] = state.get(key, default)
|
||||
|
||||
text_candidates = [serialize_text_candidate(tc) for tc in state.get("text_candidates", [])]
|
||||
unresolved = [serialize_text_candidate(tc) for tc in state.get("unresolved_candidates", [])]
|
||||
detections = [dataclasses.asdict(d) for d in state.get("detections", [])]
|
||||
stats = dataclasses.asdict(state.get("stats", PipelineStats()))
|
||||
# Frames manifest (needed by frame-loading stages)
|
||||
checkpoint["frames_manifest"] = {str(k): v for k, v in frames_manifest.items()}
|
||||
|
||||
# Stats (shared across stages, not owned by one)
|
||||
stats = state.get("stats")
|
||||
if stats is not None:
|
||||
checkpoint["stats"] = serialize_dataclass(stats)
|
||||
else:
|
||||
checkpoint["stats"] = {}
|
||||
|
||||
# Per-stage data
|
||||
for name, stage_def in _REGISTRY.items():
|
||||
if stage_def.serialize_fn is None:
|
||||
continue
|
||||
job_id = state.get("job_id", "")
|
||||
stage_data = stage_def.serialize_fn(state, job_id)
|
||||
checkpoint[f"stage_{name}"] = stage_data
|
||||
|
||||
checkpoint = {
|
||||
"job_id": state.get("job_id", ""),
|
||||
"video_path": state.get("video_path", ""),
|
||||
"profile_name": state.get("profile_name", ""),
|
||||
"config_overrides": state.get("config_overrides", {}),
|
||||
"frames_manifest": manifest_strs,
|
||||
"frames_meta": frames_meta,
|
||||
"filtered_frame_sequences": filtered_seqs,
|
||||
"boxes_by_frame": boxes_serialized,
|
||||
"text_candidates": text_candidates,
|
||||
"unresolved_candidates": unresolved,
|
||||
"detections": detections,
|
||||
"stats": stats,
|
||||
}
|
||||
return checkpoint
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Deserialize helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
def deserialize_state(checkpoint: dict, frames: list) -> dict:
|
||||
"""
|
||||
Reconstitute DetectState from a checkpoint dict + loaded frames.
|
||||
|
||||
def deserialize_text_candidate(d: dict, frame_map: dict[int, Frame]) -> TextCandidate:
|
||||
frame = frame_map[d["frame_sequence"]]
|
||||
bbox = BoundingBox(**d["bbox"])
|
||||
candidate = TextCandidate(
|
||||
frame=frame,
|
||||
bbox=bbox,
|
||||
text=d["text"],
|
||||
ocr_confidence=d["ocr_confidence"],
|
||||
)
|
||||
return candidate
|
||||
Calls each stage's deserialize_fn to restore stage-owned data.
|
||||
"""
|
||||
from detect.stages.base import _REGISTRY
|
||||
|
||||
|
||||
def deserialize_state(checkpoint: dict, frames: list[Frame]) -> dict:
|
||||
"""Reconstitute DetectState from a checkpoint dict + loaded frames."""
|
||||
frame_map = {f.sequence: f for f in frames}
|
||||
|
||||
filtered_seqs = set(checkpoint.get("filtered_frame_sequences", []))
|
||||
filtered_frames = [f for f in frames if f.sequence in filtered_seqs]
|
||||
state = {}
|
||||
|
||||
boxes_by_frame = {}
|
||||
for seq_str, box_dicts in checkpoint.get("boxes_by_frame", {}).items():
|
||||
seq = int(seq_str)
|
||||
boxes_by_frame[seq] = [BoundingBox(**b) for b in box_dicts]
|
||||
# Envelope
|
||||
for key in ENVELOPE_KEYS:
|
||||
default = {} if key == "config_overrides" else ""
|
||||
state[key] = checkpoint.get(key, default)
|
||||
|
||||
text_candidates = [
|
||||
deserialize_text_candidate(d, frame_map)
|
||||
for d in checkpoint.get("text_candidates", [])
|
||||
]
|
||||
unresolved_candidates = [
|
||||
deserialize_text_candidate(d, frame_map)
|
||||
for d in checkpoint.get("unresolved_candidates", [])
|
||||
]
|
||||
detections = [BrandDetection(**d) for d in checkpoint.get("detections", [])]
|
||||
stats = PipelineStats(**checkpoint.get("stats", {}))
|
||||
# Frames (always present, loaded externally)
|
||||
state["frames"] = frames
|
||||
|
||||
# Stats
|
||||
state["stats"] = deserialize_pipeline_stats(checkpoint.get("stats", {}))
|
||||
|
||||
# Per-stage data
|
||||
for name, stage_def in _REGISTRY.items():
|
||||
if stage_def.deserialize_fn is None:
|
||||
continue
|
||||
|
||||
stage_key = f"stage_{name}"
|
||||
if stage_key not in checkpoint:
|
||||
continue
|
||||
|
||||
job_id = state.get("job_id", "")
|
||||
stage_data = stage_def.deserialize_fn(checkpoint[stage_key], job_id)
|
||||
|
||||
for k, v in stage_data.items():
|
||||
if k == "_filtered_sequences":
|
||||
# Reconnect filtered frames from sequence list
|
||||
seq_set = set(v)
|
||||
state["filtered_frames"] = [f for f in frames if f.sequence in seq_set]
|
||||
elif k.endswith("_raw"):
|
||||
# Raw text candidates need frame reference reconnection
|
||||
real_key = k.removeprefix("_").removesuffix("_raw")
|
||||
state[real_key] = deserialize_text_candidates(v, frame_map)
|
||||
else:
|
||||
state[k] = v
|
||||
|
||||
state = {
|
||||
"job_id": checkpoint.get("job_id", ""),
|
||||
"video_path": checkpoint.get("video_path", ""),
|
||||
"profile_name": checkpoint.get("profile_name", ""),
|
||||
"config_overrides": checkpoint.get("config_overrides", {}),
|
||||
"frames": frames,
|
||||
"filtered_frames": filtered_frames,
|
||||
"boxes_by_frame": boxes_by_frame,
|
||||
"text_candidates": text_candidates,
|
||||
"unresolved_candidates": unresolved_candidates,
|
||||
"detections": detections,
|
||||
"stats": stats,
|
||||
}
|
||||
return state
|
||||
|
||||
102
detect/models.py
102
detect/models.py
@@ -1,86 +1,26 @@
|
||||
"""
|
||||
Core domain models for the detection pipeline.
|
||||
Re-export pipeline runtime models from core/schema/models/detect_pipeline.py.
|
||||
|
||||
These are pipeline-internal models — the data structures that flow
|
||||
between LangGraph nodes. SSE event payloads (sse_contract.py) are
|
||||
derived from these when emitting to the UI.
|
||||
All models are defined in core/schema/ — this module exists for backward
|
||||
compatibility so existing imports (from detect.models import Frame) keep working.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from core.schema.models.detect_pipeline import (
|
||||
BoundingBox,
|
||||
BrandDetection,
|
||||
BrandStats,
|
||||
DetectionReport,
|
||||
Frame,
|
||||
PipelineStats,
|
||||
TextCandidate,
|
||||
)
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class Frame:
|
||||
sequence: int
|
||||
chunk_id: int
|
||||
timestamp: float # position in video (seconds)
|
||||
image: np.ndarray
|
||||
perceptual_hash: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BoundingBox:
|
||||
x: int
|
||||
y: int
|
||||
w: int
|
||||
h: int
|
||||
confidence: float
|
||||
label: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextCandidate:
|
||||
frame: Frame
|
||||
bbox: BoundingBox
|
||||
text: str
|
||||
ocr_confidence: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class BrandDetection:
|
||||
brand: str
|
||||
timestamp: float
|
||||
duration: float
|
||||
confidence: float
|
||||
source: Literal["ocr", "local_vlm", "cloud_llm", "logo_match", "auxiliary"]
|
||||
bbox: BoundingBox | None = None
|
||||
frame_ref: int | None = None
|
||||
content_type: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BrandStats:
|
||||
total_appearances: int = 0
|
||||
total_screen_time: float = 0.0
|
||||
avg_confidence: float = 0.0
|
||||
first_seen: float = 0.0
|
||||
last_seen: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineStats:
|
||||
frames_extracted: int = 0
|
||||
frames_after_scene_filter: int = 0
|
||||
regions_detected: int = 0
|
||||
regions_resolved_by_ocr: int = 0
|
||||
regions_escalated_to_local_vlm: int = 0
|
||||
regions_escalated_to_cloud_llm: int = 0
|
||||
auxiliary_detections: int = 0
|
||||
cloud_llm_calls: int = 0
|
||||
processing_time_seconds: float = 0.0
|
||||
estimated_cloud_cost_usd: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectionReport:
|
||||
video_source: str
|
||||
content_type: str
|
||||
duration_seconds: float
|
||||
brands: dict[str, BrandStats] = field(default_factory=dict)
|
||||
timeline: list[BrandDetection] = field(default_factory=list)
|
||||
pipeline_stats: PipelineStats = field(default_factory=PipelineStats)
|
||||
__all__ = [
|
||||
"BoundingBox",
|
||||
"BrandDetection",
|
||||
"BrandStats",
|
||||
"DetectionReport",
|
||||
"Frame",
|
||||
"PipelineStats",
|
||||
"TextCandidate",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
Pipeline stages.
|
||||
|
||||
Each stage registers its StageDefinition on import,
|
||||
declaring IO (what it reads/writes from state),
|
||||
config fields (what's tunable from the editor),
|
||||
and serialization (how to checkpoint its outputs).
|
||||
"""
|
||||
|
||||
from .base import (
|
||||
StageDefinition,
|
||||
StageIO,
|
||||
StageConfigField,
|
||||
register_stage,
|
||||
get_stage,
|
||||
list_stages,
|
||||
get_palette,
|
||||
)
|
||||
|
||||
# Populate registry with built-in stages
|
||||
from . import registry # noqa: F401
|
||||
|
||||
101
detect/stages/base.py
Normal file
101
detect/stages/base.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Stage protocol — common interface for all pipeline stages.
|
||||
|
||||
Every stage declares:
|
||||
- IO: what it reads/writes from DetectState
|
||||
- Config: tunable parameters for the editor
|
||||
- Serialization: how to persist/restore its own outputs
|
||||
|
||||
The checkpoint layer is a black box — it asks each stage to serialize its
|
||||
outputs and stores the result. Stages own their data format. Binary data
|
||||
(frames, crops) goes to S3 via the stage itself. The checkpoint just
|
||||
stores the JSON envelope.
|
||||
|
||||
The graph builder uses StageIO to validate that a stage's inputs are
|
||||
satisfied by previous stages' outputs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
@dataclass
|
||||
class StageIO:
|
||||
"""Declares what a stage reads and writes from/to DetectState."""
|
||||
reads: list[str]
|
||||
writes: list[str]
|
||||
optional_reads: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StageConfigField:
|
||||
"""A single tunable config parameter for the editor UI."""
|
||||
name: str
|
||||
type: str # "float", "int", "str", "bool", "list[str]"
|
||||
default: Any
|
||||
description: str = ""
|
||||
min: float | None = None
|
||||
max: float | None = None
|
||||
options: list[str] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class StageDefinition:
|
||||
"""
|
||||
Complete metadata for a pipeline stage.
|
||||
|
||||
The profile editor uses this to build the palette, generate config
|
||||
forms, and validate graph connections. The checkpoint uses serialize_fn
|
||||
and deserialize_fn to persist stage outputs without knowing the internals.
|
||||
"""
|
||||
name: str
|
||||
label: str
|
||||
description: str
|
||||
io: StageIO
|
||||
config_fields: list[StageConfigField] = field(default_factory=list)
|
||||
category: str = "detection"
|
||||
|
||||
# The actual graph node function: (DetectState) → dict
|
||||
fn: Callable | None = None
|
||||
|
||||
# Stage-owned serialization for checkpointing.
|
||||
# serialize_fn: (state: dict, job_id: str) → json-compatible dict
|
||||
# Stage picks its writes from state, serializes them.
|
||||
# Binary data (frames) → S3 via stage, returns refs.
|
||||
# deserialize_fn: (data: dict, job_id: str) → state update dict
|
||||
# Stage restores its writes from the persisted data.
|
||||
serialize_fn: Callable | None = None
|
||||
deserialize_fn: Callable | None = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_REGISTRY: dict[str, StageDefinition] = {}
|
||||
|
||||
|
||||
def register_stage(definition: StageDefinition):
|
||||
_REGISTRY[definition.name] = definition
|
||||
|
||||
|
||||
def get_stage(name: str) -> StageDefinition:
|
||||
if name not in _REGISTRY:
|
||||
raise KeyError(f"Unknown stage: {name!r}. Registered: {list(_REGISTRY)}")
|
||||
return _REGISTRY[name]
|
||||
|
||||
|
||||
def list_stages() -> list[StageDefinition]:
|
||||
return list(_REGISTRY.values())
|
||||
|
||||
|
||||
def get_palette() -> dict[str, list[StageDefinition]]:
|
||||
"""Group stages by category for the editor palette."""
|
||||
palette: dict[str, list[StageDefinition]] = {}
|
||||
for stage in _REGISTRY.values():
|
||||
if stage.category not in palette:
|
||||
palette[stage.category] = []
|
||||
palette[stage.category].append(stage)
|
||||
return palette
|
||||
28
detect/stages/registry/__init__.py
Normal file
28
detect/stages/registry/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
Stage registry — registers all built-in stages.
|
||||
|
||||
Split by category:
|
||||
preprocessing.py — extract_frames, filter_scenes
|
||||
detection.py — detect_objects, run_ocr
|
||||
resolution.py — match_brands
|
||||
escalation.py — escalate_vlm, escalate_cloud
|
||||
output.py — compile_report
|
||||
_serializers.py — shared serialization helpers
|
||||
"""
|
||||
|
||||
from . import preprocessing
|
||||
from . import detection
|
||||
from . import resolution
|
||||
from . import escalation
|
||||
from . import output
|
||||
|
||||
|
||||
def register_all():
|
||||
preprocessing.register()
|
||||
detection.register()
|
||||
resolution.register()
|
||||
escalation.register()
|
||||
output.register()
|
||||
|
||||
|
||||
register_all()
|
||||
25
detect/stages/registry/_serializers.py
Normal file
25
detect/stages/registry/_serializers.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""
|
||||
Re-export serializers from core/schema/serializers/.
|
||||
|
||||
Stage registry modules import from here for convenience.
|
||||
All serialization logic lives in core/schema/serializers/.
|
||||
"""
|
||||
|
||||
from core.schema.serializers._common import (
|
||||
safe_construct,
|
||||
serialize_dataclass,
|
||||
serialize_dataclass_list,
|
||||
)
|
||||
from core.schema.serializers.detect_pipeline import (
|
||||
serialize_frame_meta,
|
||||
serialize_frames_with_upload as serialize_frames,
|
||||
deserialize_frames_with_download as deserialize_frames,
|
||||
serialize_text_candidate,
|
||||
serialize_text_candidates,
|
||||
deserialize_text_candidate,
|
||||
deserialize_text_candidates,
|
||||
deserialize_bounding_box,
|
||||
deserialize_brand_detection,
|
||||
deserialize_pipeline_stats,
|
||||
deserialize_detection_report,
|
||||
)
|
||||
63
detect/stages/registry/detection.py
Normal file
63
detect/stages/registry/detection.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Registration for detection stages: YOLO, OCR."""
|
||||
|
||||
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
|
||||
from ._serializers import (
|
||||
serialize_dataclass_list,
|
||||
serialize_text_candidates,
|
||||
deserialize_bounding_box,
|
||||
)
|
||||
|
||||
|
||||
def _ser_detect(state: dict, job_id: str) -> dict:
|
||||
boxes = state.get("boxes_by_frame", {})
|
||||
serialized = {str(seq): serialize_dataclass_list(bl) for seq, bl in boxes.items()}
|
||||
return {"boxes_by_frame": serialized}
|
||||
|
||||
|
||||
def _deser_detect(data: dict, job_id: str) -> dict:
|
||||
boxes = {}
|
||||
for seq_str, box_dicts in data.get("boxes_by_frame", {}).items():
|
||||
boxes[int(seq_str)] = [deserialize_bounding_box(b) for b in box_dicts]
|
||||
return {"boxes_by_frame": boxes}
|
||||
|
||||
|
||||
def _ser_ocr(state: dict, job_id: str) -> dict:
|
||||
candidates = state.get("text_candidates", [])
|
||||
return {"text_candidates": serialize_text_candidates(candidates)}
|
||||
|
||||
|
||||
def _deser_ocr(data: dict, job_id: str) -> dict:
|
||||
return {"_text_candidates_raw": data["text_candidates"]}
|
||||
|
||||
|
||||
def register():
|
||||
yolo = StageDefinition(
|
||||
name="detect_objects",
|
||||
label="Object Detection",
|
||||
description="YOLO object detection on filtered frames",
|
||||
category="detection",
|
||||
io=StageIO(reads=["filtered_frames"], writes=["boxes_by_frame"]),
|
||||
config_fields=[
|
||||
StageConfigField("model_name", "str", "yolov8n.pt", "YOLO model file"),
|
||||
StageConfigField("confidence_threshold", "float", 0.3, "Min detection confidence", min=0.0, max=1.0),
|
||||
StageConfigField("target_classes", "list[str]", [], "YOLO classes to detect (empty = all)"),
|
||||
],
|
||||
serialize_fn=_ser_detect,
|
||||
deserialize_fn=_deser_detect,
|
||||
)
|
||||
register_stage(yolo)
|
||||
|
||||
ocr = StageDefinition(
|
||||
name="run_ocr",
|
||||
label="OCR",
|
||||
description="Extract text from detected regions",
|
||||
category="detection",
|
||||
io=StageIO(reads=["filtered_frames", "boxes_by_frame"], writes=["text_candidates"]),
|
||||
config_fields=[
|
||||
StageConfigField("languages", "list[str]", ["en"], "OCR languages"),
|
||||
StageConfigField("min_confidence", "float", 0.5, "Min OCR confidence", min=0.0, max=1.0),
|
||||
],
|
||||
serialize_fn=_ser_ocr,
|
||||
deserialize_fn=_deser_ocr,
|
||||
)
|
||||
register_stage(ocr)
|
||||
63
detect/stages/registry/escalation.py
Normal file
63
detect/stages/registry/escalation.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Registration for escalation stages: local VLM, cloud LLM."""
|
||||
|
||||
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
|
||||
from ._serializers import (
|
||||
serialize_dataclass_list,
|
||||
serialize_text_candidates,
|
||||
deserialize_brand_detection,
|
||||
)
|
||||
|
||||
|
||||
def _ser_escalation(state: dict, job_id: str) -> dict:
|
||||
matched = state.get("detections", [])
|
||||
unresolved = state.get("unresolved_candidates", [])
|
||||
return {
|
||||
"detections": serialize_dataclass_list(matched),
|
||||
"unresolved_candidates": serialize_text_candidates(unresolved),
|
||||
}
|
||||
|
||||
|
||||
def _deser_escalation(data: dict, job_id: str) -> dict:
|
||||
detections = [deserialize_brand_detection(d) for d in data.get("detections", [])]
|
||||
return {
|
||||
"detections": detections,
|
||||
"_unresolved_raw": data.get("unresolved_candidates", []),
|
||||
}
|
||||
|
||||
|
||||
def register():
|
||||
vlm = StageDefinition(
|
||||
name="escalate_vlm",
|
||||
label="Local VLM",
|
||||
description="Process unresolved crops with moondream2",
|
||||
category="escalation",
|
||||
io=StageIO(
|
||||
reads=["unresolved_candidates"],
|
||||
writes=["detections", "unresolved_candidates"],
|
||||
optional_reads=["source_asset_id"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField("min_confidence", "float", 0.5, "Min VLM confidence", min=0.0, max=1.0),
|
||||
],
|
||||
serialize_fn=_ser_escalation,
|
||||
deserialize_fn=_deser_escalation,
|
||||
)
|
||||
register_stage(vlm)
|
||||
|
||||
cloud = StageDefinition(
|
||||
name="escalate_cloud",
|
||||
label="Cloud LLM",
|
||||
description="Escalate remaining crops to cloud provider",
|
||||
category="escalation",
|
||||
io=StageIO(
|
||||
reads=["unresolved_candidates"],
|
||||
writes=["detections"],
|
||||
optional_reads=["source_asset_id"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField("min_confidence", "float", 0.4, "Min cloud confidence", min=0.0, max=1.0),
|
||||
],
|
||||
serialize_fn=_ser_escalation,
|
||||
deserialize_fn=_deser_escalation,
|
||||
)
|
||||
register_stage(cloud)
|
||||
32
detect/stages/registry/output.py
Normal file
32
detect/stages/registry/output.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Registration for output stages: report compilation."""
|
||||
|
||||
from detect.stages.base import StageDefinition, StageIO, register_stage
|
||||
from ._serializers import serialize_dataclass, deserialize_detection_report
|
||||
|
||||
|
||||
def _ser_report(state: dict, job_id: str) -> dict:
|
||||
report = state.get("report")
|
||||
if report is None:
|
||||
return {"report": None}
|
||||
return {"report": serialize_dataclass(report)}
|
||||
|
||||
|
||||
def _deser_report(data: dict, job_id: str) -> dict:
|
||||
raw = data.get("report")
|
||||
if raw is None:
|
||||
return {"report": None}
|
||||
return {"report": deserialize_detection_report(raw)}
|
||||
|
||||
|
||||
def register():
|
||||
report = StageDefinition(
|
||||
name="compile_report",
|
||||
label="Report",
|
||||
description="Merge detections and compile final report",
|
||||
category="output",
|
||||
io=StageIO(reads=["detections"], writes=["report"]),
|
||||
config_fields=[],
|
||||
serialize_fn=_ser_report,
|
||||
deserialize_fn=_deser_report,
|
||||
)
|
||||
register_stage(report)
|
||||
57
detect/stages/registry/preprocessing.py
Normal file
57
detect/stages/registry/preprocessing.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Registration for preprocessing stages: frame extraction, scene filter."""
|
||||
|
||||
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
|
||||
from ._serializers import serialize_frames, deserialize_frames
|
||||
|
||||
|
||||
def _ser_extract(state: dict, job_id: str) -> dict:
|
||||
frames = state.get("frames", [])
|
||||
meta, manifest = serialize_frames(frames, job_id)
|
||||
return {"frames_meta": meta, "frames_manifest": manifest}
|
||||
|
||||
|
||||
def _deser_extract(data: dict, job_id: str) -> dict:
|
||||
frames = deserialize_frames(data["frames_meta"], data["frames_manifest"], job_id)
|
||||
return {"frames": frames}
|
||||
|
||||
|
||||
def _ser_filter(state: dict, job_id: str) -> dict:
|
||||
filtered = state.get("filtered_frames", [])
|
||||
seqs = [f.sequence for f in filtered]
|
||||
return {"filtered_frame_sequences": seqs}
|
||||
|
||||
|
||||
def _deser_filter(data: dict, job_id: str) -> dict:
|
||||
return {"_filtered_sequences": data["filtered_frame_sequences"]}
|
||||
|
||||
|
||||
def register():
|
||||
extract = StageDefinition(
|
||||
name="extract_frames",
|
||||
label="Frame Extraction",
|
||||
description="Extract frames from video at configurable FPS",
|
||||
category="preprocessing",
|
||||
io=StageIO(reads=["video_path"], writes=["frames"]),
|
||||
config_fields=[
|
||||
StageConfigField("fps", "float", 2.0, "Frames per second", min=0.1, max=30.0),
|
||||
StageConfigField("max_frames", "int", 500, "Maximum frames to extract", min=1, max=10000),
|
||||
],
|
||||
serialize_fn=_ser_extract,
|
||||
deserialize_fn=_deser_extract,
|
||||
)
|
||||
register_stage(extract)
|
||||
|
||||
scene_filter = StageDefinition(
|
||||
name="filter_scenes",
|
||||
label="Scene Filter",
|
||||
description="Deduplicate similar frames using perceptual hashing",
|
||||
category="preprocessing",
|
||||
io=StageIO(reads=["frames"], writes=["filtered_frames"]),
|
||||
config_fields=[
|
||||
StageConfigField("hamming_threshold", "int", 8, "Hamming distance threshold", min=0, max=64),
|
||||
StageConfigField("enabled", "bool", True, "Enable scene filtering"),
|
||||
],
|
||||
serialize_fn=_ser_filter,
|
||||
deserialize_fn=_deser_filter,
|
||||
)
|
||||
register_stage(scene_filter)
|
||||
45
detect/stages/registry/resolution.py
Normal file
45
detect/stages/registry/resolution.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Registration for resolution stages: brand resolver."""
|
||||
|
||||
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
|
||||
from ._serializers import (
|
||||
serialize_dataclass_list,
|
||||
serialize_text_candidates,
|
||||
deserialize_brand_detection,
|
||||
)
|
||||
|
||||
|
||||
def _ser_brands(state: dict, job_id: str) -> dict:
|
||||
matched = state.get("detections", [])
|
||||
unresolved = state.get("unresolved_candidates", [])
|
||||
return {
|
||||
"detections": serialize_dataclass_list(matched),
|
||||
"unresolved_candidates": serialize_text_candidates(unresolved),
|
||||
}
|
||||
|
||||
|
||||
def _deser_brands(data: dict, job_id: str) -> dict:
|
||||
detections = [deserialize_brand_detection(d) for d in data.get("detections", [])]
|
||||
return {
|
||||
"detections": detections,
|
||||
"_unresolved_raw": data.get("unresolved_candidates", []),
|
||||
}
|
||||
|
||||
|
||||
def register():
|
||||
resolver = StageDefinition(
|
||||
name="match_brands",
|
||||
label="Brand Resolver",
|
||||
description="Match OCR text against known brands (session + global DB)",
|
||||
category="resolution",
|
||||
io=StageIO(
|
||||
reads=["text_candidates"],
|
||||
writes=["detections", "unresolved_candidates"],
|
||||
optional_reads=["session_brands", "source_asset_id"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField("fuzzy_threshold", "int", 75, "Fuzzy match threshold", min=0, max=100),
|
||||
],
|
||||
serialize_fn=_ser_brands,
|
||||
deserialize_fn=_deser_brands,
|
||||
)
|
||||
register_stage(resolver)
|
||||
@@ -1,26 +1,30 @@
|
||||
"""Tests for checkpoint serialization — round-trip without S3."""
|
||||
"""Tests for checkpoint serialization — unit tests without S3."""
|
||||
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from detect.models import BoundingBox, BrandDetection, Frame, PipelineStats, TextCandidate
|
||||
from detect.checkpoint.serializer import (
|
||||
serialize_state,
|
||||
deserialize_state,
|
||||
from core.schema.serializers._common import safe_construct
|
||||
from core.schema.serializers.detect_pipeline import (
|
||||
serialize_frame_meta,
|
||||
serialize_text_candidate,
|
||||
serialize_text_candidates,
|
||||
serialize_dataclass_list,
|
||||
deserialize_text_candidate,
|
||||
deserialize_text_candidates,
|
||||
deserialize_bounding_box,
|
||||
deserialize_brand_detection,
|
||||
deserialize_pipeline_stats,
|
||||
)
|
||||
|
||||
|
||||
def _make_frame(seq: int = 0, w: int = 100, h: int = 80) -> Frame:
|
||||
image = np.random.randint(0, 255, (h, w, 3), dtype=np.uint8)
|
||||
return Frame(
|
||||
sequence=seq,
|
||||
chunk_id=0,
|
||||
timestamp=float(seq) * 0.5,
|
||||
image=image,
|
||||
perceptual_hash=f"hash_{seq}",
|
||||
sequence=seq, chunk_id=0, timestamp=float(seq) * 0.5,
|
||||
image=image, perceptual_hash=f"hash_{seq}",
|
||||
)
|
||||
|
||||
|
||||
@@ -35,13 +39,8 @@ def _make_candidate(frame: Frame, text: str = "NIKE") -> TextCandidate:
|
||||
|
||||
def _make_detection(brand: str = "Nike", timestamp: float = 1.0) -> BrandDetection:
|
||||
return BrandDetection(
|
||||
brand=brand,
|
||||
timestamp=timestamp,
|
||||
duration=0.5,
|
||||
confidence=0.92,
|
||||
source="ocr",
|
||||
content_type="soccer_broadcast",
|
||||
frame_ref=0,
|
||||
brand=brand, timestamp=timestamp, duration=0.5,
|
||||
confidence=0.92, source="ocr", content_type="soccer_broadcast", frame_ref=0,
|
||||
)
|
||||
|
||||
|
||||
@@ -81,102 +80,84 @@ def test_deserialize_text_candidate():
|
||||
|
||||
assert restored.text == "ADIDAS"
|
||||
assert restored.ocr_confidence == 0.85
|
||||
assert restored.frame is frame # same object reference
|
||||
assert restored.frame is frame
|
||||
assert restored.bbox.x == 10
|
||||
|
||||
|
||||
# --- Full state round-trip ---
|
||||
# --- BoundingBox ---
|
||||
|
||||
def test_state_round_trip():
|
||||
frames = [_make_frame(seq=i) for i in range(3)]
|
||||
filtered = frames[:2]
|
||||
def test_bounding_box_round_trip():
|
||||
box = _make_box(x=15, y=25, w=40, h=30)
|
||||
serialized = serialize_dataclass_list([box])[0]
|
||||
restored = deserialize_bounding_box(serialized)
|
||||
|
||||
box = _make_box()
|
||||
boxes_by_frame = {0: [box], 1: [box]}
|
||||
assert restored.x == 15
|
||||
assert restored.w == 40
|
||||
assert restored.confidence == 0.9
|
||||
|
||||
candidates = [_make_candidate(frames[0], "NIKE"), _make_candidate(frames[1], "EMIRATES")]
|
||||
unresolved = [_make_candidate(frames[2], "unknown")]
|
||||
detections = [_make_detection("Nike", 0.5), _make_detection("Emirates", 1.0)]
|
||||
|
||||
# --- BrandDetection ---
|
||||
|
||||
def test_brand_detection_round_trip():
|
||||
det = _make_detection("Emirates", 3.5)
|
||||
serialized = serialize_dataclass_list([det])[0]
|
||||
restored = deserialize_brand_detection(serialized)
|
||||
|
||||
assert restored.brand == "Emirates"
|
||||
assert restored.timestamp == 3.5
|
||||
assert restored.source == "ocr"
|
||||
|
||||
|
||||
# --- PipelineStats ---
|
||||
|
||||
def test_pipeline_stats_round_trip():
|
||||
stats = PipelineStats(
|
||||
frames_extracted=3,
|
||||
frames_after_scene_filter=2,
|
||||
regions_detected=2,
|
||||
regions_resolved_by_ocr=2,
|
||||
cloud_llm_calls=1,
|
||||
estimated_cloud_cost_usd=0.003,
|
||||
frames_extracted=120, cloud_llm_calls=3,
|
||||
estimated_cloud_cost_usd=0.005,
|
||||
)
|
||||
serialized = serialize_dataclass_list([stats])[0]
|
||||
restored = deserialize_pipeline_stats(serialized)
|
||||
|
||||
state = {
|
||||
"job_id": "test-123",
|
||||
"video_path": "/tmp/test.mp4",
|
||||
"profile_name": "soccer_broadcast",
|
||||
"config_overrides": {"ocr": {"min_confidence": 0.3}},
|
||||
"frames": frames,
|
||||
"filtered_frames": filtered,
|
||||
"boxes_by_frame": boxes_by_frame,
|
||||
"text_candidates": candidates,
|
||||
"unresolved_candidates": unresolved,
|
||||
"detections": detections,
|
||||
"stats": stats,
|
||||
assert restored.frames_extracted == 120
|
||||
assert restored.cloud_llm_calls == 3
|
||||
assert restored.estimated_cloud_cost_usd == 0.005
|
||||
|
||||
|
||||
# --- safe_construct tolerates schema changes ---
|
||||
|
||||
def test_safe_construct_ignores_unknown_fields():
|
||||
data = {"x": 1, "y": 2, "w": 3, "h": 4, "confidence": 0.5, "label": "t", "extra_field": 99}
|
||||
box = safe_construct(BoundingBox, data)
|
||||
|
||||
assert box.x == 1
|
||||
assert not hasattr(box, "extra_field")
|
||||
|
||||
|
||||
def test_safe_construct_uses_defaults():
|
||||
data = {"frames_extracted": 50}
|
||||
stats = safe_construct(PipelineStats, data)
|
||||
|
||||
assert stats.frames_extracted == 50
|
||||
assert stats.cloud_llm_calls == 0 # default
|
||||
|
||||
|
||||
# --- JSON compatibility ---
|
||||
|
||||
def test_all_serialized_is_json_compatible():
|
||||
frame = _make_frame()
|
||||
candidate = _make_candidate(frame)
|
||||
detection = _make_detection()
|
||||
stats = PipelineStats(frames_extracted=10)
|
||||
|
||||
all_data = {
|
||||
"frame_meta": serialize_frame_meta(frame),
|
||||
"candidate": serialize_text_candidate(candidate),
|
||||
"detection": serialize_dataclass_list([detection])[0],
|
||||
"stats": serialize_dataclass_list([stats])[0],
|
||||
}
|
||||
|
||||
manifest = {f.sequence: f"s3://fake/frames/{f.sequence}.jpg" for f in frames}
|
||||
|
||||
# Serialize
|
||||
serialized = serialize_state(state, manifest)
|
||||
|
||||
# Verify JSON-compatible (no numpy, no Frame objects)
|
||||
import json
|
||||
json_str = json.dumps(serialized, default=str)
|
||||
json_str = json.dumps(all_data, default=str)
|
||||
assert len(json_str) > 0
|
||||
|
||||
# Deserialize with the original frames (simulating frame load from S3)
|
||||
restored = deserialize_state(serialized, frames)
|
||||
|
||||
# Verify round-trip
|
||||
assert restored["job_id"] == "test-123"
|
||||
assert restored["video_path"] == "/tmp/test.mp4"
|
||||
assert restored["profile_name"] == "soccer_broadcast"
|
||||
assert restored["config_overrides"] == {"ocr": {"min_confidence": 0.3}}
|
||||
|
||||
assert len(restored["frames"]) == 3
|
||||
assert len(restored["filtered_frames"]) == 2
|
||||
assert len(restored["boxes_by_frame"]) == 2
|
||||
assert len(restored["text_candidates"]) == 2
|
||||
assert len(restored["unresolved_candidates"]) == 1
|
||||
assert len(restored["detections"]) == 2
|
||||
|
||||
restored_stats = restored["stats"]
|
||||
assert restored_stats.frames_extracted == 3
|
||||
assert restored_stats.cloud_llm_calls == 1
|
||||
assert restored_stats.estimated_cloud_cost_usd == 0.003
|
||||
|
||||
# TextCandidate frame references should point to actual Frame objects
|
||||
tc = restored["text_candidates"][0]
|
||||
assert tc.frame is frames[0]
|
||||
assert tc.text == "NIKE"
|
||||
|
||||
|
||||
def test_state_round_trip_empty():
|
||||
"""Empty state should serialize/deserialize cleanly."""
|
||||
state = {
|
||||
"job_id": "empty-job",
|
||||
"video_path": "",
|
||||
"profile_name": "soccer_broadcast",
|
||||
"frames": [],
|
||||
"filtered_frames": [],
|
||||
"boxes_by_frame": {},
|
||||
"text_candidates": [],
|
||||
"unresolved_candidates": [],
|
||||
"detections": [],
|
||||
"stats": PipelineStats(),
|
||||
}
|
||||
|
||||
serialized = serialize_state(state, {})
|
||||
restored = deserialize_state(serialized, [])
|
||||
|
||||
assert restored["job_id"] == "empty-job"
|
||||
assert len(restored["frames"]) == 0
|
||||
assert len(restored["detections"]) == 0
|
||||
assert restored["stats"].frames_extracted == 0
|
||||
roundtrip = json.loads(json_str)
|
||||
assert roundtrip["frame_meta"]["sequence"] == frame.sequence
|
||||
|
||||
58
tests/detect/test_stage_registry.py
Normal file
58
tests/detect/test_stage_registry.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Tests for the stage registry."""
|
||||
|
||||
from detect.stages import list_stages, get_stage, get_palette
|
||||
|
||||
|
||||
EXPECTED_STAGES = [
|
||||
"extract_frames", "filter_scenes", "detect_objects", "run_ocr",
|
||||
"match_brands", "escalate_vlm", "escalate_cloud", "compile_report",
|
||||
]
|
||||
|
||||
|
||||
def test_all_stages_registered():
|
||||
stages = list_stages()
|
||||
names = [s.name for s in stages]
|
||||
for expected in EXPECTED_STAGES:
|
||||
assert expected in names, f"Missing stage: {expected}"
|
||||
|
||||
|
||||
def test_stage_has_io():
|
||||
for name in EXPECTED_STAGES:
|
||||
stage = get_stage(name)
|
||||
assert len(stage.io.writes) > 0, f"{name} has no writes"
|
||||
assert stage.label, f"{name} has no label"
|
||||
assert stage.description, f"{name} has no description"
|
||||
|
||||
|
||||
def test_stage_has_serialization():
|
||||
for name in EXPECTED_STAGES:
|
||||
stage = get_stage(name)
|
||||
assert stage.serialize_fn is not None, f"{name} has no serialize_fn"
|
||||
assert stage.deserialize_fn is not None, f"{name} has no deserialize_fn"
|
||||
|
||||
|
||||
def test_palette_groups():
|
||||
palette = get_palette()
|
||||
assert len(palette) > 0
|
||||
all_stages = []
|
||||
for category, stages in palette.items():
|
||||
assert isinstance(category, str)
|
||||
all_stages.extend(stages)
|
||||
assert len(all_stages) == len(EXPECTED_STAGES)
|
||||
|
||||
|
||||
def test_io_chain_valid():
|
||||
"""Each stage's reads should be satisfied by previous stages' writes."""
|
||||
stages = [get_stage(name) for name in EXPECTED_STAGES]
|
||||
available_keys = {"video_path", "job_id", "profile_name",
|
||||
"source_asset_id", "session_brands", "stats",
|
||||
"config_overrides"}
|
||||
|
||||
for stage in stages:
|
||||
for read_key in stage.io.reads:
|
||||
assert read_key in available_keys, (
|
||||
f"Stage {stage.name} reads '{read_key}' but no previous stage writes it. "
|
||||
f"Available: {available_keys}"
|
||||
)
|
||||
for write_key in stage.io.writes:
|
||||
available_keys.add(write_key)
|
||||
Reference in New Issue
Block a user