""" Stage protocol — common interface for all pipeline stages. Every stage declares: - IO: what it reads/writes from DetectState - Config: tunable parameters for the editor - Serialization: how to persist/restore its own outputs The checkpoint layer is a black box — it asks each stage to serialize its outputs and stores the result. Stages own their data format. Binary data (frames, crops) goes to S3 via the stage itself. The checkpoint just stores the JSON envelope. The graph builder uses StageIO to validate that a stage's inputs are satisfied by previous stages' outputs. """ from __future__ import annotations from dataclasses import dataclass, field from typing import Any, Callable @dataclass class StageIO: """Declares what a stage reads and writes from/to DetectState.""" reads: list[str] writes: list[str] optional_reads: list[str] = field(default_factory=list) @dataclass class StageConfigField: """A single tunable config parameter for the editor UI.""" name: str type: str # "float", "int", "str", "bool", "list[str]" default: Any description: str = "" min: float | None = None max: float | None = None options: list[str] | None = None @dataclass class StageDefinition: """ Complete metadata for a pipeline stage. The profile editor uses this to build the palette, generate config forms, and validate graph connections. The checkpoint uses serialize_fn and deserialize_fn to persist stage outputs without knowing the internals. """ name: str label: str description: str io: StageIO config_fields: list[StageConfigField] = field(default_factory=list) category: str = "detection" # The actual graph node function: (DetectState) → dict fn: Callable | None = None # Stage-owned serialization for checkpointing. # serialize_fn: (state: dict, job_id: str) → json-compatible dict # Stage picks its writes from state, serializes them. # Binary data (frames) → S3 via stage, returns refs. # deserialize_fn: (data: dict, job_id: str) → state update dict # Stage restores its writes from the persisted data. serialize_fn: Callable | None = None deserialize_fn: Callable | None = None # --------------------------------------------------------------------------- # Registry # --------------------------------------------------------------------------- _REGISTRY: dict[str, StageDefinition] = {} def register_stage(definition: StageDefinition): _REGISTRY[definition.name] = definition def get_stage(name: str) -> StageDefinition: if name not in _REGISTRY: raise KeyError(f"Unknown stage: {name!r}. Registered: {list(_REGISTRY)}") return _REGISTRY[name] def list_stages() -> list[StageDefinition]: return list(_REGISTRY.values()) def get_palette() -> dict[str, list[StageDefinition]]: """Group stages by category for the editor palette.""" palette: dict[str, list[StageDefinition]] = {} for stage in _REGISTRY.values(): if stage.category not in palette: palette[stage.category] = [] palette[stage.category].append(stage) return palette