refactor stage 1
This commit is contained in:
@@ -1,101 +1,131 @@
|
||||
"""
|
||||
Stage protocol — common interface for all pipeline stages.
|
||||
Stage base class — common interface for all pipeline stages.
|
||||
|
||||
Every stage declares:
|
||||
- IO: what it reads/writes from DetectState
|
||||
- Config: tunable parameters for the editor
|
||||
- Serialization: how to persist/restore its own outputs
|
||||
Each stage is a file that subclasses Stage. Auto-discovered via
|
||||
__init_subclass__. No manual registration needed.
|
||||
|
||||
The checkpoint layer is a black box — it asks each stage to serialize its
|
||||
outputs and stores the result. Stages own their data format. Binary data
|
||||
(frames, crops) goes to S3 via the stage itself. The checkpoint just
|
||||
stores the JSON envelope.
|
||||
A stage:
|
||||
- Has a StageDefinition (from schema) with name, config, IO
|
||||
- Implements run(frames, config) → output
|
||||
- Owns its output serialization (opaque blob)
|
||||
- Optionally has a TypeScript port for browser-side execution
|
||||
|
||||
The graph builder uses StageIO to validate that a stage's inputs are
|
||||
satisfied by previous stages' outputs.
|
||||
The checkpoint layer stores stage output as blobs without knowing
|
||||
the format. The stage that wrote it is the only one that can read it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
@dataclass
|
||||
class StageIO:
|
||||
"""Declares what a stage reads and writes from/to DetectState."""
|
||||
reads: list[str]
|
||||
writes: list[str]
|
||||
optional_reads: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StageConfigField:
|
||||
"""A single tunable config parameter for the editor UI."""
|
||||
name: str
|
||||
type: str # "float", "int", "str", "bool", "list[str]"
|
||||
default: Any
|
||||
description: str = ""
|
||||
min: float | None = None
|
||||
max: float | None = None
|
||||
options: list[str] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class StageDefinition:
|
||||
"""
|
||||
Complete metadata for a pipeline stage.
|
||||
|
||||
The profile editor uses this to build the palette, generate config
|
||||
forms, and validate graph connections. The checkpoint uses serialize_fn
|
||||
and deserialize_fn to persist stage outputs without knowing the internals.
|
||||
"""
|
||||
name: str
|
||||
label: str
|
||||
description: str
|
||||
io: StageIO
|
||||
config_fields: list[StageConfigField] = field(default_factory=list)
|
||||
category: str = "detection"
|
||||
|
||||
# The actual graph node function: (DetectState) → dict
|
||||
fn: Callable | None = None
|
||||
|
||||
# Stage-owned serialization for checkpointing.
|
||||
# serialize_fn: (state: dict, job_id: str) → json-compatible dict
|
||||
# Stage picks its writes from state, serializes them.
|
||||
# Binary data (frames) → S3 via stage, returns refs.
|
||||
# deserialize_fn: (data: dict, job_id: str) → state update dict
|
||||
# Stage restores its writes from the persisted data.
|
||||
serialize_fn: Callable | None = None
|
||||
deserialize_fn: Callable | None = None
|
||||
from core.schema.models.stages import StageConfigField, StageIO, StageDefinition
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry
|
||||
# Registry — auto-populated by __init_subclass__ (new stages)
|
||||
# + register_stage() (legacy stages during migration)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_REGISTRY: dict[str, StageDefinition] = {}
|
||||
_REGISTRY: dict[str, type['Stage']] = {}
|
||||
_LEGACY_REGISTRY: dict[str, StageDefinition] = {}
|
||||
|
||||
|
||||
def register_stage(definition: StageDefinition):
|
||||
_REGISTRY[definition.name] = definition
|
||||
"""Legacy registration for stages not yet converted to Stage subclass."""
|
||||
_LEGACY_REGISTRY[definition.name] = definition
|
||||
|
||||
|
||||
class Stage:
|
||||
"""
|
||||
Base class for all pipeline stages.
|
||||
|
||||
Subclass this in detect/stages/<name>.py. Define `definition` as a
|
||||
class attribute. Implement `run()`. Optionally override `serialize()`
|
||||
and `deserialize()` for custom blob formats (default is JSON).
|
||||
"""
|
||||
|
||||
definition: StageDefinition # set by each subclass
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if hasattr(cls, 'definition') and cls.definition is not None:
|
||||
_REGISTRY[cls.definition.name] = cls
|
||||
|
||||
def run(self, frames: list, config: dict) -> Any:
|
||||
"""
|
||||
Run the stage on a list of frames with the given config.
|
||||
|
||||
Config is a dict of parameter values (from slider UI or profile).
|
||||
Returns the stage output — whatever shape this stage produces.
|
||||
Debug overlays are included when config has debug=True.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def serialize(self, output: Any) -> bytes:
|
||||
"""Serialize stage output to bytes for checkpoint storage."""
|
||||
import json
|
||||
return json.dumps(output, default=str).encode()
|
||||
|
||||
def deserialize(self, data: bytes) -> Any:
|
||||
"""Deserialize stage output from checkpoint blob."""
|
||||
import json
|
||||
return json.loads(data)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Discovery API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _all_definitions() -> dict[str, StageDefinition]:
|
||||
"""Merge new Stage subclass registry + legacy registry."""
|
||||
merged = {}
|
||||
# Legacy first, new overwrites (new takes precedence)
|
||||
for name, defn in _LEGACY_REGISTRY.items():
|
||||
merged[name] = defn
|
||||
for name, cls in _REGISTRY.items():
|
||||
merged[name] = cls.definition
|
||||
return merged
|
||||
|
||||
|
||||
def get_stage(name: str) -> StageDefinition:
|
||||
if name not in _REGISTRY:
|
||||
raise KeyError(f"Unknown stage: {name!r}. Registered: {list(_REGISTRY)}")
|
||||
return _REGISTRY[name]
|
||||
"""Get a stage definition by name (works for both new and legacy)."""
|
||||
all_defs = _all_definitions()
|
||||
if name not in all_defs:
|
||||
raise KeyError(f"Unknown stage: {name!r}. Registered: {list(all_defs)}")
|
||||
return all_defs[name]
|
||||
|
||||
|
||||
def get_stage_class(name: str) -> type[Stage] | None:
|
||||
"""Get a Stage subclass by name. Returns None for legacy stages."""
|
||||
return _REGISTRY.get(name)
|
||||
|
||||
|
||||
def get_stage_instance(name: str) -> Stage:
|
||||
"""Get an instantiated Stage by name. Only works for new-style stages."""
|
||||
cls = _REGISTRY.get(name)
|
||||
if cls is None:
|
||||
raise KeyError(f"No Stage subclass for {name!r}. Legacy stages don't have instances.")
|
||||
return cls()
|
||||
|
||||
|
||||
def list_stages() -> list[StageDefinition]:
|
||||
"""List all registered stage definitions (new + legacy)."""
|
||||
return list(_all_definitions().values())
|
||||
|
||||
|
||||
def list_stage_classes() -> list[type[Stage]]:
|
||||
"""List all registered Stage subclasses (new-style only)."""
|
||||
return list(_REGISTRY.values())
|
||||
|
||||
|
||||
def get_palette() -> dict[str, list[StageDefinition]]:
|
||||
"""Group stages by category for the editor palette."""
|
||||
palette: dict[str, list[StageDefinition]] = {}
|
||||
for stage in _REGISTRY.values():
|
||||
if stage.category not in palette:
|
||||
palette[stage.category] = []
|
||||
palette[stage.category].append(stage)
|
||||
for defn in _all_definitions().values():
|
||||
if defn.category not in palette:
|
||||
palette[defn.category] = []
|
||||
palette[defn.category].append(defn)
|
||||
return palette
|
||||
|
||||
Reference in New Issue
Block a user