refactor stage 1

This commit is contained in:
2026-03-27 04:23:21 -03:00
parent df6bcb01e8
commit 291ac8dd40
14 changed files with 688 additions and 450 deletions

View File

@@ -1,101 +1,131 @@
"""
Stage protocol — common interface for all pipeline stages.
Stage base class — common interface for all pipeline stages.
Every stage declares:
- IO: what it reads/writes from DetectState
- Config: tunable parameters for the editor
- Serialization: how to persist/restore its own outputs
Each stage is a file that subclasses Stage. Auto-discovered via
__init_subclass__. No manual registration needed.
The checkpoint layer is a black box — it asks each stage to serialize its
outputs and stores the result. Stages own their data format. Binary data
(frames, crops) goes to S3 via the stage itself. The checkpoint just
stores the JSON envelope.
A stage:
- Has a StageDefinition (from schema) with name, config, IO
- Implements run(frames, config) → output
- Owns its output serialization (opaque blob)
- Optionally has a TypeScript port for browser-side execution
The graph builder uses StageIO to validate that a stage's inputs are
satisfied by previous stages' outputs.
The checkpoint layer stores stage output as blobs without knowing
the format. The stage that wrote it is the only one that can read it.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Callable
from typing import Any
import numpy as np
@dataclass
class StageIO:
"""Declares what a stage reads and writes from/to DetectState."""
reads: list[str]
writes: list[str]
optional_reads: list[str] = field(default_factory=list)
@dataclass
class StageConfigField:
"""A single tunable config parameter for the editor UI."""
name: str
type: str # "float", "int", "str", "bool", "list[str]"
default: Any
description: str = ""
min: float | None = None
max: float | None = None
options: list[str] | None = None
@dataclass
class StageDefinition:
"""
Complete metadata for a pipeline stage.
The profile editor uses this to build the palette, generate config
forms, and validate graph connections. The checkpoint uses serialize_fn
and deserialize_fn to persist stage outputs without knowing the internals.
"""
name: str
label: str
description: str
io: StageIO
config_fields: list[StageConfigField] = field(default_factory=list)
category: str = "detection"
# The actual graph node function: (DetectState) → dict
fn: Callable | None = None
# Stage-owned serialization for checkpointing.
# serialize_fn: (state: dict, job_id: str) → json-compatible dict
# Stage picks its writes from state, serializes them.
# Binary data (frames) → S3 via stage, returns refs.
# deserialize_fn: (data: dict, job_id: str) → state update dict
# Stage restores its writes from the persisted data.
serialize_fn: Callable | None = None
deserialize_fn: Callable | None = None
from core.schema.models.stages import StageConfigField, StageIO, StageDefinition
# ---------------------------------------------------------------------------
# Registry
# Registry — auto-populated by __init_subclass__ (new stages)
# + register_stage() (legacy stages during migration)
# ---------------------------------------------------------------------------
_REGISTRY: dict[str, StageDefinition] = {}
_REGISTRY: dict[str, type['Stage']] = {}
_LEGACY_REGISTRY: dict[str, StageDefinition] = {}
def register_stage(definition: StageDefinition):
_REGISTRY[definition.name] = definition
"""Legacy registration for stages not yet converted to Stage subclass."""
_LEGACY_REGISTRY[definition.name] = definition
class Stage:
"""
Base class for all pipeline stages.
Subclass this in detect/stages/<name>.py. Define `definition` as a
class attribute. Implement `run()`. Optionally override `serialize()`
and `deserialize()` for custom blob formats (default is JSON).
"""
definition: StageDefinition # set by each subclass
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if hasattr(cls, 'definition') and cls.definition is not None:
_REGISTRY[cls.definition.name] = cls
def run(self, frames: list, config: dict) -> Any:
"""
Run the stage on a list of frames with the given config.
Config is a dict of parameter values (from slider UI or profile).
Returns the stage output — whatever shape this stage produces.
Debug overlays are included when config has debug=True.
"""
raise NotImplementedError
def serialize(self, output: Any) -> bytes:
"""Serialize stage output to bytes for checkpoint storage."""
import json
return json.dumps(output, default=str).encode()
def deserialize(self, data: bytes) -> Any:
"""Deserialize stage output from checkpoint blob."""
import json
return json.loads(data)
# ---------------------------------------------------------------------------
# Discovery API
# ---------------------------------------------------------------------------
def _all_definitions() -> dict[str, StageDefinition]:
"""Merge new Stage subclass registry + legacy registry."""
merged = {}
# Legacy first, new overwrites (new takes precedence)
for name, defn in _LEGACY_REGISTRY.items():
merged[name] = defn
for name, cls in _REGISTRY.items():
merged[name] = cls.definition
return merged
def get_stage(name: str) -> StageDefinition:
if name not in _REGISTRY:
raise KeyError(f"Unknown stage: {name!r}. Registered: {list(_REGISTRY)}")
return _REGISTRY[name]
"""Get a stage definition by name (works for both new and legacy)."""
all_defs = _all_definitions()
if name not in all_defs:
raise KeyError(f"Unknown stage: {name!r}. Registered: {list(all_defs)}")
return all_defs[name]
def get_stage_class(name: str) -> type[Stage] | None:
"""Get a Stage subclass by name. Returns None for legacy stages."""
return _REGISTRY.get(name)
def get_stage_instance(name: str) -> Stage:
"""Get an instantiated Stage by name. Only works for new-style stages."""
cls = _REGISTRY.get(name)
if cls is None:
raise KeyError(f"No Stage subclass for {name!r}. Legacy stages don't have instances.")
return cls()
def list_stages() -> list[StageDefinition]:
"""List all registered stage definitions (new + legacy)."""
return list(_all_definitions().values())
def list_stage_classes() -> list[type[Stage]]:
"""List all registered Stage subclasses (new-style only)."""
return list(_REGISTRY.values())
def get_palette() -> dict[str, list[StageDefinition]]:
"""Group stages by category for the editor palette."""
palette: dict[str, list[StageDefinition]] = {}
for stage in _REGISTRY.values():
if stage.category not in palette:
palette[stage.category] = []
palette[stage.category].append(stage)
for defn in _all_definitions().values():
if defn.category not in palette:
palette[defn.category] = []
palette[defn.category].append(defn)
return palette