102 lines
3.1 KiB
Python
102 lines
3.1 KiB
Python
"""
|
|
Stage protocol — common interface for all pipeline stages.
|
|
|
|
Every stage declares:
|
|
- IO: what it reads/writes from DetectState
|
|
- Config: tunable parameters for the editor
|
|
- Serialization: how to persist/restore its own outputs
|
|
|
|
The checkpoint layer is a black box — it asks each stage to serialize its
|
|
outputs and stores the result. Stages own their data format. Binary data
|
|
(frames, crops) goes to S3 via the stage itself. The checkpoint just
|
|
stores the JSON envelope.
|
|
|
|
The graph builder uses StageIO to validate that a stage's inputs are
|
|
satisfied by previous stages' outputs.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Callable
|
|
|
|
|
|
@dataclass
|
|
class StageIO:
|
|
"""Declares what a stage reads and writes from/to DetectState."""
|
|
reads: list[str]
|
|
writes: list[str]
|
|
optional_reads: list[str] = field(default_factory=list)
|
|
|
|
|
|
@dataclass
|
|
class StageConfigField:
|
|
"""A single tunable config parameter for the editor UI."""
|
|
name: str
|
|
type: str # "float", "int", "str", "bool", "list[str]"
|
|
default: Any
|
|
description: str = ""
|
|
min: float | None = None
|
|
max: float | None = None
|
|
options: list[str] | None = None
|
|
|
|
|
|
@dataclass
|
|
class StageDefinition:
|
|
"""
|
|
Complete metadata for a pipeline stage.
|
|
|
|
The profile editor uses this to build the palette, generate config
|
|
forms, and validate graph connections. The checkpoint uses serialize_fn
|
|
and deserialize_fn to persist stage outputs without knowing the internals.
|
|
"""
|
|
name: str
|
|
label: str
|
|
description: str
|
|
io: StageIO
|
|
config_fields: list[StageConfigField] = field(default_factory=list)
|
|
category: str = "detection"
|
|
|
|
# The actual graph node function: (DetectState) → dict
|
|
fn: Callable | None = None
|
|
|
|
# Stage-owned serialization for checkpointing.
|
|
# serialize_fn: (state: dict, job_id: str) → json-compatible dict
|
|
# Stage picks its writes from state, serializes them.
|
|
# Binary data (frames) → S3 via stage, returns refs.
|
|
# deserialize_fn: (data: dict, job_id: str) → state update dict
|
|
# Stage restores its writes from the persisted data.
|
|
serialize_fn: Callable | None = None
|
|
deserialize_fn: Callable | None = None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Registry
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_REGISTRY: dict[str, StageDefinition] = {}
|
|
|
|
|
|
def register_stage(definition: StageDefinition):
|
|
_REGISTRY[definition.name] = definition
|
|
|
|
|
|
def get_stage(name: str) -> StageDefinition:
|
|
if name not in _REGISTRY:
|
|
raise KeyError(f"Unknown stage: {name!r}. Registered: {list(_REGISTRY)}")
|
|
return _REGISTRY[name]
|
|
|
|
|
|
def list_stages() -> list[StageDefinition]:
|
|
return list(_REGISTRY.values())
|
|
|
|
|
|
def get_palette() -> dict[str, list[StageDefinition]]:
|
|
"""Group stages by category for the editor palette."""
|
|
palette: dict[str, list[StageDefinition]] = {}
|
|
for stage in _REGISTRY.values():
|
|
if stage.category not in palette:
|
|
palette[stage.category] = []
|
|
palette[stage.category].append(stage)
|
|
return palette
|