Files
mediaproc/detect/stages/base.py
2026-03-27 04:23:21 -03:00

132 lines
4.4 KiB
Python

"""
Stage base class — common interface for all pipeline stages.
Each stage is a file that subclasses Stage. Auto-discovered via
__init_subclass__. No manual registration needed.
A stage:
- Has a StageDefinition (from schema) with name, config, IO
- Implements run(frames, config) → output
- Owns its output serialization (opaque blob)
- Optionally has a TypeScript port for browser-side execution
The checkpoint layer stores stage output as blobs without knowing
the format. The stage that wrote it is the only one that can read it.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
import numpy as np
from core.schema.models.stages import StageConfigField, StageIO, StageDefinition
# ---------------------------------------------------------------------------
# Registry — auto-populated by __init_subclass__ (new stages)
# + register_stage() (legacy stages during migration)
# ---------------------------------------------------------------------------
_REGISTRY: dict[str, type['Stage']] = {}
_LEGACY_REGISTRY: dict[str, StageDefinition] = {}
def register_stage(definition: StageDefinition):
"""Legacy registration for stages not yet converted to Stage subclass."""
_LEGACY_REGISTRY[definition.name] = definition
class Stage:
"""
Base class for all pipeline stages.
Subclass this in detect/stages/<name>.py. Define `definition` as a
class attribute. Implement `run()`. Optionally override `serialize()`
and `deserialize()` for custom blob formats (default is JSON).
"""
definition: StageDefinition # set by each subclass
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if hasattr(cls, 'definition') and cls.definition is not None:
_REGISTRY[cls.definition.name] = cls
def run(self, frames: list, config: dict) -> Any:
"""
Run the stage on a list of frames with the given config.
Config is a dict of parameter values (from slider UI or profile).
Returns the stage output — whatever shape this stage produces.
Debug overlays are included when config has debug=True.
"""
raise NotImplementedError
def serialize(self, output: Any) -> bytes:
"""Serialize stage output to bytes for checkpoint storage."""
import json
return json.dumps(output, default=str).encode()
def deserialize(self, data: bytes) -> Any:
"""Deserialize stage output from checkpoint blob."""
import json
return json.loads(data)
# ---------------------------------------------------------------------------
# Discovery API
# ---------------------------------------------------------------------------
def _all_definitions() -> dict[str, StageDefinition]:
"""Merge new Stage subclass registry + legacy registry."""
merged = {}
# Legacy first, new overwrites (new takes precedence)
for name, defn in _LEGACY_REGISTRY.items():
merged[name] = defn
for name, cls in _REGISTRY.items():
merged[name] = cls.definition
return merged
def get_stage(name: str) -> StageDefinition:
"""Get a stage definition by name (works for both new and legacy)."""
all_defs = _all_definitions()
if name not in all_defs:
raise KeyError(f"Unknown stage: {name!r}. Registered: {list(all_defs)}")
return all_defs[name]
def get_stage_class(name: str) -> type[Stage] | None:
"""Get a Stage subclass by name. Returns None for legacy stages."""
return _REGISTRY.get(name)
def get_stage_instance(name: str) -> Stage:
"""Get an instantiated Stage by name. Only works for new-style stages."""
cls = _REGISTRY.get(name)
if cls is None:
raise KeyError(f"No Stage subclass for {name!r}. Legacy stages don't have instances.")
return cls()
def list_stages() -> list[StageDefinition]:
"""List all registered stage definitions (new + legacy)."""
return list(_all_definitions().values())
def list_stage_classes() -> list[type[Stage]]:
"""List all registered Stage subclasses (new-style only)."""
return list(_REGISTRY.values())
def get_palette() -> dict[str, list[StageDefinition]]:
"""Group stages by category for the editor palette."""
palette: dict[str, list[StageDefinition]] = {}
for defn in _all_definitions().values():
if defn.category not in palette:
palette[defn.category] = []
palette[defn.category].append(defn)
return palette