schema clean up and refactor
This commit is contained in:
28
detect/stages/registry/__init__.py
Normal file
28
detect/stages/registry/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
Stage registry — registers all built-in stages.
|
||||
|
||||
Split by category:
|
||||
preprocessing.py — extract_frames, filter_scenes
|
||||
detection.py — detect_objects, run_ocr
|
||||
resolution.py — match_brands
|
||||
escalation.py — escalate_vlm, escalate_cloud
|
||||
output.py — compile_report
|
||||
_serializers.py — shared serialization helpers
|
||||
"""
|
||||
|
||||
from . import preprocessing
|
||||
from . import detection
|
||||
from . import resolution
|
||||
from . import escalation
|
||||
from . import output
|
||||
|
||||
|
||||
def register_all():
|
||||
preprocessing.register()
|
||||
detection.register()
|
||||
resolution.register()
|
||||
escalation.register()
|
||||
output.register()
|
||||
|
||||
|
||||
register_all()
|
||||
25
detect/stages/registry/_serializers.py
Normal file
25
detect/stages/registry/_serializers.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""
|
||||
Re-export serializers from core/schema/serializers/.
|
||||
|
||||
Stage registry modules import from here for convenience.
|
||||
All serialization logic lives in core/schema/serializers/.
|
||||
"""
|
||||
|
||||
from core.schema.serializers._common import (
|
||||
safe_construct,
|
||||
serialize_dataclass,
|
||||
serialize_dataclass_list,
|
||||
)
|
||||
from core.schema.serializers.detect_pipeline import (
|
||||
serialize_frame_meta,
|
||||
serialize_frames_with_upload as serialize_frames,
|
||||
deserialize_frames_with_download as deserialize_frames,
|
||||
serialize_text_candidate,
|
||||
serialize_text_candidates,
|
||||
deserialize_text_candidate,
|
||||
deserialize_text_candidates,
|
||||
deserialize_bounding_box,
|
||||
deserialize_brand_detection,
|
||||
deserialize_pipeline_stats,
|
||||
deserialize_detection_report,
|
||||
)
|
||||
63
detect/stages/registry/detection.py
Normal file
63
detect/stages/registry/detection.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Registration for detection stages: YOLO, OCR."""
|
||||
|
||||
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
|
||||
from ._serializers import (
|
||||
serialize_dataclass_list,
|
||||
serialize_text_candidates,
|
||||
deserialize_bounding_box,
|
||||
)
|
||||
|
||||
|
||||
def _ser_detect(state: dict, job_id: str) -> dict:
|
||||
boxes = state.get("boxes_by_frame", {})
|
||||
serialized = {str(seq): serialize_dataclass_list(bl) for seq, bl in boxes.items()}
|
||||
return {"boxes_by_frame": serialized}
|
||||
|
||||
|
||||
def _deser_detect(data: dict, job_id: str) -> dict:
|
||||
boxes = {}
|
||||
for seq_str, box_dicts in data.get("boxes_by_frame", {}).items():
|
||||
boxes[int(seq_str)] = [deserialize_bounding_box(b) for b in box_dicts]
|
||||
return {"boxes_by_frame": boxes}
|
||||
|
||||
|
||||
def _ser_ocr(state: dict, job_id: str) -> dict:
|
||||
candidates = state.get("text_candidates", [])
|
||||
return {"text_candidates": serialize_text_candidates(candidates)}
|
||||
|
||||
|
||||
def _deser_ocr(data: dict, job_id: str) -> dict:
|
||||
return {"_text_candidates_raw": data["text_candidates"]}
|
||||
|
||||
|
||||
def register():
|
||||
yolo = StageDefinition(
|
||||
name="detect_objects",
|
||||
label="Object Detection",
|
||||
description="YOLO object detection on filtered frames",
|
||||
category="detection",
|
||||
io=StageIO(reads=["filtered_frames"], writes=["boxes_by_frame"]),
|
||||
config_fields=[
|
||||
StageConfigField("model_name", "str", "yolov8n.pt", "YOLO model file"),
|
||||
StageConfigField("confidence_threshold", "float", 0.3, "Min detection confidence", min=0.0, max=1.0),
|
||||
StageConfigField("target_classes", "list[str]", [], "YOLO classes to detect (empty = all)"),
|
||||
],
|
||||
serialize_fn=_ser_detect,
|
||||
deserialize_fn=_deser_detect,
|
||||
)
|
||||
register_stage(yolo)
|
||||
|
||||
ocr = StageDefinition(
|
||||
name="run_ocr",
|
||||
label="OCR",
|
||||
description="Extract text from detected regions",
|
||||
category="detection",
|
||||
io=StageIO(reads=["filtered_frames", "boxes_by_frame"], writes=["text_candidates"]),
|
||||
config_fields=[
|
||||
StageConfigField("languages", "list[str]", ["en"], "OCR languages"),
|
||||
StageConfigField("min_confidence", "float", 0.5, "Min OCR confidence", min=0.0, max=1.0),
|
||||
],
|
||||
serialize_fn=_ser_ocr,
|
||||
deserialize_fn=_deser_ocr,
|
||||
)
|
||||
register_stage(ocr)
|
||||
63
detect/stages/registry/escalation.py
Normal file
63
detect/stages/registry/escalation.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Registration for escalation stages: local VLM, cloud LLM."""
|
||||
|
||||
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
|
||||
from ._serializers import (
|
||||
serialize_dataclass_list,
|
||||
serialize_text_candidates,
|
||||
deserialize_brand_detection,
|
||||
)
|
||||
|
||||
|
||||
def _ser_escalation(state: dict, job_id: str) -> dict:
|
||||
matched = state.get("detections", [])
|
||||
unresolved = state.get("unresolved_candidates", [])
|
||||
return {
|
||||
"detections": serialize_dataclass_list(matched),
|
||||
"unresolved_candidates": serialize_text_candidates(unresolved),
|
||||
}
|
||||
|
||||
|
||||
def _deser_escalation(data: dict, job_id: str) -> dict:
|
||||
detections = [deserialize_brand_detection(d) for d in data.get("detections", [])]
|
||||
return {
|
||||
"detections": detections,
|
||||
"_unresolved_raw": data.get("unresolved_candidates", []),
|
||||
}
|
||||
|
||||
|
||||
def register():
|
||||
vlm = StageDefinition(
|
||||
name="escalate_vlm",
|
||||
label="Local VLM",
|
||||
description="Process unresolved crops with moondream2",
|
||||
category="escalation",
|
||||
io=StageIO(
|
||||
reads=["unresolved_candidates"],
|
||||
writes=["detections", "unresolved_candidates"],
|
||||
optional_reads=["source_asset_id"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField("min_confidence", "float", 0.5, "Min VLM confidence", min=0.0, max=1.0),
|
||||
],
|
||||
serialize_fn=_ser_escalation,
|
||||
deserialize_fn=_deser_escalation,
|
||||
)
|
||||
register_stage(vlm)
|
||||
|
||||
cloud = StageDefinition(
|
||||
name="escalate_cloud",
|
||||
label="Cloud LLM",
|
||||
description="Escalate remaining crops to cloud provider",
|
||||
category="escalation",
|
||||
io=StageIO(
|
||||
reads=["unresolved_candidates"],
|
||||
writes=["detections"],
|
||||
optional_reads=["source_asset_id"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField("min_confidence", "float", 0.4, "Min cloud confidence", min=0.0, max=1.0),
|
||||
],
|
||||
serialize_fn=_ser_escalation,
|
||||
deserialize_fn=_deser_escalation,
|
||||
)
|
||||
register_stage(cloud)
|
||||
32
detect/stages/registry/output.py
Normal file
32
detect/stages/registry/output.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Registration for output stages: report compilation."""
|
||||
|
||||
from detect.stages.base import StageDefinition, StageIO, register_stage
|
||||
from ._serializers import serialize_dataclass, deserialize_detection_report
|
||||
|
||||
|
||||
def _ser_report(state: dict, job_id: str) -> dict:
|
||||
report = state.get("report")
|
||||
if report is None:
|
||||
return {"report": None}
|
||||
return {"report": serialize_dataclass(report)}
|
||||
|
||||
|
||||
def _deser_report(data: dict, job_id: str) -> dict:
|
||||
raw = data.get("report")
|
||||
if raw is None:
|
||||
return {"report": None}
|
||||
return {"report": deserialize_detection_report(raw)}
|
||||
|
||||
|
||||
def register():
|
||||
report = StageDefinition(
|
||||
name="compile_report",
|
||||
label="Report",
|
||||
description="Merge detections and compile final report",
|
||||
category="output",
|
||||
io=StageIO(reads=["detections"], writes=["report"]),
|
||||
config_fields=[],
|
||||
serialize_fn=_ser_report,
|
||||
deserialize_fn=_deser_report,
|
||||
)
|
||||
register_stage(report)
|
||||
57
detect/stages/registry/preprocessing.py
Normal file
57
detect/stages/registry/preprocessing.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Registration for preprocessing stages: frame extraction, scene filter."""
|
||||
|
||||
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
|
||||
from ._serializers import serialize_frames, deserialize_frames
|
||||
|
||||
|
||||
def _ser_extract(state: dict, job_id: str) -> dict:
|
||||
frames = state.get("frames", [])
|
||||
meta, manifest = serialize_frames(frames, job_id)
|
||||
return {"frames_meta": meta, "frames_manifest": manifest}
|
||||
|
||||
|
||||
def _deser_extract(data: dict, job_id: str) -> dict:
|
||||
frames = deserialize_frames(data["frames_meta"], data["frames_manifest"], job_id)
|
||||
return {"frames": frames}
|
||||
|
||||
|
||||
def _ser_filter(state: dict, job_id: str) -> dict:
|
||||
filtered = state.get("filtered_frames", [])
|
||||
seqs = [f.sequence for f in filtered]
|
||||
return {"filtered_frame_sequences": seqs}
|
||||
|
||||
|
||||
def _deser_filter(data: dict, job_id: str) -> dict:
|
||||
return {"_filtered_sequences": data["filtered_frame_sequences"]}
|
||||
|
||||
|
||||
def register():
|
||||
extract = StageDefinition(
|
||||
name="extract_frames",
|
||||
label="Frame Extraction",
|
||||
description="Extract frames from video at configurable FPS",
|
||||
category="preprocessing",
|
||||
io=StageIO(reads=["video_path"], writes=["frames"]),
|
||||
config_fields=[
|
||||
StageConfigField("fps", "float", 2.0, "Frames per second", min=0.1, max=30.0),
|
||||
StageConfigField("max_frames", "int", 500, "Maximum frames to extract", min=1, max=10000),
|
||||
],
|
||||
serialize_fn=_ser_extract,
|
||||
deserialize_fn=_deser_extract,
|
||||
)
|
||||
register_stage(extract)
|
||||
|
||||
scene_filter = StageDefinition(
|
||||
name="filter_scenes",
|
||||
label="Scene Filter",
|
||||
description="Deduplicate similar frames using perceptual hashing",
|
||||
category="preprocessing",
|
||||
io=StageIO(reads=["frames"], writes=["filtered_frames"]),
|
||||
config_fields=[
|
||||
StageConfigField("hamming_threshold", "int", 8, "Hamming distance threshold", min=0, max=64),
|
||||
StageConfigField("enabled", "bool", True, "Enable scene filtering"),
|
||||
],
|
||||
serialize_fn=_ser_filter,
|
||||
deserialize_fn=_deser_filter,
|
||||
)
|
||||
register_stage(scene_filter)
|
||||
45
detect/stages/registry/resolution.py
Normal file
45
detect/stages/registry/resolution.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Registration for resolution stages: brand resolver."""
|
||||
|
||||
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
|
||||
from ._serializers import (
|
||||
serialize_dataclass_list,
|
||||
serialize_text_candidates,
|
||||
deserialize_brand_detection,
|
||||
)
|
||||
|
||||
|
||||
def _ser_brands(state: dict, job_id: str) -> dict:
|
||||
matched = state.get("detections", [])
|
||||
unresolved = state.get("unresolved_candidates", [])
|
||||
return {
|
||||
"detections": serialize_dataclass_list(matched),
|
||||
"unresolved_candidates": serialize_text_candidates(unresolved),
|
||||
}
|
||||
|
||||
|
||||
def _deser_brands(data: dict, job_id: str) -> dict:
|
||||
detections = [deserialize_brand_detection(d) for d in data.get("detections", [])]
|
||||
return {
|
||||
"detections": detections,
|
||||
"_unresolved_raw": data.get("unresolved_candidates", []),
|
||||
}
|
||||
|
||||
|
||||
def register():
|
||||
resolver = StageDefinition(
|
||||
name="match_brands",
|
||||
label="Brand Resolver",
|
||||
description="Match OCR text against known brands (session + global DB)",
|
||||
category="resolution",
|
||||
io=StageIO(
|
||||
reads=["text_candidates"],
|
||||
writes=["detections", "unresolved_candidates"],
|
||||
optional_reads=["session_brands", "source_asset_id"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField("fuzzy_threshold", "int", 75, "Fuzzy match threshold", min=0, max=100),
|
||||
],
|
||||
serialize_fn=_ser_brands,
|
||||
deserialize_fn=_deser_brands,
|
||||
)
|
||||
register_stage(resolver)
|
||||
Reference in New Issue
Block a user