132 lines
4.4 KiB
Python
132 lines
4.4 KiB
Python
"""
|
|
Stage base class — common interface for all pipeline stages.
|
|
|
|
Each stage is a file that subclasses Stage. Auto-discovered via
|
|
__init_subclass__. No manual registration needed.
|
|
|
|
A stage:
|
|
- Has a StageDefinition (from schema) with name, config, IO
|
|
- Implements run(frames, config) → output
|
|
- Owns its output serialization (opaque blob)
|
|
- Optionally has a TypeScript port for browser-side execution
|
|
|
|
The checkpoint layer stores stage output as blobs without knowing
|
|
the format. The stage that wrote it is the only one that can read it.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
|
|
from core.schema.models.stages import StageConfigField, StageIO, StageDefinition
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Registry — auto-populated by __init_subclass__ (new stages)
|
|
# + register_stage() (legacy stages during migration)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_REGISTRY: dict[str, type['Stage']] = {}
|
|
_LEGACY_REGISTRY: dict[str, StageDefinition] = {}
|
|
|
|
|
|
def register_stage(definition: StageDefinition):
|
|
"""Legacy registration for stages not yet converted to Stage subclass."""
|
|
_LEGACY_REGISTRY[definition.name] = definition
|
|
|
|
|
|
class Stage:
|
|
"""
|
|
Base class for all pipeline stages.
|
|
|
|
Subclass this in detect/stages/<name>.py. Define `definition` as a
|
|
class attribute. Implement `run()`. Optionally override `serialize()`
|
|
and `deserialize()` for custom blob formats (default is JSON).
|
|
"""
|
|
|
|
definition: StageDefinition # set by each subclass
|
|
|
|
def __init_subclass__(cls, **kwargs):
|
|
super().__init_subclass__(**kwargs)
|
|
if hasattr(cls, 'definition') and cls.definition is not None:
|
|
_REGISTRY[cls.definition.name] = cls
|
|
|
|
def run(self, frames: list, config: dict) -> Any:
|
|
"""
|
|
Run the stage on a list of frames with the given config.
|
|
|
|
Config is a dict of parameter values (from slider UI or profile).
|
|
Returns the stage output — whatever shape this stage produces.
|
|
Debug overlays are included when config has debug=True.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def serialize(self, output: Any) -> bytes:
|
|
"""Serialize stage output to bytes for checkpoint storage."""
|
|
import json
|
|
return json.dumps(output, default=str).encode()
|
|
|
|
def deserialize(self, data: bytes) -> Any:
|
|
"""Deserialize stage output from checkpoint blob."""
|
|
import json
|
|
return json.loads(data)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Discovery API
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _all_definitions() -> dict[str, StageDefinition]:
|
|
"""Merge new Stage subclass registry + legacy registry."""
|
|
merged = {}
|
|
# Legacy first, new overwrites (new takes precedence)
|
|
for name, defn in _LEGACY_REGISTRY.items():
|
|
merged[name] = defn
|
|
for name, cls in _REGISTRY.items():
|
|
merged[name] = cls.definition
|
|
return merged
|
|
|
|
|
|
def get_stage(name: str) -> StageDefinition:
|
|
"""Get a stage definition by name (works for both new and legacy)."""
|
|
all_defs = _all_definitions()
|
|
if name not in all_defs:
|
|
raise KeyError(f"Unknown stage: {name!r}. Registered: {list(all_defs)}")
|
|
return all_defs[name]
|
|
|
|
|
|
def get_stage_class(name: str) -> type[Stage] | None:
|
|
"""Get a Stage subclass by name. Returns None for legacy stages."""
|
|
return _REGISTRY.get(name)
|
|
|
|
|
|
def get_stage_instance(name: str) -> Stage:
|
|
"""Get an instantiated Stage by name. Only works for new-style stages."""
|
|
cls = _REGISTRY.get(name)
|
|
if cls is None:
|
|
raise KeyError(f"No Stage subclass for {name!r}. Legacy stages don't have instances.")
|
|
return cls()
|
|
|
|
|
|
def list_stages() -> list[StageDefinition]:
|
|
"""List all registered stage definitions (new + legacy)."""
|
|
return list(_all_definitions().values())
|
|
|
|
|
|
def list_stage_classes() -> list[type[Stage]]:
|
|
"""List all registered Stage subclasses (new-style only)."""
|
|
return list(_REGISTRY.values())
|
|
|
|
|
|
def get_palette() -> dict[str, list[StageDefinition]]:
|
|
"""Group stages by category for the editor palette."""
|
|
palette: dict[str, list[StageDefinition]] = {}
|
|
for defn in _all_definitions().values():
|
|
if defn.category not in palette:
|
|
palette[defn.category] = []
|
|
palette[defn.category].append(defn)
|
|
return palette
|