schema clean up and refactor

This commit is contained in:
2026-03-26 05:14:33 -03:00
parent 08c58a6a9d
commit d58a90157a
17 changed files with 930 additions and 287 deletions

View File

@@ -0,0 +1,21 @@
"""
Pipeline stages.
Each stage registers its StageDefinition on import,
declaring IO (what it reads/writes from state),
config fields (what's tunable from the editor),
and serialization (how to checkpoint its outputs).
"""
from .base import (
StageDefinition,
StageIO,
StageConfigField,
register_stage,
get_stage,
list_stages,
get_palette,
)
# Populate registry with built-in stages
from . import registry # noqa: F401

101
detect/stages/base.py Normal file
View File

@@ -0,0 +1,101 @@
"""
Stage protocol — common interface for all pipeline stages.
Every stage declares:
- IO: what it reads/writes from DetectState
- Config: tunable parameters for the editor
- Serialization: how to persist/restore its own outputs
The checkpoint layer is a black box — it asks each stage to serialize its
outputs and stores the result. Stages own their data format. Binary data
(frames, crops) goes to S3 via the stage itself. The checkpoint just
stores the JSON envelope.
The graph builder uses StageIO to validate that a stage's inputs are
satisfied by previous stages' outputs.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Callable
@dataclass
class StageIO:
"""Declares what a stage reads and writes from/to DetectState."""
reads: list[str]
writes: list[str]
optional_reads: list[str] = field(default_factory=list)
@dataclass
class StageConfigField:
"""A single tunable config parameter for the editor UI."""
name: str
type: str # "float", "int", "str", "bool", "list[str]"
default: Any
description: str = ""
min: float | None = None
max: float | None = None
options: list[str] | None = None
@dataclass
class StageDefinition:
"""
Complete metadata for a pipeline stage.
The profile editor uses this to build the palette, generate config
forms, and validate graph connections. The checkpoint uses serialize_fn
and deserialize_fn to persist stage outputs without knowing the internals.
"""
name: str
label: str
description: str
io: StageIO
config_fields: list[StageConfigField] = field(default_factory=list)
category: str = "detection"
# The actual graph node function: (DetectState) → dict
fn: Callable | None = None
# Stage-owned serialization for checkpointing.
# serialize_fn: (state: dict, job_id: str) → json-compatible dict
# Stage picks its writes from state, serializes them.
# Binary data (frames) → S3 via stage, returns refs.
# deserialize_fn: (data: dict, job_id: str) → state update dict
# Stage restores its writes from the persisted data.
serialize_fn: Callable | None = None
deserialize_fn: Callable | None = None
# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
_REGISTRY: dict[str, StageDefinition] = {}
def register_stage(definition: StageDefinition):
_REGISTRY[definition.name] = definition
def get_stage(name: str) -> StageDefinition:
if name not in _REGISTRY:
raise KeyError(f"Unknown stage: {name!r}. Registered: {list(_REGISTRY)}")
return _REGISTRY[name]
def list_stages() -> list[StageDefinition]:
return list(_REGISTRY.values())
def get_palette() -> dict[str, list[StageDefinition]]:
"""Group stages by category for the editor palette."""
palette: dict[str, list[StageDefinition]] = {}
for stage in _REGISTRY.values():
if stage.category not in palette:
palette[stage.category] = []
palette[stage.category].append(stage)
return palette

View File

@@ -0,0 +1,28 @@
"""
Stage registry — registers all built-in stages.
Split by category:
preprocessing.py — extract_frames, filter_scenes
detection.py — detect_objects, run_ocr
resolution.py — match_brands
escalation.py — escalate_vlm, escalate_cloud
output.py — compile_report
_serializers.py — shared serialization helpers
"""
from . import preprocessing
from . import detection
from . import resolution
from . import escalation
from . import output
def register_all():
preprocessing.register()
detection.register()
resolution.register()
escalation.register()
output.register()
register_all()

View File

@@ -0,0 +1,25 @@
"""
Re-export serializers from core/schema/serializers/.
Stage registry modules import from here for convenience.
All serialization logic lives in core/schema/serializers/.
"""
from core.schema.serializers._common import (
safe_construct,
serialize_dataclass,
serialize_dataclass_list,
)
from core.schema.serializers.detect_pipeline import (
serialize_frame_meta,
serialize_frames_with_upload as serialize_frames,
deserialize_frames_with_download as deserialize_frames,
serialize_text_candidate,
serialize_text_candidates,
deserialize_text_candidate,
deserialize_text_candidates,
deserialize_bounding_box,
deserialize_brand_detection,
deserialize_pipeline_stats,
deserialize_detection_report,
)

View File

@@ -0,0 +1,63 @@
"""Registration for detection stages: YOLO, OCR."""
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
from ._serializers import (
serialize_dataclass_list,
serialize_text_candidates,
deserialize_bounding_box,
)
def _ser_detect(state: dict, job_id: str) -> dict:
boxes = state.get("boxes_by_frame", {})
serialized = {str(seq): serialize_dataclass_list(bl) for seq, bl in boxes.items()}
return {"boxes_by_frame": serialized}
def _deser_detect(data: dict, job_id: str) -> dict:
boxes = {}
for seq_str, box_dicts in data.get("boxes_by_frame", {}).items():
boxes[int(seq_str)] = [deserialize_bounding_box(b) for b in box_dicts]
return {"boxes_by_frame": boxes}
def _ser_ocr(state: dict, job_id: str) -> dict:
candidates = state.get("text_candidates", [])
return {"text_candidates": serialize_text_candidates(candidates)}
def _deser_ocr(data: dict, job_id: str) -> dict:
return {"_text_candidates_raw": data["text_candidates"]}
def register():
yolo = StageDefinition(
name="detect_objects",
label="Object Detection",
description="YOLO object detection on filtered frames",
category="detection",
io=StageIO(reads=["filtered_frames"], writes=["boxes_by_frame"]),
config_fields=[
StageConfigField("model_name", "str", "yolov8n.pt", "YOLO model file"),
StageConfigField("confidence_threshold", "float", 0.3, "Min detection confidence", min=0.0, max=1.0),
StageConfigField("target_classes", "list[str]", [], "YOLO classes to detect (empty = all)"),
],
serialize_fn=_ser_detect,
deserialize_fn=_deser_detect,
)
register_stage(yolo)
ocr = StageDefinition(
name="run_ocr",
label="OCR",
description="Extract text from detected regions",
category="detection",
io=StageIO(reads=["filtered_frames", "boxes_by_frame"], writes=["text_candidates"]),
config_fields=[
StageConfigField("languages", "list[str]", ["en"], "OCR languages"),
StageConfigField("min_confidence", "float", 0.5, "Min OCR confidence", min=0.0, max=1.0),
],
serialize_fn=_ser_ocr,
deserialize_fn=_deser_ocr,
)
register_stage(ocr)

View File

@@ -0,0 +1,63 @@
"""Registration for escalation stages: local VLM, cloud LLM."""
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
from ._serializers import (
serialize_dataclass_list,
serialize_text_candidates,
deserialize_brand_detection,
)
def _ser_escalation(state: dict, job_id: str) -> dict:
matched = state.get("detections", [])
unresolved = state.get("unresolved_candidates", [])
return {
"detections": serialize_dataclass_list(matched),
"unresolved_candidates": serialize_text_candidates(unresolved),
}
def _deser_escalation(data: dict, job_id: str) -> dict:
detections = [deserialize_brand_detection(d) for d in data.get("detections", [])]
return {
"detections": detections,
"_unresolved_raw": data.get("unresolved_candidates", []),
}
def register():
vlm = StageDefinition(
name="escalate_vlm",
label="Local VLM",
description="Process unresolved crops with moondream2",
category="escalation",
io=StageIO(
reads=["unresolved_candidates"],
writes=["detections", "unresolved_candidates"],
optional_reads=["source_asset_id"],
),
config_fields=[
StageConfigField("min_confidence", "float", 0.5, "Min VLM confidence", min=0.0, max=1.0),
],
serialize_fn=_ser_escalation,
deserialize_fn=_deser_escalation,
)
register_stage(vlm)
cloud = StageDefinition(
name="escalate_cloud",
label="Cloud LLM",
description="Escalate remaining crops to cloud provider",
category="escalation",
io=StageIO(
reads=["unresolved_candidates"],
writes=["detections"],
optional_reads=["source_asset_id"],
),
config_fields=[
StageConfigField("min_confidence", "float", 0.4, "Min cloud confidence", min=0.0, max=1.0),
],
serialize_fn=_ser_escalation,
deserialize_fn=_deser_escalation,
)
register_stage(cloud)

View File

@@ -0,0 +1,32 @@
"""Registration for output stages: report compilation."""
from detect.stages.base import StageDefinition, StageIO, register_stage
from ._serializers import serialize_dataclass, deserialize_detection_report
def _ser_report(state: dict, job_id: str) -> dict:
report = state.get("report")
if report is None:
return {"report": None}
return {"report": serialize_dataclass(report)}
def _deser_report(data: dict, job_id: str) -> dict:
raw = data.get("report")
if raw is None:
return {"report": None}
return {"report": deserialize_detection_report(raw)}
def register():
report = StageDefinition(
name="compile_report",
label="Report",
description="Merge detections and compile final report",
category="output",
io=StageIO(reads=["detections"], writes=["report"]),
config_fields=[],
serialize_fn=_ser_report,
deserialize_fn=_deser_report,
)
register_stage(report)

View File

@@ -0,0 +1,57 @@
"""Registration for preprocessing stages: frame extraction, scene filter."""
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
from ._serializers import serialize_frames, deserialize_frames
def _ser_extract(state: dict, job_id: str) -> dict:
frames = state.get("frames", [])
meta, manifest = serialize_frames(frames, job_id)
return {"frames_meta": meta, "frames_manifest": manifest}
def _deser_extract(data: dict, job_id: str) -> dict:
frames = deserialize_frames(data["frames_meta"], data["frames_manifest"], job_id)
return {"frames": frames}
def _ser_filter(state: dict, job_id: str) -> dict:
filtered = state.get("filtered_frames", [])
seqs = [f.sequence for f in filtered]
return {"filtered_frame_sequences": seqs}
def _deser_filter(data: dict, job_id: str) -> dict:
return {"_filtered_sequences": data["filtered_frame_sequences"]}
def register():
extract = StageDefinition(
name="extract_frames",
label="Frame Extraction",
description="Extract frames from video at configurable FPS",
category="preprocessing",
io=StageIO(reads=["video_path"], writes=["frames"]),
config_fields=[
StageConfigField("fps", "float", 2.0, "Frames per second", min=0.1, max=30.0),
StageConfigField("max_frames", "int", 500, "Maximum frames to extract", min=1, max=10000),
],
serialize_fn=_ser_extract,
deserialize_fn=_deser_extract,
)
register_stage(extract)
scene_filter = StageDefinition(
name="filter_scenes",
label="Scene Filter",
description="Deduplicate similar frames using perceptual hashing",
category="preprocessing",
io=StageIO(reads=["frames"], writes=["filtered_frames"]),
config_fields=[
StageConfigField("hamming_threshold", "int", 8, "Hamming distance threshold", min=0, max=64),
StageConfigField("enabled", "bool", True, "Enable scene filtering"),
],
serialize_fn=_ser_filter,
deserialize_fn=_deser_filter,
)
register_stage(scene_filter)

View File

@@ -0,0 +1,45 @@
"""Registration for resolution stages: brand resolver."""
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
from ._serializers import (
serialize_dataclass_list,
serialize_text_candidates,
deserialize_brand_detection,
)
def _ser_brands(state: dict, job_id: str) -> dict:
matched = state.get("detections", [])
unresolved = state.get("unresolved_candidates", [])
return {
"detections": serialize_dataclass_list(matched),
"unresolved_candidates": serialize_text_candidates(unresolved),
}
def _deser_brands(data: dict, job_id: str) -> dict:
detections = [deserialize_brand_detection(d) for d in data.get("detections", [])]
return {
"detections": detections,
"_unresolved_raw": data.get("unresolved_candidates", []),
}
def register():
resolver = StageDefinition(
name="match_brands",
label="Brand Resolver",
description="Match OCR text against known brands (session + global DB)",
category="resolution",
io=StageIO(
reads=["text_candidates"],
writes=["detections", "unresolved_candidates"],
optional_reads=["session_brands", "source_asset_id"],
),
config_fields=[
StageConfigField("fuzzy_threshold", "int", 75, "Fuzzy match threshold", min=0, max=100),
],
serialize_fn=_ser_brands,
deserialize_fn=_deser_brands,
)
register_stage(resolver)