""" Stage base class — common interface for all pipeline stages. Each stage is a file that subclasses Stage. Auto-discovered via __init_subclass__. No manual registration needed. A stage: - Has a StageDefinition (from schema) with name, config, IO - Implements run(frames, config) → output - Owns its output serialization (opaque blob) - Optionally has a TypeScript port for browser-side execution The checkpoint layer stores stage output as blobs without knowing the format. The stage that wrote it is the only one that can read it. """ from __future__ import annotations from dataclasses import dataclass, field from typing import Any import numpy as np from core.schema.models.stages import StageConfigField, StageIO, StageDefinition # --------------------------------------------------------------------------- # Registry — auto-populated by __init_subclass__ (new stages) # + register_stage() (legacy stages during migration) # --------------------------------------------------------------------------- _REGISTRY: dict[str, type['Stage']] = {} _LEGACY_REGISTRY: dict[str, StageDefinition] = {} def register_stage(definition: StageDefinition): """Legacy registration for stages not yet converted to Stage subclass.""" _LEGACY_REGISTRY[definition.name] = definition class Stage: """ Base class for all pipeline stages. Subclass this in detect/stages/.py. Define `definition` as a class attribute. Implement `run()`. Optionally override `serialize()` and `deserialize()` for custom blob formats (default is JSON). """ definition: StageDefinition # set by each subclass def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) if hasattr(cls, 'definition') and cls.definition is not None: _REGISTRY[cls.definition.name] = cls def run(self, frames: list, config: dict) -> Any: """ Run the stage on a list of frames with the given config. Config is a dict of parameter values (from slider UI or profile). Returns the stage output — whatever shape this stage produces. Debug overlays are included when config has debug=True. """ raise NotImplementedError def serialize(self, output: Any) -> bytes: """Serialize stage output to bytes for checkpoint storage.""" import json return json.dumps(output, default=str).encode() def deserialize(self, data: bytes) -> Any: """Deserialize stage output from checkpoint blob.""" import json return json.loads(data) # --------------------------------------------------------------------------- # Discovery API # --------------------------------------------------------------------------- def _all_definitions() -> dict[str, StageDefinition]: """Merge new Stage subclass registry + legacy registry.""" merged = {} # Legacy first, new overwrites (new takes precedence) for name, defn in _LEGACY_REGISTRY.items(): merged[name] = defn for name, cls in _REGISTRY.items(): merged[name] = cls.definition return merged def get_stage(name: str) -> StageDefinition: """Get a stage definition by name (works for both new and legacy).""" all_defs = _all_definitions() if name not in all_defs: raise KeyError(f"Unknown stage: {name!r}. Registered: {list(all_defs)}") return all_defs[name] def get_stage_class(name: str) -> type[Stage] | None: """Get a Stage subclass by name. Returns None for legacy stages.""" return _REGISTRY.get(name) def get_stage_instance(name: str) -> Stage: """Get an instantiated Stage by name. Only works for new-style stages.""" cls = _REGISTRY.get(name) if cls is None: raise KeyError(f"No Stage subclass for {name!r}. Legacy stages don't have instances.") return cls() def list_stages() -> list[StageDefinition]: """List all registered stage definitions (new + legacy).""" return list(_all_definitions().values()) def list_stage_classes() -> list[type[Stage]]: """List all registered Stage subclasses (new-style only).""" return list(_REGISTRY.values()) def get_palette() -> dict[str, list[StageDefinition]]: """Group stages by category for the editor palette.""" palette: dict[str, list[StageDefinition]] = {} for defn in _all_definitions().values(): if defn.category not in palette: palette[defn.category] = [] palette[defn.category].append(defn) return palette