schema clean up and refactor

This commit is contained in:
2026-03-26 05:14:33 -03:00
parent 08c58a6a9d
commit d58a90157a
17 changed files with 930 additions and 287 deletions

View File

@@ -0,0 +1,28 @@
"""
Stage registry — registers all built-in stages.
Split by category:
preprocessing.py — extract_frames, filter_scenes
detection.py — detect_objects, run_ocr
resolution.py — match_brands
escalation.py — escalate_vlm, escalate_cloud
output.py — compile_report
_serializers.py — shared serialization helpers
"""
from . import preprocessing
from . import detection
from . import resolution
from . import escalation
from . import output
def register_all():
preprocessing.register()
detection.register()
resolution.register()
escalation.register()
output.register()
register_all()

View File

@@ -0,0 +1,25 @@
"""
Re-export serializers from core/schema/serializers/.
Stage registry modules import from here for convenience.
All serialization logic lives in core/schema/serializers/.
"""
from core.schema.serializers._common import (
safe_construct,
serialize_dataclass,
serialize_dataclass_list,
)
from core.schema.serializers.detect_pipeline import (
serialize_frame_meta,
serialize_frames_with_upload as serialize_frames,
deserialize_frames_with_download as deserialize_frames,
serialize_text_candidate,
serialize_text_candidates,
deserialize_text_candidate,
deserialize_text_candidates,
deserialize_bounding_box,
deserialize_brand_detection,
deserialize_pipeline_stats,
deserialize_detection_report,
)

View File

@@ -0,0 +1,63 @@
"""Registration for detection stages: YOLO, OCR."""
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
from ._serializers import (
serialize_dataclass_list,
serialize_text_candidates,
deserialize_bounding_box,
)
def _ser_detect(state: dict, job_id: str) -> dict:
boxes = state.get("boxes_by_frame", {})
serialized = {str(seq): serialize_dataclass_list(bl) for seq, bl in boxes.items()}
return {"boxes_by_frame": serialized}
def _deser_detect(data: dict, job_id: str) -> dict:
boxes = {}
for seq_str, box_dicts in data.get("boxes_by_frame", {}).items():
boxes[int(seq_str)] = [deserialize_bounding_box(b) for b in box_dicts]
return {"boxes_by_frame": boxes}
def _ser_ocr(state: dict, job_id: str) -> dict:
candidates = state.get("text_candidates", [])
return {"text_candidates": serialize_text_candidates(candidates)}
def _deser_ocr(data: dict, job_id: str) -> dict:
return {"_text_candidates_raw": data["text_candidates"]}
def register():
yolo = StageDefinition(
name="detect_objects",
label="Object Detection",
description="YOLO object detection on filtered frames",
category="detection",
io=StageIO(reads=["filtered_frames"], writes=["boxes_by_frame"]),
config_fields=[
StageConfigField("model_name", "str", "yolov8n.pt", "YOLO model file"),
StageConfigField("confidence_threshold", "float", 0.3, "Min detection confidence", min=0.0, max=1.0),
StageConfigField("target_classes", "list[str]", [], "YOLO classes to detect (empty = all)"),
],
serialize_fn=_ser_detect,
deserialize_fn=_deser_detect,
)
register_stage(yolo)
ocr = StageDefinition(
name="run_ocr",
label="OCR",
description="Extract text from detected regions",
category="detection",
io=StageIO(reads=["filtered_frames", "boxes_by_frame"], writes=["text_candidates"]),
config_fields=[
StageConfigField("languages", "list[str]", ["en"], "OCR languages"),
StageConfigField("min_confidence", "float", 0.5, "Min OCR confidence", min=0.0, max=1.0),
],
serialize_fn=_ser_ocr,
deserialize_fn=_deser_ocr,
)
register_stage(ocr)

View File

@@ -0,0 +1,63 @@
"""Registration for escalation stages: local VLM, cloud LLM."""
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
from ._serializers import (
serialize_dataclass_list,
serialize_text_candidates,
deserialize_brand_detection,
)
def _ser_escalation(state: dict, job_id: str) -> dict:
matched = state.get("detections", [])
unresolved = state.get("unresolved_candidates", [])
return {
"detections": serialize_dataclass_list(matched),
"unresolved_candidates": serialize_text_candidates(unresolved),
}
def _deser_escalation(data: dict, job_id: str) -> dict:
detections = [deserialize_brand_detection(d) for d in data.get("detections", [])]
return {
"detections": detections,
"_unresolved_raw": data.get("unresolved_candidates", []),
}
def register():
vlm = StageDefinition(
name="escalate_vlm",
label="Local VLM",
description="Process unresolved crops with moondream2",
category="escalation",
io=StageIO(
reads=["unresolved_candidates"],
writes=["detections", "unresolved_candidates"],
optional_reads=["source_asset_id"],
),
config_fields=[
StageConfigField("min_confidence", "float", 0.5, "Min VLM confidence", min=0.0, max=1.0),
],
serialize_fn=_ser_escalation,
deserialize_fn=_deser_escalation,
)
register_stage(vlm)
cloud = StageDefinition(
name="escalate_cloud",
label="Cloud LLM",
description="Escalate remaining crops to cloud provider",
category="escalation",
io=StageIO(
reads=["unresolved_candidates"],
writes=["detections"],
optional_reads=["source_asset_id"],
),
config_fields=[
StageConfigField("min_confidence", "float", 0.4, "Min cloud confidence", min=0.0, max=1.0),
],
serialize_fn=_ser_escalation,
deserialize_fn=_deser_escalation,
)
register_stage(cloud)

View File

@@ -0,0 +1,32 @@
"""Registration for output stages: report compilation."""
from detect.stages.base import StageDefinition, StageIO, register_stage
from ._serializers import serialize_dataclass, deserialize_detection_report
def _ser_report(state: dict, job_id: str) -> dict:
report = state.get("report")
if report is None:
return {"report": None}
return {"report": serialize_dataclass(report)}
def _deser_report(data: dict, job_id: str) -> dict:
raw = data.get("report")
if raw is None:
return {"report": None}
return {"report": deserialize_detection_report(raw)}
def register():
report = StageDefinition(
name="compile_report",
label="Report",
description="Merge detections and compile final report",
category="output",
io=StageIO(reads=["detections"], writes=["report"]),
config_fields=[],
serialize_fn=_ser_report,
deserialize_fn=_deser_report,
)
register_stage(report)

View File

@@ -0,0 +1,57 @@
"""Registration for preprocessing stages: frame extraction, scene filter."""
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
from ._serializers import serialize_frames, deserialize_frames
def _ser_extract(state: dict, job_id: str) -> dict:
frames = state.get("frames", [])
meta, manifest = serialize_frames(frames, job_id)
return {"frames_meta": meta, "frames_manifest": manifest}
def _deser_extract(data: dict, job_id: str) -> dict:
frames = deserialize_frames(data["frames_meta"], data["frames_manifest"], job_id)
return {"frames": frames}
def _ser_filter(state: dict, job_id: str) -> dict:
filtered = state.get("filtered_frames", [])
seqs = [f.sequence for f in filtered]
return {"filtered_frame_sequences": seqs}
def _deser_filter(data: dict, job_id: str) -> dict:
return {"_filtered_sequences": data["filtered_frame_sequences"]}
def register():
extract = StageDefinition(
name="extract_frames",
label="Frame Extraction",
description="Extract frames from video at configurable FPS",
category="preprocessing",
io=StageIO(reads=["video_path"], writes=["frames"]),
config_fields=[
StageConfigField("fps", "float", 2.0, "Frames per second", min=0.1, max=30.0),
StageConfigField("max_frames", "int", 500, "Maximum frames to extract", min=1, max=10000),
],
serialize_fn=_ser_extract,
deserialize_fn=_deser_extract,
)
register_stage(extract)
scene_filter = StageDefinition(
name="filter_scenes",
label="Scene Filter",
description="Deduplicate similar frames using perceptual hashing",
category="preprocessing",
io=StageIO(reads=["frames"], writes=["filtered_frames"]),
config_fields=[
StageConfigField("hamming_threshold", "int", 8, "Hamming distance threshold", min=0, max=64),
StageConfigField("enabled", "bool", True, "Enable scene filtering"),
],
serialize_fn=_ser_filter,
deserialize_fn=_deser_filter,
)
register_stage(scene_filter)

View File

@@ -0,0 +1,45 @@
"""Registration for resolution stages: brand resolver."""
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
from ._serializers import (
serialize_dataclass_list,
serialize_text_candidates,
deserialize_brand_detection,
)
def _ser_brands(state: dict, job_id: str) -> dict:
matched = state.get("detections", [])
unresolved = state.get("unresolved_candidates", [])
return {
"detections": serialize_dataclass_list(matched),
"unresolved_candidates": serialize_text_candidates(unresolved),
}
def _deser_brands(data: dict, job_id: str) -> dict:
detections = [deserialize_brand_detection(d) for d in data.get("detections", [])]
return {
"detections": detections,
"_unresolved_raw": data.get("unresolved_candidates", []),
}
def register():
resolver = StageDefinition(
name="match_brands",
label="Brand Resolver",
description="Match OCR text against known brands (session + global DB)",
category="resolution",
io=StageIO(
reads=["text_candidates"],
writes=["detections", "unresolved_candidates"],
optional_reads=["session_brands", "source_asset_id"],
),
config_fields=[
StageConfigField("fuzzy_threshold", "int", 75, "Fuzzy match threshold", min=0, max=100),
],
serialize_fn=_ser_brands,
deserialize_fn=_deser_brands,
)
register_stage(resolver)