refactor stage 1

This commit is contained in:
2026-03-27 04:23:21 -03:00
parent df6bcb01e8
commit 291ac8dd40
14 changed files with 688 additions and 450 deletions

View File

@@ -1,21 +1,21 @@
"""
Pipeline stages.
Each stage registers its StageDefinition on import,
declaring IO (what it reads/writes from state),
config fields (what's tunable from the editor),
and serialization (how to checkpoint its outputs).
Each stage is a file with a Stage subclass. Auto-discovered via
__init_subclass__ — importing the file registers the stage.
"""
from .base import (
StageDefinition,
StageIO,
StageConfigField,
register_stage,
Stage,
get_stage,
get_stage_instance,
list_stages,
list_stage_classes,
get_palette,
)
# Populate registry with built-in stages
# Import all stage files to trigger auto-registration
from . import edge_detector # noqa: F401
# Import registry for backward compat (other stages still use old pattern)
from . import registry # noqa: F401

View File

@@ -1,101 +1,131 @@
"""
Stage protocol — common interface for all pipeline stages.
Stage base class — common interface for all pipeline stages.
Every stage declares:
- IO: what it reads/writes from DetectState
- Config: tunable parameters for the editor
- Serialization: how to persist/restore its own outputs
Each stage is a file that subclasses Stage. Auto-discovered via
__init_subclass__. No manual registration needed.
The checkpoint layer is a black box — it asks each stage to serialize its
outputs and stores the result. Stages own their data format. Binary data
(frames, crops) goes to S3 via the stage itself. The checkpoint just
stores the JSON envelope.
A stage:
- Has a StageDefinition (from schema) with name, config, IO
- Implements run(frames, config) → output
- Owns its output serialization (opaque blob)
- Optionally has a TypeScript port for browser-side execution
The graph builder uses StageIO to validate that a stage's inputs are
satisfied by previous stages' outputs.
The checkpoint layer stores stage output as blobs without knowing
the format. The stage that wrote it is the only one that can read it.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Callable
from typing import Any
import numpy as np
@dataclass
class StageIO:
"""Declares what a stage reads and writes from/to DetectState."""
reads: list[str]
writes: list[str]
optional_reads: list[str] = field(default_factory=list)
@dataclass
class StageConfigField:
"""A single tunable config parameter for the editor UI."""
name: str
type: str # "float", "int", "str", "bool", "list[str]"
default: Any
description: str = ""
min: float | None = None
max: float | None = None
options: list[str] | None = None
@dataclass
class StageDefinition:
"""
Complete metadata for a pipeline stage.
The profile editor uses this to build the palette, generate config
forms, and validate graph connections. The checkpoint uses serialize_fn
and deserialize_fn to persist stage outputs without knowing the internals.
"""
name: str
label: str
description: str
io: StageIO
config_fields: list[StageConfigField] = field(default_factory=list)
category: str = "detection"
# The actual graph node function: (DetectState) → dict
fn: Callable | None = None
# Stage-owned serialization for checkpointing.
# serialize_fn: (state: dict, job_id: str) → json-compatible dict
# Stage picks its writes from state, serializes them.
# Binary data (frames) → S3 via stage, returns refs.
# deserialize_fn: (data: dict, job_id: str) → state update dict
# Stage restores its writes from the persisted data.
serialize_fn: Callable | None = None
deserialize_fn: Callable | None = None
from core.schema.models.stages import StageConfigField, StageIO, StageDefinition
# ---------------------------------------------------------------------------
# Registry
# Registry — auto-populated by __init_subclass__ (new stages)
# + register_stage() (legacy stages during migration)
# ---------------------------------------------------------------------------
_REGISTRY: dict[str, StageDefinition] = {}
_REGISTRY: dict[str, type['Stage']] = {}
_LEGACY_REGISTRY: dict[str, StageDefinition] = {}
def register_stage(definition: StageDefinition):
_REGISTRY[definition.name] = definition
"""Legacy registration for stages not yet converted to Stage subclass."""
_LEGACY_REGISTRY[definition.name] = definition
class Stage:
"""
Base class for all pipeline stages.
Subclass this in detect/stages/<name>.py. Define `definition` as a
class attribute. Implement `run()`. Optionally override `serialize()`
and `deserialize()` for custom blob formats (default is JSON).
"""
definition: StageDefinition # set by each subclass
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if hasattr(cls, 'definition') and cls.definition is not None:
_REGISTRY[cls.definition.name] = cls
def run(self, frames: list, config: dict) -> Any:
"""
Run the stage on a list of frames with the given config.
Config is a dict of parameter values (from slider UI or profile).
Returns the stage output — whatever shape this stage produces.
Debug overlays are included when config has debug=True.
"""
raise NotImplementedError
def serialize(self, output: Any) -> bytes:
"""Serialize stage output to bytes for checkpoint storage."""
import json
return json.dumps(output, default=str).encode()
def deserialize(self, data: bytes) -> Any:
"""Deserialize stage output from checkpoint blob."""
import json
return json.loads(data)
# ---------------------------------------------------------------------------
# Discovery API
# ---------------------------------------------------------------------------
def _all_definitions() -> dict[str, StageDefinition]:
"""Merge new Stage subclass registry + legacy registry."""
merged = {}
# Legacy first, new overwrites (new takes precedence)
for name, defn in _LEGACY_REGISTRY.items():
merged[name] = defn
for name, cls in _REGISTRY.items():
merged[name] = cls.definition
return merged
def get_stage(name: str) -> StageDefinition:
if name not in _REGISTRY:
raise KeyError(f"Unknown stage: {name!r}. Registered: {list(_REGISTRY)}")
return _REGISTRY[name]
"""Get a stage definition by name (works for both new and legacy)."""
all_defs = _all_definitions()
if name not in all_defs:
raise KeyError(f"Unknown stage: {name!r}. Registered: {list(all_defs)}")
return all_defs[name]
def get_stage_class(name: str) -> type[Stage] | None:
"""Get a Stage subclass by name. Returns None for legacy stages."""
return _REGISTRY.get(name)
def get_stage_instance(name: str) -> Stage:
"""Get an instantiated Stage by name. Only works for new-style stages."""
cls = _REGISTRY.get(name)
if cls is None:
raise KeyError(f"No Stage subclass for {name!r}. Legacy stages don't have instances.")
return cls()
def list_stages() -> list[StageDefinition]:
"""List all registered stage definitions (new + legacy)."""
return list(_all_definitions().values())
def list_stage_classes() -> list[type[Stage]]:
"""List all registered Stage subclasses (new-style only)."""
return list(_REGISTRY.values())
def get_palette() -> dict[str, list[StageDefinition]]:
"""Group stages by category for the editor palette."""
palette: dict[str, list[StageDefinition]] = {}
for stage in _REGISTRY.values():
if stage.category not in palette:
palette[stage.category] = []
palette[stage.category].append(stage)
for defn in _all_definitions().values():
if defn.category not in palette:
palette[defn.category] = []
palette[defn.category].append(defn)
return palette

View File

@@ -7,168 +7,227 @@ advertising hoardings. Pure OpenCV, no ML models.
Two modes:
- Remote: calls GPU inference server over HTTP
- Local: imports cv2 directly (OpenCV on same machine)
Emits frame_update events with bounding boxes for the frame viewer.
"""
from __future__ import annotations
import base64
import io
import json
import logging
import os
import time
from typing import Any
from PIL import Image
from detect import emit
from detect.models import BoundingBox, Frame
from detect.profiles.base import RegionAnalysisConfig
from detect.stages.base import Stage
from core.schema.models.stages import StageDefinition, StageConfigField, StageIO
logger = logging.getLogger(__name__)
class EdgeDetectionStage(Stage):
definition = StageDefinition(
name="detect_edges",
label="Edge Detection",
description="Canny + HoughLinesP — find horizontal line pairs (hoarding boundaries)",
category="cv_analysis",
io=StageIO(
reads=["filtered_frames"],
writes=["edge_regions_by_frame"],
),
config_fields=[
StageConfigField("enabled", "bool", True, "Enable edge detection"),
StageConfigField("edge_canny_low", "int", 50, "Canny low threshold", min=0, max=255),
StageConfigField("edge_canny_high", "int", 150, "Canny high threshold", min=0, max=255),
StageConfigField("edge_hough_threshold", "int", 80, "Hough accumulator threshold", min=1, max=500),
StageConfigField("edge_hough_min_length", "int", 100, "Min line length (px)", min=10, max=2000),
StageConfigField("edge_hough_max_gap", "int", 10, "Max line gap (px)", min=1, max=100),
StageConfigField("edge_pair_max_distance", "int", 200, "Max distance between line pair (px)", min=10, max=500),
StageConfigField("edge_pair_min_distance", "int", 15, "Min distance between line pair (px)", min=5, max=200),
],
)
def run(self, frames: list[Frame], config: dict) -> dict[int, list[BoundingBox]]:
"""
Run edge detection on all frames.
Config keys: enabled, edge_canny_low, edge_canny_high, edge_hough_threshold,
edge_hough_min_length, edge_hough_max_gap, edge_pair_max_distance, edge_pair_min_distance,
debug (bool), inference_url (str|None), job_id (str|None).
Returns dict mapping frame sequence → list of BoundingBox.
"""
enabled = config.get("enabled", True)
job_id = config.get("job_id")
inference_url = config.get("inference_url") or os.environ.get("INFERENCE_URL")
if not enabled:
emit.log(job_id, "EdgeDetection", "INFO", "Edge detection disabled, skipping")
return {}
mode = "remote" if inference_url else "local"
emit.log(job_id, "EdgeDetection", "INFO",
f"Detecting edges in {len(frames)} frames (mode={mode})")
all_boxes: dict[int, list[BoundingBox]] = {}
total_regions = 0
for frame in frames:
t0 = time.monotonic()
if inference_url:
boxes = self._run_remote(frame, config, inference_url, job_id or "")
else:
boxes = self._run_local(frame, config)
ms = (time.monotonic() - t0) * 1000
all_boxes[frame.sequence] = boxes
total_regions += len(boxes)
emit.log(job_id, "EdgeDetection", "DEBUG",
f"Frame {frame.sequence}: {len(boxes)} regions in {ms:.0f}ms"
+ (f" [{', '.join(b.label for b in boxes)}]" if boxes else ""))
if boxes and job_id:
box_dicts = [
{"x": b.x, "y": b.y, "w": b.w, "h": b.h,
"confidence": b.confidence, "label": b.label,
"stage": "detect_edges"}
for b in boxes
]
emit.frame_update(
job_id,
frame_ref=frame.sequence,
timestamp=frame.timestamp,
jpeg_b64=_frame_to_b64(frame),
boxes=box_dicts,
)
emit.log(job_id, "EdgeDetection", "INFO",
f"Found {total_regions} edge regions across {len(frames)} frames")
emit.stats(job_id, cv_regions_detected=total_regions)
return all_boxes
def serialize(self, output: Any) -> bytes:
"""Serialize edge regions to JSON blob."""
serialized = {}
for seq, boxes in output.items():
serialized[str(seq)] = [
{"x": b.x, "y": b.y, "w": b.w, "h": b.h,
"confidence": b.confidence, "label": b.label}
for b in boxes
]
return json.dumps(serialized).encode()
def deserialize(self, data: bytes) -> dict[int, list[BoundingBox]]:
"""Deserialize edge regions from JSON blob."""
raw = json.loads(data)
result = {}
for seq_str, box_dicts in raw.items():
boxes = [
BoundingBox(x=b["x"], y=b["y"], w=b["w"], h=b["h"],
confidence=b["confidence"], label=b["label"])
for b in box_dicts
]
result[int(seq_str)] = boxes
return result
# --- Private helpers ---
def _run_remote(self, frame: Frame, config: dict,
inference_url: str, job_id: str) -> list[BoundingBox]:
from detect.inference import InferenceClient
from detect.emit import _run_log_level
client = InferenceClient(
base_url=inference_url, job_id=job_id, log_level=_run_log_level,
)
results = client.detect_edges(
image=frame.image,
edge_canny_low=config.get("edge_canny_low", 50),
edge_canny_high=config.get("edge_canny_high", 150),
edge_hough_threshold=config.get("edge_hough_threshold", 80),
edge_hough_min_length=config.get("edge_hough_min_length", 100),
edge_hough_max_gap=config.get("edge_hough_max_gap", 10),
edge_pair_max_distance=config.get("edge_pair_max_distance", 200),
edge_pair_min_distance=config.get("edge_pair_min_distance", 15),
)
boxes = []
for r in results:
box = BoundingBox(
x=r.x, y=r.y, w=r.w, h=r.h,
confidence=r.confidence, label=r.label,
)
boxes.append(box)
return boxes
def _run_local(self, frame: Frame, config: dict) -> list[BoundingBox]:
detect_edges_fn = _load_cv_edges().detect_edges
edge_results = detect_edges_fn(
frame.image,
canny_low=config.get("edge_canny_low", 50),
canny_high=config.get("edge_canny_high", 150),
hough_threshold=config.get("edge_hough_threshold", 80),
hough_min_length=config.get("edge_hough_min_length", 100),
hough_max_gap=config.get("edge_hough_max_gap", 10),
pair_max_distance=config.get("edge_pair_max_distance", 200),
pair_min_distance=config.get("edge_pair_min_distance", 15),
)
boxes = []
for r in edge_results:
box = BoundingBox(
x=r["x"], y=r["y"], w=r["w"], h=r["h"],
confidence=r["confidence"], label=r["label"],
)
boxes.append(box)
return boxes
# --- Module-level helpers ---
def _frame_to_b64(frame: Frame) -> str:
"""Encode frame as base64 JPEG for SSE frame_update events."""
img = Image.fromarray(frame.image)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=70)
return base64.b64encode(buf.getvalue()).decode()
def _detect_remote(
frame: Frame,
config: RegionAnalysisConfig,
inference_url: str,
job_id: str = "",
log_level: str = "INFO",
) -> list[BoundingBox]:
"""Call the inference server over HTTP."""
from detect.inference import InferenceClient
client = InferenceClient(
base_url=inference_url, job_id=job_id, log_level=log_level,
)
results = client.detect_edges(
image=frame.image,
edge_canny_low=config.edge_canny_low,
edge_canny_high=config.edge_canny_high,
edge_hough_threshold=config.edge_hough_threshold,
edge_hough_min_length=config.edge_hough_min_length,
edge_hough_max_gap=config.edge_hough_max_gap,
edge_pair_max_distance=config.edge_pair_max_distance,
edge_pair_min_distance=config.edge_pair_min_distance,
)
boxes = []
for r in results:
box = BoundingBox(
x=r.x, y=r.y, w=r.w, h=r.h,
confidence=r.confidence, label=r.label,
)
boxes.append(box)
return boxes
_cv_edges_mod = None
def _load_cv_edges():
"""Load edges module directly — gpu/models/__init__.py has GPU-container-only imports."""
global _cv_edges_mod
if _cv_edges_mod is None:
import importlib.util
from pathlib import Path
spec = importlib.util.spec_from_file_location("cv_edges", Path("gpu/models/cv/edges.py"))
_cv_edges_mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(_cv_edges_mod)
return _cv_edges_mod
def _detect_local(frame: Frame, config: RegionAnalysisConfig) -> list[BoundingBox]:
"""Run edge detection in-process (requires opencv-python)."""
detect_edges_fn = _load_cv_edges().detect_edges
# --- Backward compat: standalone function for graph.py ---
edge_results = detect_edges_fn(
frame.image,
canny_low=config.edge_canny_low,
canny_high=config.edge_canny_high,
hough_threshold=config.edge_hough_threshold,
hough_min_length=config.edge_hough_min_length,
hough_max_gap=config.edge_hough_max_gap,
pair_max_distance=config.edge_pair_max_distance,
pair_min_distance=config.edge_pair_min_distance,
)
boxes = []
for r in edge_results:
box = BoundingBox(
x=r["x"], y=r["y"], w=r["w"], h=r["h"],
confidence=r["confidence"], label=r["label"],
)
boxes.append(box)
return boxes
def detect_edge_regions(
frames: list[Frame],
config: RegionAnalysisConfig,
inference_url: str | None = None,
job_id: str | None = None,
) -> dict[int, list[BoundingBox]]:
"""
Run edge detection on all frames.
Returns a dict mapping frame sequence → list of bounding boxes.
"""
if not config.enabled:
emit.log(job_id, "EdgeDetection", "INFO", "Edge detection disabled, skipping")
return {}
mode = "remote" if inference_url else "local"
emit.log(job_id, "EdgeDetection", "INFO",
f"Detecting edges in {len(frames)} frames (mode={mode})")
all_boxes: dict[int, list[BoundingBox]] = {}
total_regions = 0
for i, frame in enumerate(frames):
t0 = time.monotonic()
if inference_url:
from detect.emit import _run_log_level
boxes = _detect_remote(
frame, config, inference_url,
job_id=job_id or "", log_level=_run_log_level,
)
else:
boxes = _detect_local(frame, config)
analysis_ms = (time.monotonic() - t0) * 1000
all_boxes[frame.sequence] = boxes
total_regions += len(boxes)
emit.log(job_id, "EdgeDetection", "DEBUG",
f"Frame {frame.sequence}: {len(boxes)} regions in {analysis_ms:.0f}ms"
+ (f" [{', '.join(b.label for b in boxes)}]" if boxes else ""))
if boxes and job_id:
box_dicts = [
{
"x": b.x, "y": b.y, "w": b.w, "h": b.h,
"confidence": b.confidence, "label": b.label,
"stage": "detect_edges",
}
for b in boxes
]
emit.frame_update(
job_id,
frame_ref=frame.sequence,
timestamp=frame.timestamp,
jpeg_b64=_frame_to_b64(frame),
boxes=box_dicts,
)
emit.log(job_id, "EdgeDetection", "INFO",
f"Found {total_regions} edge regions across {len(frames)} frames")
emit.stats(job_id, cv_regions_detected=total_regions)
return all_boxes
def detect_edge_regions(frames, config, inference_url=None, job_id=None):
"""Convenience wrapper — calls EdgeDetectionStage.run()."""
stage = EdgeDetectionStage()
cfg = {
"enabled": config.enabled,
"edge_canny_low": config.edge_canny_low,
"edge_canny_high": config.edge_canny_high,
"edge_hough_threshold": config.edge_hough_threshold,
"edge_hough_min_length": config.edge_hough_min_length,
"edge_hough_max_gap": config.edge_hough_max_gap,
"edge_pair_max_distance": config.edge_pair_max_distance,
"edge_pair_min_distance": config.edge_pair_min_distance,
"inference_url": inference_url,
"job_id": job_id,
}
return stage.run(frames, cfg)