schema clean up and refactor
This commit is contained in:
101
detect/stages/base.py
Normal file
101
detect/stages/base.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Stage protocol — common interface for all pipeline stages.
|
||||
|
||||
Every stage declares:
|
||||
- IO: what it reads/writes from DetectState
|
||||
- Config: tunable parameters for the editor
|
||||
- Serialization: how to persist/restore its own outputs
|
||||
|
||||
The checkpoint layer is a black box — it asks each stage to serialize its
|
||||
outputs and stores the result. Stages own their data format. Binary data
|
||||
(frames, crops) goes to S3 via the stage itself. The checkpoint just
|
||||
stores the JSON envelope.
|
||||
|
||||
The graph builder uses StageIO to validate that a stage's inputs are
|
||||
satisfied by previous stages' outputs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
@dataclass
|
||||
class StageIO:
|
||||
"""Declares what a stage reads and writes from/to DetectState."""
|
||||
reads: list[str]
|
||||
writes: list[str]
|
||||
optional_reads: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StageConfigField:
|
||||
"""A single tunable config parameter for the editor UI."""
|
||||
name: str
|
||||
type: str # "float", "int", "str", "bool", "list[str]"
|
||||
default: Any
|
||||
description: str = ""
|
||||
min: float | None = None
|
||||
max: float | None = None
|
||||
options: list[str] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class StageDefinition:
|
||||
"""
|
||||
Complete metadata for a pipeline stage.
|
||||
|
||||
The profile editor uses this to build the palette, generate config
|
||||
forms, and validate graph connections. The checkpoint uses serialize_fn
|
||||
and deserialize_fn to persist stage outputs without knowing the internals.
|
||||
"""
|
||||
name: str
|
||||
label: str
|
||||
description: str
|
||||
io: StageIO
|
||||
config_fields: list[StageConfigField] = field(default_factory=list)
|
||||
category: str = "detection"
|
||||
|
||||
# The actual graph node function: (DetectState) → dict
|
||||
fn: Callable | None = None
|
||||
|
||||
# Stage-owned serialization for checkpointing.
|
||||
# serialize_fn: (state: dict, job_id: str) → json-compatible dict
|
||||
# Stage picks its writes from state, serializes them.
|
||||
# Binary data (frames) → S3 via stage, returns refs.
|
||||
# deserialize_fn: (data: dict, job_id: str) → state update dict
|
||||
# Stage restores its writes from the persisted data.
|
||||
serialize_fn: Callable | None = None
|
||||
deserialize_fn: Callable | None = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_REGISTRY: dict[str, StageDefinition] = {}
|
||||
|
||||
|
||||
def register_stage(definition: StageDefinition):
|
||||
_REGISTRY[definition.name] = definition
|
||||
|
||||
|
||||
def get_stage(name: str) -> StageDefinition:
|
||||
if name not in _REGISTRY:
|
||||
raise KeyError(f"Unknown stage: {name!r}. Registered: {list(_REGISTRY)}")
|
||||
return _REGISTRY[name]
|
||||
|
||||
|
||||
def list_stages() -> list[StageDefinition]:
|
||||
return list(_REGISTRY.values())
|
||||
|
||||
|
||||
def get_palette() -> dict[str, list[StageDefinition]]:
|
||||
"""Group stages by category for the editor palette."""
|
||||
palette: dict[str, list[StageDefinition]] = {}
|
||||
for stage in _REGISTRY.values():
|
||||
if stage.category not in palette:
|
||||
palette[stage.category] = []
|
||||
palette[stage.category].append(stage)
|
||||
return palette
|
||||
Reference in New Issue
Block a user