diff --git a/core/api/detect_config.py b/core/api/detect_config.py index 75034b6..0ce70f7 100644 --- a/core/api/detect_config.py +++ b/core/api/detect_config.py @@ -64,25 +64,42 @@ def list_stage_configs(): result = [] for stage in list_stages(): - info = StageConfigInfo( - name=stage.name, - label=stage.label, - description=stage.description, - category=stage.category, - config_fields=[ - { - "name": f.name, - "type": f.type, - "default": f.default, - "description": f.description, - "min": f.min, - "max": f.max, - "options": f.options, - } - for f in stage.config_fields - ], - reads=stage.io.reads, - writes=stage.io.writes, - ) + info = _stage_to_info(stage) result.append(info) return result + + +@router.get("/config/stages/{stage_name}", response_model=StageConfigInfo) +def get_stage_config(stage_name: str): + """Return config field metadata for a single stage.""" + from detect.stages import get_stage + + try: + stage = get_stage(stage_name) + except KeyError: + from fastapi import HTTPException + raise HTTPException(status_code=404, detail=f"Unknown stage: {stage_name}") + return _stage_to_info(stage) + + +def _stage_to_info(stage) -> StageConfigInfo: + return StageConfigInfo( + name=stage.name, + label=stage.label, + description=stage.description, + category=stage.category, + config_fields=[ + { + "name": f.name, + "type": f.type, + "default": f.default, + "description": f.description, + "min": f.min, + "max": f.max, + "options": f.options, + } + for f in stage.config_fields + ], + reads=stage.io.reads, + writes=stage.io.writes, + ) diff --git a/core/api/detect_replay.py b/core/api/detect_replay.py index 89f0cb8..9f3fedd 100644 --- a/core/api/detect_replay.py +++ b/core/api/detect_replay.py @@ -1,16 +1,20 @@ """ -API endpoints for checkpoint inspection, replay, and retry. +API endpoints for checkpoint inspection, replay, retry, and GPU proxy. GET /detect/checkpoints/{job_id} — list available checkpoints POST /detect/replay — replay from a stage with config overrides POST /detect/retry — queue async retry with different provider +POST /detect/replay-stage — replay single stage (fast path) +POST /detect/gpu/detect_edges — proxy to GPU inference server +POST /detect/gpu/detect_edges/debug — proxy with debug overlays """ from __future__ import annotations import logging +import os -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, HTTPException, Request, Response from pydantic import BaseModel logger = logging.getLogger(__name__) @@ -22,6 +26,18 @@ router = APIRouter(prefix="/detect", tags=["detect"]) class CheckpointInfo(BaseModel): stage: str + is_scenario: bool = False + scenario_label: str = "" + + +class ScenarioInfo(BaseModel): + job_id: str + stage: str + scenario_label: str + profile_name: str + video_path: str + frame_count: int = 0 + created_at: str = "" class ReplayRequest(BaseModel): @@ -51,6 +67,39 @@ class RetryResponse(BaseModel): job_id: str +class ReplaySingleStageRequest(BaseModel): + job_id: str + stage: str + frame_refs: list[int] | None = None + config_overrides: dict | None = None + debug: bool = False + + +class ReplaySingleStageBox(BaseModel): + x: int + y: int + w: int + h: int + confidence: float + label: str + + +class FrameDebugOverlays(BaseModel): + edge_overlay_b64: str = "" + lines_overlay_b64: str = "" + horizontal_count: int = 0 + pair_count: int = 0 + + +class ReplaySingleStageResponse(BaseModel): + status: str + stage: str + frame_count: int = 0 + region_count: int = 0 + regions_by_frame: dict[str, list[ReplaySingleStageBox]] = {} + debug: dict[str, FrameDebugOverlays] = {} # keyed by frame seq + + # --- Endpoints --- @router.get("/checkpoints/{job_id}") @@ -67,6 +116,28 @@ def list_checkpoints(job_id: str) -> list[CheckpointInfo]: return result +@router.get("/scenarios", response_model=list[ScenarioInfo]) +def list_scenarios_endpoint(): + """List all available scenarios (bookmarked checkpoints).""" + from core.db.detect import list_scenarios + + scenarios = list_scenarios() + result = [] + for s in scenarios: + manifest = s.frames_manifest or {} + info = ScenarioInfo( + job_id=str(s.job_id), + stage=s.stage, + scenario_label=s.scenario_label, + profile_name=s.profile_name, + video_path=s.video_path, + frame_count=len(manifest), + created_at=str(s.created_at) if s.created_at else "", + ) + result.append(info) + return result + + @router.post("/replay", response_model=ReplayResponse) def replay(req: ReplayRequest): """Replay pipeline from a specific stage with optional config overrides.""" @@ -119,3 +190,103 @@ def retry(req: RetryRequest): job_id=req.job_id, ) return response + + +@router.post("/replay-stage", response_model=ReplaySingleStageResponse) +def replay_single_stage(req: ReplaySingleStageRequest): + """Replay a single stage on specific frames — fast path for interactive tuning.""" + from detect.checkpoint.replay import replay_single_stage as _replay + + try: + result = _replay( + job_id=req.job_id, + stage=req.stage, + frame_refs=req.frame_refs, + config_overrides=req.config_overrides, + debug=req.debug, + ) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Single-stage replay failed: {e}") + + # Convert result to response format + regions_by_frame = result.get("edge_regions_by_frame", {}) + total_regions = 0 + serialized = {} + for seq, boxes in regions_by_frame.items(): + box_list = [] + for b in boxes: + box = ReplaySingleStageBox( + x=b.x, y=b.y, w=b.w, h=b.h, + confidence=b.confidence, label=b.label, + ) + box_list.append(box) + serialized[str(seq)] = box_list + total_regions += len(box_list) + + # Serialize debug overlays if present + debug_out = {} + raw_debug = result.get("debug", {}) + for seq, d in raw_debug.items(): + debug_out[str(seq)] = FrameDebugOverlays( + edge_overlay_b64=d.get("edge_overlay_b64", ""), + lines_overlay_b64=d.get("lines_overlay_b64", ""), + horizontal_count=d.get("horizontal_count", 0), + pair_count=d.get("pair_count", 0), + ) + + return ReplaySingleStageResponse( + status="completed", + stage=req.stage, + frame_count=len(regions_by_frame), + region_count=total_regions, + regions_by_frame=serialized, + debug=debug_out, + ) + + +# --- GPU proxy — thin passthrough to inference server for interactive editor --- + + +def _gpu_url() -> str: + url = os.environ.get("INFERENCE_URL", "http://localhost:8000") + return url.rstrip("/") + + +@router.post("/gpu/detect_edges") +async def gpu_detect_edges(request: Request): + """Proxy to GPU inference server — browser can't reach it directly.""" + import httpx + + body = await request.body() + try: + async with httpx.AsyncClient(timeout=30.0) as client: + resp = await client.post( + f"{_gpu_url()}/detect_edges", + content=body, + headers={"Content-Type": "application/json"}, + ) + return Response(content=resp.content, status_code=resp.status_code, + media_type="application/json") + except Exception as e: + raise HTTPException(status_code=502, detail=f"GPU server unreachable: {e}") + + +@router.post("/gpu/detect_edges/debug") +async def gpu_detect_edges_debug(request: Request): + """Proxy to GPU inference server debug endpoint.""" + import httpx + + body = await request.body() + try: + async with httpx.AsyncClient(timeout=30.0) as client: + resp = await client.post( + f"{_gpu_url()}/detect_edges/debug", + content=body, + headers={"Content-Type": "application/json"}, + ) + return Response(content=resp.content, status_code=resp.status_code, + media_type="application/json") + except Exception as e: + raise HTTPException(status_code=502, detail=f"GPU server unreachable: {e}") diff --git a/core/db/detect.py b/core/db/detect.py index 3681b2d..c3a320f 100644 --- a/core/db/detect.py +++ b/core/db/detect.py @@ -102,6 +102,17 @@ def list_stage_checkpoints(job_id: UUID) -> list[str]: return list(session.exec(stmt).all()) +def list_scenarios() -> list[StageCheckpoint]: + """List all checkpoints marked as scenarios.""" + with get_session() as session: + stmt = ( + select(StageCheckpoint) + .where(StageCheckpoint.is_scenario == True) + .order_by(StageCheckpoint.created_at.desc()) + ) + return list(session.exec(stmt).all()) + + def delete_stage_checkpoints(job_id: UUID) -> None: with get_session() as session: stmt = select(StageCheckpoint).where(StageCheckpoint.job_id == job_id) diff --git a/core/db/models.py b/core/db/models.py index 6dc6772..cfacda6 100644 --- a/core/db/models.py +++ b/core/db/models.py @@ -193,15 +193,14 @@ class StageCheckpoint(SQLModel, table=True): frames_manifest: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}')) frames_meta: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]')) filtered_frame_sequences: List[int] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]')) - boxes_by_frame: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}')) - text_candidates: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]')) - unresolved_candidates: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]')) - detections: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]')) + stage_output_key: str = "" # s3 key: checkpoints/{job_id}/stages/{stage}.bson stats: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}')) config_snapshot: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}')) config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}')) video_path: str = "" profile_name: str = "" + is_scenario: bool = False + scenario_label: str = "" created_at: Optional[datetime] = Field(default_factory=datetime.utcnow) class KnownBrand(SQLModel, table=True): diff --git a/core/schema/modelgen.json b/core/schema/modelgen.json index 37658ca..9783025 100644 --- a/core/schema/modelgen.json +++ b/core/schema/modelgen.json @@ -40,6 +40,11 @@ "target": "typescript", "output": "ui/detection-app/src/types/store-state.ts", "include": ["ui_state_views"] + }, + { + "target": "pydantic", + "output": "gpu/models/inference_contract.py", + "include": ["inference_views"] } ] } diff --git a/core/schema/models/__init__.py b/core/schema/models/__init__.py index 8222f9b..b8bdd48 100644 --- a/core/schema/models/__init__.py +++ b/core/schema/models/__init__.py @@ -33,6 +33,7 @@ from .detect_jobs import ( from .media import AssetStatus, MediaAsset from .presets import BUILTIN_PRESETS, TranscodePreset from .detect import DETECT_VIEWS # noqa: F401 — discovered by modelgen generic loader +from .inference import INFERENCE_VIEWS # noqa: F401 — GPU inference server API types from .ui_state import UI_STATE_VIEWS # noqa: F401 — UI store state types from .views import ChunkEvent, ChunkOutputFile, PipelineStats, WorkerEvent from .sources import ChunkInfo, SourceJob, SourceType diff --git a/core/schema/models/detect.py b/core/schema/models/detect.py index 33c5325..67ec246 100644 --- a/core/schema/models/detect.py +++ b/core/schema/models/detect.py @@ -53,6 +53,7 @@ class BoundingBoxEvent: label: str resolved_brand: Optional[str] = None source: Optional[str] = None + stage: Optional[str] = None @dataclass @@ -85,6 +86,7 @@ class StatsUpdate: frames_extracted: int = 0 frames_after_scene_filter: int = 0 + cv_regions_detected: int = 0 regions_detected: int = 0 regions_resolved_by_ocr: int = 0 regions_escalated_to_local_vlm: int = 0 @@ -166,6 +168,8 @@ class CheckpointInfo: """Available checkpoint for a stage.""" stage: str + is_scenario: bool = False + scenario_label: str = "" @dataclass diff --git a/core/schema/models/detect_jobs.py b/core/schema/models/detect_jobs.py index 72a8232..3354c68 100644 --- a/core/schema/models/detect_jobs.py +++ b/core/schema/models/detect_jobs.py @@ -93,13 +93,12 @@ class StageCheckpoint: frames_meta: List[Dict[str, Any]] = field(default_factory=list) # sequence, chunk_id, timestamp, hash filtered_frame_sequences: List[int] = field(default_factory=list) - # Detection state (full structured data, not just summaries) - boxes_by_frame: Dict[str, List[Dict[str, Any]]] = field(default_factory=dict) - text_candidates: List[Dict[str, Any]] = field(default_factory=list) - unresolved_candidates: List[Dict[str, Any]] = field(default_factory=list) - detections: List[Dict[str, Any]] = field(default_factory=list) + # Stage output — stored as blob in MinIO: checkpoints/{job_id}/stages/{stage}.bson + # Each stage's serialize_fn/deserialize_fn owns the format. + # Postgres only stores the S3 key, not the data itself. + stage_output_key: str = "" # s3 key to the serialized stage output - # Pipeline state + # Pipeline state (small, stays in Postgres) stats: Dict[str, Any] = field(default_factory=dict) config_snapshot: Dict[str, Any] = field(default_factory=dict) config_overrides: Dict[str, Any] = field(default_factory=dict) @@ -108,6 +107,13 @@ class StageCheckpoint: video_path: str = "" profile_name: str = "" + # Scenario — a checkpoint bookmarked for the editor workflow. + # Created by seeders (manual scripts that populate state from real footage) + # or captured from a running pipeline. Loaded via URL: + # /detection/?job=&stage=&editor=true + is_scenario: bool = False + scenario_label: str = "" # human-readable name, e.g. "chelsea_edges_lowcanny" + # Timestamps created_at: Optional[datetime] = None diff --git a/core/schema/models/detect_pipeline.py b/core/schema/models/detect_pipeline.py index e3673be..84cdd57 100644 --- a/core/schema/models/detect_pipeline.py +++ b/core/schema/models/detect_pipeline.py @@ -70,6 +70,7 @@ class BrandStats: class PipelineStats: frames_extracted: int = 0 frames_after_scene_filter: int = 0 + cv_regions_detected: int = 0 regions_detected: int = 0 regions_resolved_by_ocr: int = 0 regions_escalated_to_local_vlm: int = 0 diff --git a/core/schema/models/inference.py b/core/schema/models/inference.py new file mode 100644 index 0000000..117c129 --- /dev/null +++ b/core/schema/models/inference.py @@ -0,0 +1,197 @@ +""" +Inference Server API Schema Definitions + +Source of truth for GPU inference server request/response types. +Generates: Pydantic (gpu/models/inference_contract.py) + +These are the wire-format types for the HTTP API between the +pipeline (detect/) and the inference server (gpu/). +""" + +from dataclasses import dataclass, field +from typing import List, Optional + + +# --- Object Detection (YOLO) --- + + +@dataclass +class DetectRequest: + """Request body for object detection.""" + + image: str # base64 JPEG + model: Optional[str] = None + confidence: Optional[float] = None + target_classes: Optional[List[str]] = None + + +@dataclass +class BBox: + """A detected bounding box.""" + + x: int + y: int + w: int + h: int + confidence: float + label: str + + +@dataclass +class DetectResponse: + """Response from object detection.""" + + detections: List[BBox] = field(default_factory=list) + + +# --- OCR --- + + +@dataclass +class OCRRequest: + """Request body for OCR.""" + + image: str # base64 JPEG + languages: Optional[List[str]] = None + + +@dataclass +class OCRTextResult: + """A single OCR text extraction result.""" + + text: str + confidence: float + bbox: List[int] = field(default_factory=list) # [x, y, w, h] + + +@dataclass +class OCRResponse: + """Response from OCR.""" + + results: List[OCRTextResult] = field(default_factory=list) + + +# --- Preprocessing --- + + +@dataclass +class PreprocessRequest: + """Request body for image preprocessing.""" + + image: str # base64 JPEG + binarize: bool = False + deskew: bool = False + contrast: bool = True + + +@dataclass +class PreprocessResponse: + """Response from preprocessing.""" + + image: str # base64 JPEG of processed image + + +# --- VLM --- + + +@dataclass +class VLMRequest: + """Request body for visual language model query.""" + + image: str # base64 JPEG + prompt: str + model: Optional[str] = None + + +@dataclass +class VLMResponse: + """Response from VLM.""" + + brand: str + confidence: float + reasoning: str + + +# --- CV Region Analysis --- + + +@dataclass +class AnalyzeRegionsRequest: + """Request body for CV region analysis.""" + + image: str # base64 JPEG + # Edge detection (Canny + HoughLinesP) + edge_canny_low: int = 50 + edge_canny_high: int = 150 + edge_hough_threshold: int = 80 + edge_hough_min_length: int = 100 + edge_hough_max_gap: int = 10 + edge_pair_max_distance: int = 200 + edge_pair_min_distance: int = 15 + + +@dataclass +class RegionBox: + """A candidate region from CV analysis.""" + + x: int + y: int + w: int + h: int + confidence: float + label: str + + +@dataclass +class AnalyzeRegionsResponse: + """Response from CV region analysis.""" + + regions: List[RegionBox] = field(default_factory=list) + + +@dataclass +class AnalyzeRegionsDebugResponse: + """Response from CV region analysis with debug overlays.""" + + regions: List[RegionBox] = field(default_factory=list) + edge_overlay_b64: str = "" # Canny edge image as base64 JPEG + lines_overlay_b64: str = "" # frame with Hough lines drawn + horizontal_count: int = 0 + pair_count: int = 0 + + +# --- Server Config --- + + +@dataclass +class ConfigUpdate: + """Request body for updating server configuration.""" + + device: Optional[str] = None + yolo_model: Optional[str] = None + yolo_confidence: Optional[float] = None + vram_budget_mb: Optional[int] = None + strategy: Optional[str] = None + ocr_languages: Optional[List[str]] = None + ocr_min_confidence: Optional[float] = None + + +# --- Export list for modelgen --- + +INFERENCE_VIEWS = [ + DetectRequest, + BBox, + DetectResponse, + OCRRequest, + OCRTextResult, + OCRResponse, + PreprocessRequest, + PreprocessResponse, + VLMRequest, + VLMResponse, + AnalyzeRegionsRequest, + RegionBox, + AnalyzeRegionsResponse, + AnalyzeRegionsDebugResponse, + ConfigUpdate, +] diff --git a/ctrl/.env.template b/ctrl/.env.template index 8492437..de1f0a5 100644 --- a/ctrl/.env.template +++ b/ctrl/.env.template @@ -1,13 +1,13 @@ # MPR Environment Configuration # Copy to .env and adjust values as needed -# Database +# Database (must match ctrl/k8s/base/postgres.yaml configmap) POSTGRES_DB=mpr -POSTGRES_USER=mpr_user -POSTGRES_PASSWORD=mpr_pass +POSTGRES_USER=mpr +POSTGRES_PASSWORD=mpr POSTGRES_HOST=postgres POSTGRES_PORT=5432 -DATABASE_URL=postgresql://mpr_user:mpr_pass@postgres:5432/mpr +DATABASE_URL=postgresql://mpr:mpr@postgres:5432/mpr # Redis REDIS_HOST=redis @@ -27,6 +27,9 @@ GRPC_HOST=grpc GRPC_PORT=50051 GRPC_MAX_WORKERS=10 +# Media — host path for kind cluster mount +MEDIA_HOST_PATH=/home/you/wdir/mpr/media + # S3 Storage (MinIO locally, real S3 on AWS) # In k8s/docker: http://minio:9000 # On dev machine (port-forward): http://localhost:9000 diff --git a/ctrl/Tiltfile b/ctrl/Tiltfile index b9511c3..43521e7 100644 --- a/ctrl/Tiltfile +++ b/ctrl/Tiltfile @@ -4,6 +4,10 @@ allow_k8s_contexts('kind-mpr') +# Create namespace first — kustomize includes it but Tilt may apply +# all resources in parallel, causing "namespace not found" races +local('kubectl create namespace mpr --dry-run=client -o yaml | kubectl apply -f -') + # Apply k8s manifests via kustomize (dev overlay) k8s_yaml(kustomize('k8s/overlays/dev')) diff --git a/ctrl/kind-create.sh b/ctrl/kind-create.sh index f7fdea1..8cbc870 100755 --- a/ctrl/kind-create.sh +++ b/ctrl/kind-create.sh @@ -1,11 +1,17 @@ #!/bin/bash # Create the kind cluster with host media mount. -# Usage: MEDIA_HOST_PATH=/home/you/mpr/media ./kind-create.sh +# Reads MEDIA_HOST_PATH from ctrl/.env or environment. set -euo pipefail -: "${MEDIA_HOST_PATH:?Set MEDIA_HOST_PATH to your local media directory (e.g. /home/you/mpr/media)}" - SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" + +# Source .env if MEDIA_HOST_PATH not already set +if [[ -z "${MEDIA_HOST_PATH:-}" ]] && [[ -f "$SCRIPT_DIR/.env" ]]; then + export $(grep -E '^MEDIA_HOST_PATH=' "$SCRIPT_DIR/.env" | xargs) +fi + +: "${MEDIA_HOST_PATH:?Set MEDIA_HOST_PATH in ctrl/.env or environment}" + CONFIG_TPL="$SCRIPT_DIR/k8s/kind-config.yaml.tpl" envsubst < "$CONFIG_TPL" | kind create cluster --config - diff --git a/detect/checkpoint/replay.py b/detect/checkpoint/replay.py index c40ab04..1805cbe 100644 --- a/detect/checkpoint/replay.py +++ b/detect/checkpoint/replay.py @@ -26,6 +26,7 @@ class OverrideProfile: { "frame_extraction": {"fps": 1.0}, "scene_filter": {"hamming_threshold": 12}, + "region_analysis": {"edge_canny_low": 30, "edge_canny_high": 120}, "detection": {"confidence_threshold": 0.5}, "ocr": {"languages": ["en", "es"], "min_confidence": 0.3}, "resolver": {"fuzzy_threshold": 60}, @@ -52,6 +53,9 @@ class OverrideProfile: def scene_filter_config(self): return self._patch(self._base.scene_filter_config(), "scene_filter") + def region_analysis_config(self): + return self._patch(self._base.region_analysis_config(), "region_analysis") + def detection_config(self): return self._patch(self._base.detection_config(), "detection") @@ -130,3 +134,137 @@ def replay_from( emit.clear_run_context() return result + + +def replay_single_stage( + job_id: str, + stage: str, + frame_refs: list[int] | None = None, + config_overrides: dict | None = None, + debug: bool = False, +) -> dict: + """ + Replay a single stage on specific frames (or all frames from checkpoint). + + Fast path for interactive parameter tuning — runs only the target stage + function, not the full pipeline tail. Returns the stage output directly. + + When debug=True and stage is detect_edges, returns additional overlay + data (Canny edges, Hough lines) for visual feedback in the editor. + + For detect_edges: returns {"edge_regions_by_frame": {seq: [box, ...]}} + With debug=True, also returns {"debug": {seq: {edge_overlay_b64, lines_overlay_b64, ...}}} + """ + if stage not in NODES: + raise ValueError(f"Unknown stage: {stage!r}. Options: {NODES}") + + stage_idx = NODES.index(stage) + if stage_idx == 0: + raise ValueError("Cannot replay the first stage — just run the full pipeline") + + previous_stage = NODES[stage_idx - 1] + + available = list_checkpoints(job_id) + if previous_stage not in available: + raise ValueError( + f"No checkpoint for stage {previous_stage!r} (job {job_id}). " + f"Available: {available}" + ) + + logger.info("Single-stage replay: job %s, stage %s (loading checkpoint: %s, debug=%s)", + job_id, stage, previous_stage, debug) + + state = load_checkpoint(job_id, previous_stage) + + # Build profile with overrides + from detect.profiles import get_profile + profile = get_profile(state.get("profile_name", "soccer_broadcast")) + if config_overrides: + profile = OverrideProfile(profile, config_overrides) + + # Run the stage function directly (not through the graph) + if stage == "detect_edges": + return _replay_detect_edges(state, profile, frame_refs, job_id, debug) + else: + raise ValueError( + f"Single-stage replay not yet implemented for {stage!r}. " + f"Use replay_from() for full pipeline replay." + ) + + +def _replay_detect_edges( + state: dict, + profile, + frame_refs: list[int] | None, + job_id: str, + debug: bool, +) -> dict: + """Run edge detection on checkpoint frames, optionally with debug overlays.""" + import os + from detect.stages.edge_detector import detect_edge_regions + + config = profile.region_analysis_config() + frames = state.get("filtered_frames", []) + + if frame_refs: + ref_set = set(frame_refs) + frames = [f for f in frames if f.sequence in ref_set] + + inference_url = os.environ.get("INFERENCE_URL") + + # Normal run — always needed for the boxes + result = detect_edge_regions( + frames=frames, + config=config, + inference_url=inference_url, + job_id=job_id, + ) + output = {"edge_regions_by_frame": result} + + # Debug overlays — call debug endpoint (remote) or local debug function + if debug and frames: + debug_data = {} + if inference_url: + from detect.inference import InferenceClient + client = InferenceClient(base_url=inference_url, job_id=job_id) + for frame in frames: + dr = client.detect_edges_debug( + image=frame.image, + edge_canny_low=config.edge_canny_low, + edge_canny_high=config.edge_canny_high, + edge_hough_threshold=config.edge_hough_threshold, + edge_hough_min_length=config.edge_hough_min_length, + edge_hough_max_gap=config.edge_hough_max_gap, + edge_pair_max_distance=config.edge_pair_max_distance, + edge_pair_min_distance=config.edge_pair_min_distance, + ) + debug_data[frame.sequence] = { + "edge_overlay_b64": dr.edge_overlay_b64, + "lines_overlay_b64": dr.lines_overlay_b64, + "horizontal_count": dr.horizontal_count, + "pair_count": dr.pair_count, + } + else: + # Local mode — import GPU module directly + from detect.stages.edge_detector import _load_cv_edges + edges_mod = _load_cv_edges() + for frame in frames: + dr = edges_mod.detect_edges_debug( + frame.image, + canny_low=config.edge_canny_low, + canny_high=config.edge_canny_high, + hough_threshold=config.edge_hough_threshold, + hough_min_length=config.edge_hough_min_length, + hough_max_gap=config.edge_hough_max_gap, + pair_max_distance=config.edge_pair_max_distance, + pair_min_distance=config.edge_pair_min_distance, + ) + debug_data[frame.sequence] = { + "edge_overlay_b64": dr["edge_overlay_b64"], + "lines_overlay_b64": dr["lines_overlay_b64"], + "horizontal_count": dr["horizontal_count"], + "pair_count": dr["pair_count"], + } + output["debug"] = debug_data + + return output diff --git a/detect/checkpoint/storage.py b/detect/checkpoint/storage.py index d79c083..1157e05 100644 --- a/detect/checkpoint/storage.py +++ b/detect/checkpoint/storage.py @@ -2,39 +2,19 @@ Checkpoint storage — save/load stage state. Binary data (frame images) → S3/MinIO via frames.py -Structured data (boxes, detections, stats, config) → Postgres via Django ORM - -Until the Django model is generated by modelgen, checkpoint data is stored -as JSON in S3 as a fallback. Once DetectJob/StageCheckpoint models exist, -this module switches to Postgres. +Structured data (stage output, stats, config) → Postgres """ from __future__ import annotations -import json import logging -import os -import tempfile -from pathlib import Path -from .frames import save_frames, load_frames, BUCKET, CHECKPOINT_PREFIX +from .frames import save_frames, load_frames, CHECKPOINT_PREFIX from .serializer import serialize_state, deserialize_state logger = logging.getLogger(__name__) -def _has_db() -> bool: - """Check if Postgres is reachable.""" - try: - from core.db.connection import get_session - from sqlmodel import text - with get_session() as session: - session.exec(text("SELECT 1")) - return True - except Exception: - return False - - # --------------------------------------------------------------------------- # Save # --------------------------------------------------------------------------- @@ -45,34 +25,24 @@ def save_checkpoint( stage_index: int, state: dict, frames_manifest: dict[int, str] | None = None, + is_scenario: bool = False, + scenario_label: str = "", ) -> str: """ Save a stage checkpoint. Saves frame images to S3 (if not already saved), then persists - structured state to Postgres (or S3 JSON fallback). + structured state to Postgres. - Returns the checkpoint identifier (DB id or S3 key). + Returns the checkpoint DB id. """ - # Save frames to S3 if no manifest provided + from core.db.detect import save_stage_checkpoint + if frames_manifest is None: all_frames = state.get("frames", []) frames_manifest = save_frames(job_id, all_frames) checkpoint_data = serialize_state(state, frames_manifest) - - if _has_db(): - checkpoint_id = _save_to_db(job_id, stage, stage_index, checkpoint_data) - else: - checkpoint_id = _save_to_s3(job_id, stage, checkpoint_data) - - return checkpoint_id - - -def _save_to_db(job_id: str, stage: str, stage_index: int, data: dict) -> str: - """Save checkpoint structured data to Postgres.""" - from core.db.detect import save_stage_checkpoint - frames_prefix = f"{CHECKPOINT_PREFIX}/{job_id}/frames/" checkpoint = save_stage_checkpoint( @@ -80,44 +50,24 @@ def _save_to_db(job_id: str, stage: str, stage_index: int, data: dict) -> str: stage=stage, stage_index=stage_index, frames_prefix=frames_prefix, - frames_manifest=data.get("frames_manifest", {}), - frames_meta=data.get("frames_meta", []), - filtered_frame_sequences=data.get("filtered_frame_sequences", []), - boxes_by_frame=data.get("boxes_by_frame", {}), - text_candidates=data.get("text_candidates", []), - unresolved_candidates=data.get("unresolved_candidates", []), - detections=data.get("detections", []), - stats=data.get("stats", {}), - config_snapshot=data.get("config_overrides", {}), - config_overrides=data.get("config_overrides", {}), - video_path=data.get("video_path", ""), - profile_name=data.get("profile_name", ""), + frames_manifest=checkpoint_data.get("frames_manifest", {}), + frames_meta=checkpoint_data.get("frames_meta", []), + filtered_frame_sequences=checkpoint_data.get("filtered_frame_sequences", []), + stage_output_key=checkpoint_data.get("stage_output_key", ""), + stats=checkpoint_data.get("stats", {}), + config_snapshot=checkpoint_data.get("config_overrides", {}), + config_overrides=checkpoint_data.get("config_overrides", {}), + video_path=checkpoint_data.get("video_path", ""), + profile_name=checkpoint_data.get("profile_name", ""), + is_scenario=is_scenario, + scenario_label=scenario_label, ) - logger.info("Checkpoint saved to DB: %s/%s (id=%s)", job_id, stage, checkpoint.id) + logger.info("Checkpoint saved: %s/%s (id=%s, scenario=%s)", + job_id, stage, checkpoint.id, is_scenario) return str(checkpoint.id) -def _save_to_s3(job_id: str, stage: str, data: dict) -> str: - """Fallback: save checkpoint as JSON to S3 (before modelgen generates DB models).""" - from core.storage.s3 import upload_file - - checkpoint_json = json.dumps(data, default=str) - key = f"{CHECKPOINT_PREFIX}/{job_id}/stages/{stage}.json" - - with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp: - tmp.write(checkpoint_json) - tmp_path = tmp.name - - try: - upload_file(tmp_path, BUCKET, key) - finally: - os.unlink(tmp_path) - - logger.info("Checkpoint saved to S3: s3://%s/%s", BUCKET, key) - return key - - # --------------------------------------------------------------------------- # Load # --------------------------------------------------------------------------- @@ -125,30 +75,12 @@ def _save_to_s3(job_id: str, stage: str, data: dict) -> str: def load_checkpoint(job_id: str, stage: str) -> dict: """ Load a stage checkpoint and reconstitute full DetectState. - - Tries Postgres first, falls back to S3 JSON. """ - if _has_db(): - data = _load_from_db(job_id, stage) - else: - data = _load_from_s3(job_id, stage) - - raw_manifest = data.get("frames_manifest", {}) - manifest = {int(k): v for k, v in raw_manifest.items()} - frame_metadata = data.get("frames_meta", []) - frames = load_frames(manifest, frame_metadata) - - state = deserialize_state(data, frames) - - logger.info("Checkpoint loaded: %s/%s (%d frames)", job_id, stage, len(frames)) - return state - - -def _load_from_db(job_id: str, stage: str) -> dict: - """Load checkpoint data from Postgres via core/db.""" from core.db.detect import get_stage_checkpoint checkpoint = get_stage_checkpoint(job_id, stage) + if not checkpoint: + raise ValueError(f"No checkpoint for {job_id}/{stage}") data = { "job_id": str(checkpoint.job_id), @@ -158,28 +90,20 @@ def _load_from_db(job_id: str, stage: str) -> dict: "frames_manifest": checkpoint.frames_manifest, "frames_meta": checkpoint.frames_meta, "filtered_frame_sequences": checkpoint.filtered_frame_sequences, - "boxes_by_frame": checkpoint.boxes_by_frame, - "text_candidates": checkpoint.text_candidates, - "unresolved_candidates": checkpoint.unresolved_candidates, - "detections": checkpoint.detections, + "stage_output_key": checkpoint.stage_output_key, "stats": checkpoint.stats, } - return data + raw_manifest = data.get("frames_manifest", {}) + manifest = {int(k): v for k, v in raw_manifest.items()} + frame_metadata = data.get("frames_meta", []) + frames = load_frames(manifest, frame_metadata) -def _load_from_s3(job_id: str, stage: str) -> dict: - """Fallback: load checkpoint JSON from S3.""" - from core.storage.s3 import download_to_temp + state = deserialize_state(data, frames) - key = f"{CHECKPOINT_PREFIX}/{job_id}/stages/{stage}.json" - tmp_path = download_to_temp(BUCKET, key) - try: - with open(tmp_path) as f: - data = json.load(f) - finally: - os.unlink(tmp_path) - - return data + logger.info("Checkpoint loaded: %s/%s (%d frames, scenario=%s)", + job_id, stage, len(frames), checkpoint.is_scenario) + return state # --------------------------------------------------------------------------- @@ -188,25 +112,5 @@ def _load_from_s3(job_id: str, stage: str) -> dict: def list_checkpoints(job_id: str) -> list[str]: """List available checkpoint stages for a job.""" - if _has_db(): - return _list_from_db(job_id) - return _list_from_s3(job_id) - - -def _list_from_db(job_id: str) -> list[str]: from core.db.detect import list_stage_checkpoints return list_stage_checkpoints(job_id) - - -def _list_from_s3(job_id: str) -> list[str]: - from core.storage.s3 import list_objects - - prefix = f"{CHECKPOINT_PREFIX}/{job_id}/stages/" - objects = list_objects(BUCKET, prefix) - - stages = [] - for obj in objects: - name = Path(obj["key"]).stem - stages.append(name) - - return stages diff --git a/detect/graph.py b/detect/graph.py index cac13d8..5bfb603 100644 --- a/detect/graph.py +++ b/detect/graph.py @@ -17,6 +17,7 @@ from detect.profiles import SoccerBroadcastProfile from detect.state import DetectState from detect.stages.frame_extractor import extract_frames from detect.stages.scene_filter import scene_filter +from detect.stages.edge_detector import detect_edge_regions from detect.stages.yolo_detector import detect_objects from detect.stages.preprocess import preprocess_regions from detect.stages.ocr_stage import run_ocr @@ -31,6 +32,7 @@ INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode NODES = [ "extract_frames", "filter_scenes", + "detect_edges", "detect_objects", "preprocess", "run_ocr", @@ -119,6 +121,28 @@ def node_filter_scenes(state: DetectState) -> dict: return {"filtered_frames": kept, "stats": stats} +def node_detect_edges(state: DetectState) -> dict: + _emit_transition(state, "detect_edges", "running") + + with trace_node(state, "detect_edges") as span: + profile = _get_profile(state) + config = profile.region_analysis_config() + frames = state.get("filtered_frames", []) + job_id = state.get("job_id") + + regions = detect_edge_regions( + frames, config, inference_url=INFERENCE_URL, job_id=job_id, + ) + total = sum(len(r) for r in regions.values()) + span.set_output({"frames": len(frames), "edge_regions": total}) + + stats = state.get("stats", PipelineStats()) + stats.cv_regions_detected = total + + _emit_transition(state, "detect_edges", "done") + return {"edge_regions_by_frame": regions, "stats": stats} + + def node_detect_objects(state: DetectState) -> dict: _emit_transition(state, "detect_objects", "running") @@ -359,6 +383,7 @@ def _checkpointing_node(node_name: str, node_fn): NODE_FUNCTIONS = [ ("extract_frames", node_extract_frames), ("filter_scenes", node_filter_scenes), + ("detect_edges", node_detect_edges), ("detect_objects", node_detect_objects), ("preprocess", node_preprocess), ("run_ocr", node_run_ocr), diff --git a/detect/inference/client.py b/detect/inference/client.py index f75db73..2f1bb90 100644 --- a/detect/inference/client.py +++ b/detect/inference/client.py @@ -16,7 +16,7 @@ import numpy as np import requests from PIL import Image -from .types import DetectResult, OCRResult, ServerStatus, VLMResult +from .types import DetectResult, OCRResult, RegionDebugResult, RegionResult, ServerStatus, VLMResult logger = logging.getLogger(__name__) @@ -145,6 +145,92 @@ class InferenceClient: reasoning=data.get("reasoning", ""), ) + def detect_edges( + self, + image: np.ndarray, + edge_canny_low: int = 50, + edge_canny_high: int = 150, + edge_hough_threshold: int = 80, + edge_hough_min_length: int = 100, + edge_hough_max_gap: int = 10, + edge_pair_max_distance: int = 200, + edge_pair_min_distance: int = 15, + ) -> list[RegionResult]: + """Run edge detection on an image.""" + payload = { + "image": _encode_image(image), + "edge_canny_low": edge_canny_low, + "edge_canny_high": edge_canny_high, + "edge_hough_threshold": edge_hough_threshold, + "edge_hough_min_length": edge_hough_min_length, + "edge_hough_max_gap": edge_hough_max_gap, + "edge_pair_max_distance": edge_pair_max_distance, + "edge_pair_min_distance": edge_pair_min_distance, + } + + resp = self.session.post( + f"{self.base_url}/detect_edges", + json=payload, + timeout=self.timeout, + ) + resp.raise_for_status() + + results = [] + for r in resp.json().get("regions", []): + result = RegionResult( + x=r["x"], y=r["y"], w=r["w"], h=r["h"], + confidence=r["confidence"], label=r["label"], + ) + results.append(result) + return results + + def detect_edges_debug( + self, + image: np.ndarray, + edge_canny_low: int = 50, + edge_canny_high: int = 150, + edge_hough_threshold: int = 80, + edge_hough_min_length: int = 100, + edge_hough_max_gap: int = 10, + edge_pair_max_distance: int = 200, + edge_pair_min_distance: int = 15, + ) -> RegionDebugResult: + """Run edge detection with debug overlays.""" + payload = { + "image": _encode_image(image), + "edge_canny_low": edge_canny_low, + "edge_canny_high": edge_canny_high, + "edge_hough_threshold": edge_hough_threshold, + "edge_hough_min_length": edge_hough_min_length, + "edge_hough_max_gap": edge_hough_max_gap, + "edge_pair_max_distance": edge_pair_max_distance, + "edge_pair_min_distance": edge_pair_min_distance, + } + + resp = self.session.post( + f"{self.base_url}/detect_edges/debug", + json=payload, + timeout=self.timeout, + ) + resp.raise_for_status() + + data = resp.json() + regions = [] + for r in data.get("regions", []): + region = RegionResult( + x=r["x"], y=r["y"], w=r["w"], h=r["h"], + confidence=r["confidence"], label=r["label"], + ) + regions.append(region) + + return RegionDebugResult( + regions=regions, + edge_overlay_b64=data.get("edge_overlay_b64", ""), + lines_overlay_b64=data.get("lines_overlay_b64", ""), + horizontal_count=data.get("horizontal_count", 0), + pair_count=data.get("pair_count", 0), + ) + def load_model(self, model: str, quantization: str = "fp16") -> None: """Request the server to load a model into VRAM.""" self.session.post( diff --git a/detect/inference/types.py b/detect/inference/types.py index ccb66c3..90e3ba7 100644 --- a/detect/inference/types.py +++ b/detect/inference/types.py @@ -38,6 +38,27 @@ class VLMResult: reasoning: str +@dataclass +class RegionResult: + """A candidate region from CV analysis.""" + x: int + y: int + w: int + h: int + confidence: float + label: str + + +@dataclass +class RegionDebugResult: + """CV region analysis with debug overlays.""" + regions: list[RegionResult] = field(default_factory=list) + edge_overlay_b64: str = "" + lines_overlay_b64: str = "" + horizontal_count: int = 0 + pair_count: int = 0 + + @dataclass class ModelInfo: """Info about a loaded model.""" diff --git a/detect/profiles/__init__.py b/detect/profiles/__init__.py index c77ed2e..4b21b5b 100644 --- a/detect/profiles/__init__.py +++ b/detect/profiles/__init__.py @@ -9,6 +9,19 @@ from .base import ( ) from .soccer import SoccerBroadcastProfile +_PROFILES: dict[str, type] = { + "soccer_broadcast": SoccerBroadcastProfile, +} + + +def get_profile(name: str) -> ContentTypeProfile: + """Get a profile instance by name.""" + cls = _PROFILES.get(name) + if cls is None: + raise ValueError(f"Unknown profile: {name!r}. Available: {list(_PROFILES)}") + return cls() + + __all__ = [ "ContentTypeProfile", "CropContext", @@ -18,4 +31,5 @@ __all__ = [ "ResolverConfig", "SceneFilterConfig", "SoccerBroadcastProfile", + "get_profile", ] diff --git a/detect/profiles/base.py b/detect/profiles/base.py index 00b6419..b0c4c03 100644 --- a/detect/profiles/base.py +++ b/detect/profiles/base.py @@ -44,6 +44,19 @@ class ResolverConfig: fuzzy_threshold: int = 75 +@dataclass +class RegionAnalysisConfig: + enabled: bool = True + # Edge detection (Canny + HoughLinesP) + edge_canny_low: int = 50 + edge_canny_high: int = 150 + edge_hough_threshold: int = 80 + edge_hough_min_length: int = 100 + edge_hough_max_gap: int = 10 + edge_pair_max_distance: int = 200 + edge_pair_min_distance: int = 15 + + @dataclass class CropContext: image: bytes @@ -56,6 +69,7 @@ class ContentTypeProfile(Protocol): def frame_extraction_config(self) -> FrameExtractionConfig: ... def scene_filter_config(self) -> SceneFilterConfig: ... + def region_analysis_config(self) -> RegionAnalysisConfig: ... def detection_config(self) -> DetectionConfig: ... def ocr_config(self) -> OCRConfig: ... def resolver_config(self) -> ResolverConfig: ... diff --git a/detect/profiles/soccer.py b/detect/profiles/soccer.py index 916b651..60e2a66 100644 --- a/detect/profiles/soccer.py +++ b/detect/profiles/soccer.py @@ -9,6 +9,7 @@ from .base import ( DetectionConfig, FrameExtractionConfig, OCRConfig, + RegionAnalysisConfig, ResolverConfig, SceneFilterConfig, ) @@ -23,6 +24,17 @@ class SoccerBroadcastProfile: def scene_filter_config(self) -> SceneFilterConfig: return SceneFilterConfig(hamming_threshold=8, enabled=True) + def region_analysis_config(self) -> RegionAnalysisConfig: + return RegionAnalysisConfig( + edge_canny_low=50, + edge_canny_high=150, + edge_hough_threshold=80, + edge_hough_min_length=100, + edge_hough_max_gap=10, + edge_pair_max_distance=200, + edge_pair_min_distance=15, + ) + def detection_config(self) -> DetectionConfig: return DetectionConfig( model_name="yolov8n.pt", diff --git a/detect/sse_contract.py b/detect/sse_contract.py index d5161d0..b772cc3 100644 --- a/detect/sse_contract.py +++ b/detect/sse_contract.py @@ -34,6 +34,7 @@ class BoundingBoxEvent(BaseModel): label: str resolved_brand: Optional[str] = None source: Optional[str] = None + stage: Optional[str] = None class BrandSummary(BaseModel): """Per-brand stats in the final report.""" @@ -54,6 +55,7 @@ class StatsUpdate(BaseModel): """Funnel statistics snapshot. SSE event: stats_update""" frames_extracted: int = 0 frames_after_scene_filter: int = 0 + cv_regions_detected: int = 0 regions_detected: int = 0 regions_resolved_by_ocr: int = 0 regions_escalated_to_local_vlm: int = 0 diff --git a/detect/stages/edge_detector.py b/detect/stages/edge_detector.py new file mode 100644 index 0000000..405f6cb --- /dev/null +++ b/detect/stages/edge_detector.py @@ -0,0 +1,174 @@ +""" +Stage — Edge Detection + +Canny + HoughLinesP to find horizontal line pairs that bound +advertising hoardings. Pure OpenCV, no ML models. + +Two modes: + - Remote: calls GPU inference server over HTTP + - Local: imports cv2 directly (OpenCV on same machine) + +Emits frame_update events with bounding boxes for the frame viewer. +""" + +from __future__ import annotations + +import base64 +import io +import logging +import time + +from PIL import Image + +from detect import emit +from detect.models import BoundingBox, Frame +from detect.profiles.base import RegionAnalysisConfig + +logger = logging.getLogger(__name__) + + +def _frame_to_b64(frame: Frame) -> str: + """Encode frame as base64 JPEG for SSE frame_update events.""" + img = Image.fromarray(frame.image) + buf = io.BytesIO() + img.save(buf, format="JPEG", quality=70) + return base64.b64encode(buf.getvalue()).decode() + + +def _detect_remote( + frame: Frame, + config: RegionAnalysisConfig, + inference_url: str, + job_id: str = "", + log_level: str = "INFO", +) -> list[BoundingBox]: + """Call the inference server over HTTP.""" + from detect.inference import InferenceClient + + client = InferenceClient( + base_url=inference_url, job_id=job_id, log_level=log_level, + ) + results = client.detect_edges( + image=frame.image, + edge_canny_low=config.edge_canny_low, + edge_canny_high=config.edge_canny_high, + edge_hough_threshold=config.edge_hough_threshold, + edge_hough_min_length=config.edge_hough_min_length, + edge_hough_max_gap=config.edge_hough_max_gap, + edge_pair_max_distance=config.edge_pair_max_distance, + edge_pair_min_distance=config.edge_pair_min_distance, + ) + boxes = [] + for r in results: + box = BoundingBox( + x=r.x, y=r.y, w=r.w, h=r.h, + confidence=r.confidence, label=r.label, + ) + boxes.append(box) + return boxes + + +_cv_edges_mod = None + + +def _load_cv_edges(): + """Load edges module directly — gpu/models/__init__.py has GPU-container-only imports.""" + global _cv_edges_mod + if _cv_edges_mod is None: + import importlib.util + from pathlib import Path + + spec = importlib.util.spec_from_file_location("cv_edges", Path("gpu/models/cv/edges.py")) + _cv_edges_mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(_cv_edges_mod) + return _cv_edges_mod + + +def _detect_local(frame: Frame, config: RegionAnalysisConfig) -> list[BoundingBox]: + """Run edge detection in-process (requires opencv-python).""" + detect_edges_fn = _load_cv_edges().detect_edges + + edge_results = detect_edges_fn( + frame.image, + canny_low=config.edge_canny_low, + canny_high=config.edge_canny_high, + hough_threshold=config.edge_hough_threshold, + hough_min_length=config.edge_hough_min_length, + hough_max_gap=config.edge_hough_max_gap, + pair_max_distance=config.edge_pair_max_distance, + pair_min_distance=config.edge_pair_min_distance, + ) + + boxes = [] + for r in edge_results: + box = BoundingBox( + x=r["x"], y=r["y"], w=r["w"], h=r["h"], + confidence=r["confidence"], label=r["label"], + ) + boxes.append(box) + return boxes + + +def detect_edge_regions( + frames: list[Frame], + config: RegionAnalysisConfig, + inference_url: str | None = None, + job_id: str | None = None, +) -> dict[int, list[BoundingBox]]: + """ + Run edge detection on all frames. + + Returns a dict mapping frame sequence → list of bounding boxes. + """ + if not config.enabled: + emit.log(job_id, "EdgeDetection", "INFO", "Edge detection disabled, skipping") + return {} + + mode = "remote" if inference_url else "local" + emit.log(job_id, "EdgeDetection", "INFO", + f"Detecting edges in {len(frames)} frames (mode={mode})") + + all_boxes: dict[int, list[BoundingBox]] = {} + total_regions = 0 + + for i, frame in enumerate(frames): + t0 = time.monotonic() + if inference_url: + from detect.emit import _run_log_level + boxes = _detect_remote( + frame, config, inference_url, + job_id=job_id or "", log_level=_run_log_level, + ) + else: + boxes = _detect_local(frame, config) + analysis_ms = (time.monotonic() - t0) * 1000 + + all_boxes[frame.sequence] = boxes + total_regions += len(boxes) + + emit.log(job_id, "EdgeDetection", "DEBUG", + f"Frame {frame.sequence}: {len(boxes)} regions in {analysis_ms:.0f}ms" + + (f" [{', '.join(b.label for b in boxes)}]" if boxes else "")) + + if boxes and job_id: + box_dicts = [ + { + "x": b.x, "y": b.y, "w": b.w, "h": b.h, + "confidence": b.confidence, "label": b.label, + "stage": "detect_edges", + } + for b in boxes + ] + emit.frame_update( + job_id, + frame_ref=frame.sequence, + timestamp=frame.timestamp, + jpeg_b64=_frame_to_b64(frame), + boxes=box_dicts, + ) + + emit.log(job_id, "EdgeDetection", "INFO", + f"Found {total_regions} edge regions across {len(frames)} frames") + emit.stats(job_id, cv_regions_detected=total_regions) + + return all_boxes diff --git a/detect/stages/registry/__init__.py b/detect/stages/registry/__init__.py index 1e410df..8fd7ce6 100644 --- a/detect/stages/registry/__init__.py +++ b/detect/stages/registry/__init__.py @@ -3,6 +3,7 @@ Stage registry — registers all built-in stages. Split by category: preprocessing.py — extract_frames, filter_scenes + cv_analysis.py — detect_edges (+ future: detect_contours, detect_color, merge_regions) detection.py — detect_objects, run_ocr resolution.py — match_brands escalation.py — escalate_vlm, escalate_cloud @@ -11,6 +12,7 @@ Split by category: """ from . import preprocessing +from . import cv_analysis from . import detection from . import resolution from . import escalation @@ -19,6 +21,7 @@ from . import output def register_all(): preprocessing.register() + cv_analysis.register() detection.register() resolution.register() escalation.register() diff --git a/detect/stages/registry/cv_analysis.py b/detect/stages/registry/cv_analysis.py new file mode 100644 index 0000000..8ffb6c8 --- /dev/null +++ b/detect/stages/registry/cv_analysis.py @@ -0,0 +1,45 @@ +"""Registration for CV analysis stages: edge detection.""" + +from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage +from ._serializers import serialize_dataclass_list, deserialize_bounding_box + + +def _ser_regions(state: dict, job_id: str) -> dict: + regions = state.get("edge_regions_by_frame", {}) + serialized = { + str(seq): serialize_dataclass_list(bl) for seq, bl in regions.items() + } + return {"edge_regions_by_frame": serialized} + + +def _deser_regions(data: dict, job_id: str) -> dict: + regions = {} + for seq_str, box_dicts in data.get("edge_regions_by_frame", {}).items(): + regions[int(seq_str)] = [deserialize_bounding_box(b) for b in box_dicts] + return {"edge_regions_by_frame": regions} + + +def register(): + edge_detection = StageDefinition( + name="detect_edges", + label="Edge Detection", + description="Canny + HoughLinesP — find horizontal line pairs (hoarding boundaries)", + category="cv_analysis", + io=StageIO( + reads=["filtered_frames"], + writes=["edge_regions_by_frame"], + ), + config_fields=[ + StageConfigField("enabled", "bool", True, "Enable region analysis"), + StageConfigField("edge_canny_low", "int", 50, "Canny low threshold", min=0, max=255), + StageConfigField("edge_canny_high", "int", 150, "Canny high threshold", min=0, max=255), + StageConfigField("edge_hough_threshold", "int", 80, "Hough accumulator threshold", min=1, max=500), + StageConfigField("edge_hough_min_length", "int", 100, "Min line length (px)", min=10, max=2000), + StageConfigField("edge_hough_max_gap", "int", 10, "Max line gap (px)", min=1, max=100), + StageConfigField("edge_pair_max_distance", "int", 200, "Max distance between line pair (px)", min=10, max=500), + StageConfigField("edge_pair_min_distance", "int", 15, "Min distance between line pair (px)", min=5, max=200), + ], + serialize_fn=_ser_regions, + deserialize_fn=_deser_regions, + ) + register_stage(edge_detection) diff --git a/detect/state.py b/detect/state.py index 0cf4b4a..bff6cbe 100644 --- a/detect/state.py +++ b/detect/state.py @@ -22,6 +22,7 @@ class DetectState(TypedDict, total=False): # Stage outputs frames: list[Frame] filtered_frames: list[Frame] + edge_regions_by_frame: dict[int, list[BoundingBox]] boxes_by_frame: dict[int, list[BoundingBox]] preprocessed_crops: dict # "{frame_seq}_{box_idx}" → np.ndarray text_candidates: list[TextCandidate] diff --git a/gpu/models/cv/__init__.py b/gpu/models/cv/__init__.py new file mode 100644 index 0000000..339bc4c --- /dev/null +++ b/gpu/models/cv/__init__.py @@ -0,0 +1 @@ +"""CV operations — pure OpenCV, no ML models.""" diff --git a/gpu/models/cv/edges.py b/gpu/models/cv/edges.py new file mode 100644 index 0000000..7a29f49 --- /dev/null +++ b/gpu/models/cv/edges.py @@ -0,0 +1,258 @@ +""" +Edge detection — Canny + HoughLinesP → parallel line pairs → bounding boxes. + +Finds horizontal line pairs with consistent spacing, which correspond to +the top and bottom edges of advertising hoardings. +""" + +from __future__ import annotations + +import base64 +import io + +import cv2 +import numpy as np + + +def detect_edges( + image: np.ndarray, + canny_low: int = 50, + canny_high: int = 150, + hough_threshold: int = 80, + hough_min_length: int = 100, + hough_max_gap: int = 10, + pair_max_distance: int = 200, + pair_min_distance: int = 15, +) -> list[dict]: + """ + Find horizontal line pairs that likely bound advertising hoardings. + + Returns list of dicts with keys: x, y, w, h, confidence, label. + Each box represents the region between a detected pair of parallel + horizontal lines. + """ + gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + edges = cv2.Canny(gray, canny_low, canny_high) + + raw_lines = cv2.HoughLinesP( + edges, + rho=1, + theta=np.pi / 180, + threshold=hough_threshold, + minLineLength=hough_min_length, + maxLineGap=hough_max_gap, + ) + + if raw_lines is None: + return [] + + # Filter to near-horizontal lines (within 10 degrees) + horizontals = _filter_horizontal(raw_lines, max_angle_deg=10) + + if len(horizontals) < 2: + return [] + + # Find pairs of parallel horizontals with consistent spacing + pairs = _find_line_pairs( + horizontals, + min_distance=pair_min_distance, + max_distance=pair_max_distance, + ) + + # Convert pairs to bounding boxes + h, w = image.shape[:2] + results = [] + for top_line, bottom_line in pairs: + box = _pair_to_bbox(top_line, bottom_line, frame_width=w, frame_height=h) + if box is not None: + results.append(box) + + return results + + +def _filter_horizontal(lines: np.ndarray, max_angle_deg: float = 10) -> list[tuple]: + """Keep only lines within max_angle_deg of horizontal.""" + max_slope = np.tan(np.radians(max_angle_deg)) + result = [] + for line in lines: + x1, y1, x2, y2 = line[0] + dx = x2 - x1 + if dx == 0: + continue + slope = abs((y2 - y1) / dx) + if slope <= max_slope: + y_mid = (y1 + y2) / 2 + x_min = min(x1, x2) + x_max = max(x1, x2) + length = np.sqrt(dx**2 + (y2 - y1) ** 2) + result.append((x_min, x_max, y_mid, length)) + return result + + +def _find_line_pairs( + horizontals: list[tuple], + min_distance: int, + max_distance: int, +) -> list[tuple]: + """ + Find pairs of horizontal lines that could be top/bottom of a hoarding. + + Lines must overlap horizontally and be spaced within [min_distance, max_distance]. + """ + # Sort by y position + sorted_lines = sorted(horizontals, key=lambda l: l[2]) + + pairs = [] + used = set() + + for i, top in enumerate(sorted_lines): + if i in used: + continue + for j, bottom in enumerate(sorted_lines[i + 1 :], start=i + 1): + if j in used: + continue + + y_gap = bottom[2] - top[2] + if y_gap < min_distance: + continue + if y_gap > max_distance: + break # sorted by y, no point checking further + + # Check horizontal overlap + overlap_start = max(top[0], bottom[0]) + overlap_end = min(top[1], bottom[1]) + overlap = overlap_end - overlap_start + + # Require at least 50% overlap relative to shorter line + shorter_length = min(top[1] - top[0], bottom[1] - bottom[0]) + if shorter_length > 0 and overlap / shorter_length >= 0.5: + pairs.append((top, bottom)) + used.add(i) + used.add(j) + break + + return pairs + + +def _pair_to_bbox( + top: tuple, + bottom: tuple, + frame_width: int, + frame_height: int, +) -> dict | None: + """Convert a line pair to a bounding box dict.""" + x = int(max(0, min(top[0], bottom[0]))) + y = int(max(0, top[2])) + x2 = int(min(frame_width, max(top[1], bottom[1]))) + y2 = int(min(frame_height, bottom[2])) + w = x2 - x + h = y2 - y + + if w < 20 or h < 5: + return None + + # Confidence based on line lengths relative to box width + avg_line_length = (top[3] + bottom[3]) / 2 + coverage = min(1.0, avg_line_length / max(w, 1)) + + return { + "x": x, + "y": y, + "w": w, + "h": h, + "confidence": round(coverage, 3), + "label": "edge_region", + } + + +def _np_to_b64_jpeg(image: np.ndarray, quality: int = 70) -> str: + """Encode a numpy image (BGR or grayscale) as base64 JPEG.""" + ok, buf = cv2.imencode(".jpg", image, [cv2.IMWRITE_JPEG_QUALITY, quality]) + if not ok: + return "" + return base64.b64encode(buf.tobytes()).decode() + + +def detect_edges_debug( + image: np.ndarray, + canny_low: int = 50, + canny_high: int = 150, + hough_threshold: int = 80, + hough_min_length: int = 100, + hough_max_gap: int = 10, + pair_max_distance: int = 200, + pair_min_distance: int = 15, +) -> dict: + """ + Same as detect_edges but returns intermediate visualizations. + + Returns dict with: + regions: list[dict] — same boxes as detect_edges + edge_overlay_b64: str — Canny edge image as base64 JPEG + lines_overlay_b64: str — frame with Hough lines drawn + horizontal_count: int — number of horizontal lines found + pair_count: int — number of line pairs found + """ + gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + edges = cv2.Canny(gray, canny_low, canny_high) + + # Edge overlay — Canny output as-is (white edges on black) + edge_overlay_b64 = _np_to_b64_jpeg(edges) + + raw_lines = cv2.HoughLinesP( + edges, + rho=1, + theta=np.pi / 180, + threshold=hough_threshold, + minLineLength=hough_min_length, + maxLineGap=hough_max_gap, + ) + + # Lines overlay — draw all Hough lines on a copy of the frame + lines_vis = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + if raw_lines is not None: + for line in raw_lines: + x1, y1, x2, y2 = line[0] + cv2.line(lines_vis, (x1, y1), (x2, y2), (0, 0, 255), 1) + + horizontals = [] + if raw_lines is not None: + horizontals = _filter_horizontal(raw_lines, max_angle_deg=10) + + # Draw horizontal lines in cyan, thicker + for h_line in horizontals: + x_min, x_max, y_mid, _ = h_line + cv2.line(lines_vis, (int(x_min), int(y_mid)), (int(x_max), int(y_mid)), (255, 255, 0), 2) + + pairs = [] + if len(horizontals) >= 2: + pairs = _find_line_pairs( + horizontals, + min_distance=pair_min_distance, + max_distance=pair_max_distance, + ) + + # Draw paired lines in green + for top_line, bottom_line in pairs: + cv2.line(lines_vis, (int(top_line[0]), int(top_line[2])), + (int(top_line[1]), int(top_line[2])), (0, 255, 0), 2) + cv2.line(lines_vis, (int(bottom_line[0]), int(bottom_line[2])), + (int(bottom_line[1]), int(bottom_line[2])), (0, 255, 0), 2) + + lines_overlay_b64 = _np_to_b64_jpeg(lines_vis) + + # Build region boxes (same logic as detect_edges) + h, w = image.shape[:2] + regions = [] + for top_line, bottom_line in pairs: + box = _pair_to_bbox(top_line, bottom_line, frame_width=w, frame_height=h) + if box is not None: + regions.append(box) + + return { + "regions": regions, + "edge_overlay_b64": edge_overlay_b64, + "lines_overlay_b64": lines_overlay_b64, + "horizontal_count": len(horizontals), + "pair_count": len(pairs), + } diff --git a/gpu/models/inference_contract.py b/gpu/models/inference_contract.py new file mode 100644 index 0000000..5f63639 --- /dev/null +++ b/gpu/models/inference_contract.py @@ -0,0 +1,112 @@ +""" +Pydantic Models - GENERATED FILE + +Do not edit directly. Regenerate using modelgen. +""" + +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Optional +from uuid import UUID + +from pydantic import BaseModel, Field + +class DetectRequest(BaseModel): + """Request body for object detection.""" + image: str + model: Optional[str] = None + confidence: Optional[float] = None + target_classes: Optional[List[str]] = None + +class BBox(BaseModel): + """A detected bounding box.""" + x: int + y: int + w: int + h: int + confidence: float + label: str + +class DetectResponse(BaseModel): + """Response from object detection.""" + detections: List[BBox] = Field(default_factory=list) + +class OCRRequest(BaseModel): + """Request body for OCR.""" + image: str + languages: Optional[List[str]] = None + +class OCRTextResult(BaseModel): + """A single OCR text extraction result.""" + text: str + confidence: float + bbox: List[int] = Field(default_factory=list) + +class OCRResponse(BaseModel): + """Response from OCR.""" + results: List[OCRTextResult] = Field(default_factory=list) + +class PreprocessRequest(BaseModel): + """Request body for image preprocessing.""" + image: str + binarize: bool = False + deskew: bool = False + contrast: bool = True + +class PreprocessResponse(BaseModel): + """Response from preprocessing.""" + image: str + +class VLMRequest(BaseModel): + """Request body for visual language model query.""" + image: str + prompt: str + model: Optional[str] = None + +class VLMResponse(BaseModel): + """Response from VLM.""" + brand: str + confidence: float + reasoning: str + +class AnalyzeRegionsRequest(BaseModel): + """Request body for CV region analysis.""" + image: str + edge_canny_low: int = 50 + edge_canny_high: int = 150 + edge_hough_threshold: int = 80 + edge_hough_min_length: int = 100 + edge_hough_max_gap: int = 10 + edge_pair_max_distance: int = 200 + edge_pair_min_distance: int = 15 + +class RegionBox(BaseModel): + """A candidate region from CV analysis.""" + x: int + y: int + w: int + h: int + confidence: float + label: str + +class AnalyzeRegionsResponse(BaseModel): + """Response from CV region analysis.""" + regions: List[RegionBox] = Field(default_factory=list) + +class AnalyzeRegionsDebugResponse(BaseModel): + """Response from CV region analysis with debug overlays.""" + regions: List[RegionBox] = Field(default_factory=list) + edge_overlay_b64: str = "" + lines_overlay_b64: str = "" + horizontal_count: int = 0 + pair_count: int = 0 + +class ConfigUpdate(BaseModel): + """Request body for updating server configuration.""" + device: Optional[str] = None + yolo_model: Optional[str] = None + yolo_confidence: Optional[float] = None + vram_budget_mb: Optional[int] = None + strategy: Optional[str] = None + ocr_languages: Optional[List[str]] = None + ocr_min_confidence: Optional[float] = None diff --git a/gpu/server.py b/gpu/server.py index 5c22138..76ce8a6 100644 --- a/gpu/server.py +++ b/gpu/server.py @@ -52,74 +52,25 @@ def _gpu_log(job_id: str, log_level: str, stage: str, level: str, msg: str): emit_log(job_id, stage, level, msg, log_level=log_level) -# --- Request/Response models --- +# --- Request/Response models (generated from core/schema/models/inference.py) --- -class DetectRequest(BaseModel): - image: str - model: str | None = None - confidence: float | None = None - target_classes: list[str] | None = None - - -class BBox(BaseModel): - x: int - y: int - w: int - h: int - confidence: float - label: str - - -class DetectResponse(BaseModel): - detections: list[BBox] - - -class OCRRequest(BaseModel): - image: str - languages: list[str] | None = None - - -class OCRTextResult(BaseModel): - text: str - confidence: float - bbox: list[int] - - -class OCRResponse(BaseModel): - results: list[OCRTextResult] - - -class PreprocessRequest(BaseModel): - image: str - binarize: bool = False - deskew: bool = False - contrast: bool = True - - -class PreprocessResponse(BaseModel): - image: str # base64 JPEG of processed image - - -class VLMRequest(BaseModel): - image: str - prompt: str - model: str | None = None - - -class VLMResponse(BaseModel): - brand: str - confidence: float - reasoning: str - - -class ConfigUpdate(BaseModel): - device: str | None = None - yolo_model: str | None = None - yolo_confidence: float | None = None - vram_budget_mb: int | None = None - strategy: str | None = None - ocr_languages: list[str] | None = None - ocr_min_confidence: float | None = None +from models.inference_contract import ( + AnalyzeRegionsDebugResponse, + AnalyzeRegionsRequest, + AnalyzeRegionsResponse, + BBox, + ConfigUpdate, + DetectRequest, + DetectResponse, + OCRRequest, + OCRResponse, + OCRTextResult, + PreprocessRequest, + PreprocessResponse, + RegionBox, + VLMRequest, + VLMResponse, +) # --- App --- @@ -281,6 +232,84 @@ def vlm(req: VLMRequest, request: Request): return VLMResponse(**result) +@app.post("/detect_edges", response_model=AnalyzeRegionsResponse) +def detect_edges_endpoint(req: AnalyzeRegionsRequest, request: Request): + job_id, log_level = _job_ctx(request) + + try: + image = _decode_image(req.image) + h, w = image.shape[:2] + except Exception as e: + raise HTTPException(status_code=400, detail=f"Bad image: {e}") + + try: + t0 = time.monotonic() + from models.cv.edges import detect_edges + + edge_regions = detect_edges( + image, + canny_low=req.edge_canny_low, + canny_high=req.edge_canny_high, + hough_threshold=req.edge_hough_threshold, + hough_min_length=req.edge_hough_min_length, + hough_max_gap=req.edge_hough_max_gap, + pair_max_distance=req.edge_pair_max_distance, + pair_min_distance=req.edge_pair_min_distance, + ) + infer_ms = (time.monotonic() - t0) * 1000 + + _gpu_log(job_id, log_level, "GPU:CV", "DEBUG", + f"Edge analysis {w}x{h}: {infer_ms:.0f}ms → {len(edge_regions)} regions") + except Exception as e: + raise HTTPException(status_code=500, detail=f"Region analysis failed: {e}") + + boxes = [RegionBox(**r) for r in edge_regions] + return AnalyzeRegionsResponse(regions=boxes) + + +@app.post("/detect_edges/debug", response_model=AnalyzeRegionsDebugResponse) +def detect_edges_debug_endpoint(req: AnalyzeRegionsRequest, request: Request): + job_id, log_level = _job_ctx(request) + + try: + image = _decode_image(req.image) + h, w = image.shape[:2] + except Exception as e: + raise HTTPException(status_code=400, detail=f"Bad image: {e}") + + try: + t0 = time.monotonic() + from models.cv.edges import detect_edges_debug + + result = detect_edges_debug( + image, + canny_low=req.edge_canny_low, + canny_high=req.edge_canny_high, + hough_threshold=req.edge_hough_threshold, + hough_min_length=req.edge_hough_min_length, + hough_max_gap=req.edge_hough_max_gap, + pair_max_distance=req.edge_pair_max_distance, + pair_min_distance=req.edge_pair_min_distance, + ) + infer_ms = (time.monotonic() - t0) * 1000 + + _gpu_log(job_id, log_level, "GPU:CV", "DEBUG", + f"Edge debug {w}x{h}: {infer_ms:.0f}ms → {len(result['regions'])} regions, " + f"{result['horizontal_count']} horizontals, {result['pair_count']} pairs") + except Exception as e: + raise HTTPException(status_code=500, detail=f"Region debug analysis failed: {e}") + + boxes = [RegionBox(**r) for r in result["regions"]] + response = AnalyzeRegionsDebugResponse( + regions=boxes, + edge_overlay_b64=result["edge_overlay_b64"], + lines_overlay_b64=result["lines_overlay_b64"], + horizontal_count=result["horizontal_count"], + pair_count=result["pair_count"], + ) + return response + + if __name__ == "__main__": import uvicorn diff --git a/tests/detect/manual/list_scenarios.py b/tests/detect/manual/list_scenarios.py new file mode 100644 index 0000000..cd126c3 --- /dev/null +++ b/tests/detect/manual/list_scenarios.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +""" +List available scenarios and open one in the browser. + +Usage: + python tests/detect/manual/list_scenarios.py # list all + python tests/detect/manual/list_scenarios.py --open 1 # open scenario #1 + python tests/detect/manual/list_scenarios.py --open chelsea_edges_default # by label + +Prerequisites: + kubectl port-forward svc/postgres 5432:5432 & +""" + +from __future__ import annotations + +import argparse +import logging +import os +import sys +import webbrowser + +parser = argparse.ArgumentParser(description="List and open scenarios") +parser.add_argument("--open", type=str, default=None, + help="Open scenario by number (1-based) or label") +parser.add_argument("--db-url", + default=os.environ.get("DATABASE_URL", "postgresql://mpr:mpr@localhost:5432/mpr")) +parser.add_argument("--base-url", default="http://mpr.local.ar/detection/") +args = parser.parse_args() + +os.environ["DATABASE_URL"] = args.db_url +sys.path.insert(0, ".") + +logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(message)s") +logger = logging.getLogger(__name__) + + +def main(): + from core.db.detect import list_scenarios + + scenarios = list_scenarios() + + if not scenarios: + logger.info("No scenarios found. Create one with:") + logger.info(" python tests/detect/manual/seed_scenario.py") + return + + logger.info("") + logger.info("%3s %-35s %-12s %-18s %6s %s", "#", "Label", "Job ID", "Stage", "Frames", "Created") + logger.info("─" * 100) + + for i, s in enumerate(scenarios, 1): + manifest = s.frames_manifest or {} + created = str(s.created_at)[:19] if s.created_at else "—" + job_short = str(s.job_id)[:8] + logger.info("%3d %-35s %-12s %-18s %6d %s", + i, s.scenario_label, job_short, s.stage, len(manifest), created) + + logger.info("") + + if args.open: + target = None + try: + idx = int(args.open) - 1 + if 0 <= idx < len(scenarios): + target = scenarios[idx] + except ValueError: + for s in scenarios: + if s.scenario_label == args.open: + target = s + break + + if not target: + logger.error("Scenario not found: %s", args.open) + return + + url = f"{args.base_url}?job={target.job_id}#/editor/detect_edges" + logger.info("Opening: %s", url) + webbrowser.open(url) + else: + logger.info("To open a scenario:") + logger.info(" python tests/detect/manual/list_scenarios.py --open 1") + logger.info(" python tests/detect/manual/list_scenarios.py --open chelsea_edges_default") + + +if __name__ == "__main__": + main() diff --git a/tests/detect/manual/run_region_analysis.py b/tests/detect/manual/run_region_analysis.py new file mode 100644 index 0000000..f0fe663 --- /dev/null +++ b/tests/detect/manual/run_region_analysis.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 +""" +Run edge detection on test video frames — visual verification. + +Uses a minimal 3-stage pipeline: extract_frames → filter_scenes → detect_edges. +No YOLO, OCR, or downstream stages. + +Usage: + python tests/detect/manual/run_region_analysis.py [--job JOB_ID] [--port PORT] [--local] + +Opens: http://mpr.local.ar/detection/?job= + +What to look for in the frame viewer: + - "Edges" toggle appears (cyan) + - Cyan boxes around horizontal line pairs (hoarding edges) + - No boxes on players, ball, or sky + - Boxes concentrated in the lower third of the frame +""" + +import argparse +import logging +import os +import sys +import time as _time + +parser = argparse.ArgumentParser() +parser.add_argument("--job", default=f"cv-{int(_time.time()) % 100000}") +parser.add_argument("--port", type=int, default=6379) +parser.add_argument("--local", action="store_true", help="Run CV locally (no inference server)") +args = parser.parse_args() + +os.environ["REDIS_URL"] = f"redis://localhost:{args.port}/0" +if args.local: + os.environ.pop("INFERENCE_URL", None) + +logging.basicConfig(level=logging.DEBUG, format="%(levelname)-7s %(name)s — %(message)s") + +sys.path.insert(0, ".") + +from langgraph.graph import END, StateGraph + +from detect import emit +from detect.models import PipelineStats +from detect.profiles.soccer import SoccerBroadcastProfile +from detect.stages.frame_extractor import extract_frames +from detect.stages.scene_filter import scene_filter +from detect.stages.edge_detector import detect_edge_regions +from detect.state import DetectState + +logger = logging.getLogger(__name__) + +VIDEO = "media/mpr/out/chunks/95043d50-4df6-4ac8-bbd5-2ba873117c6e/chunk_0000.mp4" +INFERENCE_URL = os.environ.get("INFERENCE_URL") + + +# --- 3-stage pipeline --- + +NODES = ["extract_frames", "filter_scenes", "detect_edges"] + + +def _emit_transition(job_id: str, node: str, status: str, node_states: dict): + node_states[node] = status + nodes = [{"id": n, "status": node_states.get(n, "pending")} for n in NODES] + emit.graph_update(job_id, nodes) + + +def node_extract(state: DetectState) -> dict: + job_id = state.get("job_id", "") + ns = state.get("_node_states", {n: "pending" for n in NODES}) + _emit_transition(job_id, "extract_frames", "running", ns) + + profile = SoccerBroadcastProfile() + config = profile.frame_extraction_config() + frames = extract_frames(state["video_path"], config, job_id=job_id) + + _emit_transition(job_id, "extract_frames", "done", ns) + return {"frames": frames, "stats": PipelineStats(frames_extracted=len(frames)), "_node_states": ns} + + +def node_filter(state: DetectState) -> dict: + job_id = state.get("job_id", "") + ns = state.get("_node_states", {}) + _emit_transition(job_id, "filter_scenes", "running", ns) + + profile = SoccerBroadcastProfile() + config = profile.scene_filter_config() + kept = scene_filter(state.get("frames", []), config, job_id=job_id) + + stats = state.get("stats", PipelineStats()) + stats.frames_after_scene_filter = len(kept) + + _emit_transition(job_id, "filter_scenes", "done", ns) + return {"filtered_frames": kept, "stats": stats, "_node_states": ns} + + +def node_edges(state: DetectState) -> dict: + job_id = state.get("job_id", "") + ns = state.get("_node_states", {}) + _emit_transition(job_id, "detect_edges", "running", ns) + + profile = SoccerBroadcastProfile() + config = profile.region_analysis_config() + regions = detect_edge_regions( + state.get("filtered_frames", []), config, + inference_url=INFERENCE_URL, job_id=job_id, + ) + total = sum(len(r) for r in regions.values()) + + stats = state.get("stats", PipelineStats()) + stats.cv_regions_detected = total + + _emit_transition(job_id, "detect_edges", "done", ns) + return {"edge_regions_by_frame": regions, "stats": stats, "_node_states": ns} + + +def build_3stage_graph() -> StateGraph: + graph = StateGraph(DetectState) + graph.add_node("extract_frames", node_extract) + graph.add_node("filter_scenes", node_filter) + graph.add_node("detect_edges", node_edges) + graph.set_entry_point("extract_frames") + graph.add_edge("extract_frames", "filter_scenes") + graph.add_edge("filter_scenes", "detect_edges") + graph.add_edge("detect_edges", END) + return graph + + +def main(): + logger.info("Job: %s", args.job) + logger.info("Mode: %s", "remote" if INFERENCE_URL else "local") + logger.info("Pipeline: extract_frames → filter_scenes → detect_edges") + logger.info("Open: http://mpr.local.ar/detection/?job=%s", args.job) + input("\nPress Enter to start...") + + emit.set_run_context(run_id=args.job, parent_job_id=args.job, run_type="initial", log_level="DEBUG") + + graph = build_3stage_graph() + pipeline = graph.compile() + + initial_state = { + "video_path": VIDEO, + "job_id": args.job, + "profile_name": "soccer_broadcast", + } + + result = pipeline.invoke(initial_state) + + # Print results + regions = result.get("edge_regions_by_frame", {}) + total = sum(len(boxes) for boxes in regions.values()) + frames_with_regions = sum(1 for boxes in regions.values() if boxes) + + logger.info("Results:") + logger.info(" Total edge regions: %d", total) + logger.info(" Frames with regions: %d / %d", + frames_with_regions, len(result.get("filtered_frames", []))) + + for seq, boxes in sorted(regions.items()): + if boxes: + labels = [f"{b.label}({b.confidence:.2f})" for b in boxes] + logger.info(" Frame %d: %s", seq, ", ".join(labels)) + + logger.info("Done. Check the frame viewer for cyan boxes.") + logger.info("") + + # --- Parameter sensitivity --- + logger.info("=== Parameter sensitivity (local debug) ===") + + from detect.stages.edge_detector import _load_cv_edges + edges_mod = _load_cv_edges() + + filtered = result.get("filtered_frames", []) + if filtered: + sample = filtered[0] + for canny_low in [20, 50, 80, 120]: + dbg = edges_mod.detect_edges_debug(sample.image, canny_low=canny_low) + logger.info( + " canny_low=%d → %d horizontals, %d pairs, %d regions", + canny_low, dbg["horizontal_count"], dbg["pair_count"], len(dbg["regions"]), + ) + + logger.info("") + logger.info("=== Editor test ===") + logger.info(" Dashboard: http://mpr.local.ar/detection/?job=%s", args.job) + logger.info(" Editor: http://mpr.local.ar/detection/?job=%s#/editor/detect_edges", args.job) + + +if __name__ == "__main__": + main() diff --git a/tests/detect/manual/seed_scenario.py b/tests/detect/manual/seed_scenario.py new file mode 100644 index 0000000..104bd8e --- /dev/null +++ b/tests/detect/manual/seed_scenario.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +""" +Seed a scenario checkpoint from a video chunk. + +Extracts frames via ffmpeg, uploads to MinIO, creates a StageCheckpoint +in Postgres marked as a scenario. No pipeline, no Redis, no SSE. + +Prerequisites: + - Postgres reachable (port-forward or local) + - MinIO reachable (port-forward or local) + +Usage: + # With K8s port-forwards: + kubectl port-forward svc/postgres 5432:5432 & + kubectl port-forward svc/minio 9000:9000 & + + python tests/detect/manual/seed_scenario.py + + # Custom video: + python tests/detect/manual/seed_scenario.py --video media/mpr/out/chunks/.../chunk_0001.mp4 + +Then open: + http://mpr.local.ar/detection/?job=&stage=filter_scenes&editor=true +""" + +from __future__ import annotations + +import argparse +import logging +import os +import sys +import uuid + +parser = argparse.ArgumentParser(description="Seed a scenario checkpoint") +parser.add_argument("--video", + default="media/mpr/out/chunks/95043d50-4df6-4ac8-bbd5-2ba873117c6e/chunk_0000.mp4") +parser.add_argument("--label", default="chelsea_edges_default", + help="Scenario label for bookmarking") +parser.add_argument("--fps", type=float, default=2.0, help="Frames per second to extract") +parser.add_argument("--max-frames", type=int, default=20, help="Max frames to extract") +parser.add_argument("--db-url", + default=os.environ.get("DATABASE_URL", "postgresql://mpr:mpr@localhost:5432/mpr")) +parser.add_argument("--s3-url", + default=os.environ.get("S3_ENDPOINT_URL", "http://localhost:9000")) +args = parser.parse_args() + +# Set env before imports +os.environ["DATABASE_URL"] = args.db_url +os.environ["S3_ENDPOINT_URL"] = args.s3_url +os.environ.setdefault("AWS_ACCESS_KEY_ID", "minioadmin") +os.environ.setdefault("AWS_SECRET_ACCESS_KEY", "minioadmin") + +sys.path.insert(0, ".") + +logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(message)s") +logger = logging.getLogger(__name__) + + +def extract_frames_ffmpeg(video_path: str, fps: float, max_frames: int): + """Extract frames using ffmpeg subprocess — no pipeline dependencies.""" + import subprocess + import tempfile + from pathlib import Path + + import numpy as np + from PIL import Image + + from detect.models import Frame + + tmpdir = tempfile.mkdtemp(prefix="scenario_") + pattern = os.path.join(tmpdir, "frame_%04d.jpg") + + cmd = [ + "ffmpeg", "-i", video_path, + "-vf", f"fps={fps}", + "-frames:v", str(max_frames), + "-q:v", "2", + pattern, + "-y", "-loglevel", "error", + ] + subprocess.run(cmd, check=True) + + frames = [] + for jpg in sorted(Path(tmpdir).glob("frame_*.jpg")): + seq = int(jpg.stem.split("_")[1]) - 1 # 0-indexed + img = Image.open(jpg).convert("RGB") + image_array = np.array(img) + frame = Frame( + sequence=seq, + chunk_id=0, + timestamp=seq / fps, + image=image_array, + ) + frames.append(frame) + jpg.unlink() + + Path(tmpdir).rmdir() + return frames + + +def main(): + job_id = str(uuid.uuid4()) + video_path = args.video + + if not os.path.exists(video_path): + logger.error("Video not found: %s", video_path) + sys.exit(1) + + logger.info("Video: %s", video_path) + logger.info("Job ID: %s", job_id) + logger.info("Label: %s", args.label) + + # Ensure DB tables exist + from core.db.connection import create_tables + create_tables() + + # Extract frames + logger.info("Extracting frames (fps=%.1f, max=%d)...", args.fps, args.max_frames) + frames = extract_frames_ffmpeg(video_path, args.fps, args.max_frames) + logger.info("Extracted %d frames", len(frames)) + + # Upload frames to MinIO + from detect.checkpoint.frames import save_frames + logger.info("Uploading frames to MinIO...") + manifest = save_frames(job_id, frames) + logger.info("Uploaded %d frames", len(manifest)) + + # Build frame metadata + frames_meta = [ + { + "sequence": f.sequence, + "chunk_id": f.chunk_id, + "timestamp": f.timestamp, + "perceptual_hash": "", + } + for f in frames + ] + + # All frames are "filtered" (no scene filter ran) + filtered_sequences = [f.sequence for f in frames] + + # Save checkpoint as scenario + from core.db.detect import save_stage_checkpoint + from detect.checkpoint.frames import CHECKPOINT_PREFIX + + checkpoint = save_stage_checkpoint( + job_id=job_id, + stage="filter_scenes", + stage_index=1, + frames_prefix=f"{CHECKPOINT_PREFIX}/{job_id}/frames/", + frames_manifest={str(k): v for k, v in manifest.items()}, + frames_meta=frames_meta, + filtered_frame_sequences=filtered_sequences, + stage_output_key="", + stats={"frames_extracted": len(frames), "frames_after_scene_filter": len(frames)}, + config_snapshot={}, + config_overrides={}, + video_path=video_path, + profile_name="soccer_broadcast", + is_scenario=True, + scenario_label=args.label, + ) + + logger.info("") + logger.info("Scenario created:") + logger.info(" ID: %s", checkpoint.id) + logger.info(" Job: %s", job_id) + logger.info(" Label: %s", args.label) + logger.info(" Frames: %d", len(frames)) + logger.info("") + logger.info("Open in editor:") + logger.info(" http://mpr.local.ar/detection/?job=%s#/editor/detect_edges", job_id) + + +if __name__ == "__main__": + main() diff --git a/tests/detect/test_checkpoint.py b/tests/detect/test_checkpoint.py index c55b96f..620f10d 100644 --- a/tests/detect/test_checkpoint.py +++ b/tests/detect/test_checkpoint.py @@ -161,3 +161,36 @@ def test_all_serialized_is_json_compatible(): roundtrip = json.loads(json_str) assert roundtrip["frame_meta"]["sequence"] == frame.sequence + + +# --- OverrideProfile --- + +def test_override_profile_region_analysis(): + """OverrideProfile must patch region_analysis_config with overrides.""" + from detect.checkpoint.replay import OverrideProfile + from detect.profiles.soccer import SoccerBroadcastProfile + from detect.profiles.base import RegionAnalysisConfig + + base = SoccerBroadcastProfile() + original = base.region_analysis_config() + + overrides = {"region_analysis": {"edge_canny_low": 25, "edge_canny_high": 200}} + wrapped = OverrideProfile(base, overrides) + patched = wrapped.region_analysis_config() + + assert isinstance(patched, RegionAnalysisConfig) + assert patched.edge_canny_low == 25 + assert patched.edge_canny_high == 200 + # Unmodified fields keep their defaults + assert patched.edge_hough_threshold == original.edge_hough_threshold + + +def test_override_profile_passthrough(): + """OverrideProfile without region_analysis key passes through unchanged.""" + from detect.checkpoint.replay import OverrideProfile + from detect.profiles.soccer import SoccerBroadcastProfile + + base = SoccerBroadcastProfile() + wrapped = OverrideProfile(base, {"ocr": {"min_confidence": 0.1}}) + config = wrapped.region_analysis_config() + assert config.edge_canny_low == base.region_analysis_config().edge_canny_low diff --git a/tests/detect/test_edge_sensitivity.py b/tests/detect/test_edge_sensitivity.py new file mode 100644 index 0000000..dac30d0 --- /dev/null +++ b/tests/detect/test_edge_sensitivity.py @@ -0,0 +1,87 @@ +"""Parameter sensitivity tests for edge detection. + +Verifies that adjusting parameters in expected directions produces +expected changes in detection counts. Uses synthetic frames with +known geometry. +""" + +import importlib.util +from pathlib import Path + +import cv2 +import numpy as np +import pytest + + +# Load edges module directly +_spec = importlib.util.spec_from_file_location( + "cv_edges", Path("gpu/models/cv/edges.py"), +) +_edges_mod = importlib.util.module_from_spec(_spec) +_spec.loader.exec_module(_edges_mod) +detect_edges = _edges_mod.detect_edges +detect_edges_debug = _edges_mod.detect_edges_debug + + +def _frame_with_lines(n_pairs: int = 3, line_thickness: int = 2) -> np.ndarray: + """Create a frame with multiple horizontal line pairs.""" + image = np.zeros((1080, 1920, 3), dtype=np.uint8) + y_start = 300 + for i in range(n_pairs): + y_top = y_start + i * 120 + y_bot = y_top + 40 + i * 10 # varying gap + cv2.line(image, (100, y_top), (1800, y_top), (255, 255, 255), line_thickness) + cv2.line(image, (100, y_bot), (1800, y_bot), (255, 255, 255), line_thickness) + return image + + +def test_canny_low_sensitivity(): + """Lowering canny_low should find same or more horizontal lines.""" + image = _frame_with_lines() + + high_threshold = detect_edges_debug(image, canny_low=100, canny_high=200) + low_threshold = detect_edges_debug(image, canny_low=30, canny_high=200) + + assert low_threshold["horizontal_count"] >= high_threshold["horizontal_count"] + + +def test_hough_threshold_sensitivity(): + """Lowering hough_threshold should find same or more lines.""" + image = _frame_with_lines() + + strict = detect_edges_debug(image, hough_threshold=150) + lenient = detect_edges_debug(image, hough_threshold=40) + + assert lenient["horizontal_count"] >= strict["horizontal_count"] + + +def test_pair_distance_range(): + """Widening pair distance range should find same or more pairs.""" + image = _frame_with_lines() + + narrow = detect_edges_debug(image, pair_min_distance=30, pair_max_distance=50) + wide = detect_edges_debug(image, pair_min_distance=10, pair_max_distance=200) + + assert wide["pair_count"] >= narrow["pair_count"] + + +def test_hough_min_length_sensitivity(): + """Shorter min_length should find same or more lines.""" + image = _frame_with_lines() + + long_min = detect_edges_debug(image, hough_min_length=500) + short_min = detect_edges_debug(image, hough_min_length=50) + + assert short_min["horizontal_count"] >= long_min["horizontal_count"] + + +def test_blank_frame_no_false_positives(): + """All parameter combinations on blank frame should produce zero regions.""" + image = np.zeros((720, 1280, 3), dtype=np.uint8) + + # Very lenient parameters + results = detect_edges( + image, canny_low=10, canny_high=50, hough_threshold=10, + hough_min_length=20, pair_min_distance=5, pair_max_distance=500, + ) + assert results == [] diff --git a/tests/detect/test_region_analyzer.py b/tests/detect/test_region_analyzer.py new file mode 100644 index 0000000..8260bf4 --- /dev/null +++ b/tests/detect/test_region_analyzer.py @@ -0,0 +1,195 @@ +"""Tests for CV region analysis stage — regression checks only.""" + +import importlib.util +from pathlib import Path + +import numpy as np +import pytest + +from detect.models import BoundingBox, Frame +from detect.profiles.base import RegionAnalysisConfig +from detect.profiles.soccer import SoccerBroadcastProfile + + +# Load edges module directly — gpu/models/__init__.py has GPU-only imports +_spec = importlib.util.spec_from_file_location( + "cv_edges", Path("gpu/models/cv/edges.py"), +) +_edges_mod = importlib.util.module_from_spec(_spec) +_spec.loader.exec_module(_edges_mod) +detect_edges = _edges_mod.detect_edges +detect_edges_debug = _edges_mod.detect_edges_debug + + +def _make_frame(seq: int = 0, h: int = 1080, w: int = 1920) -> Frame: + """Create a blank frame for testing.""" + image = np.zeros((h, w, 3), dtype=np.uint8) + return Frame(sequence=seq, chunk_id=0, timestamp=seq * 0.5, image=image) + + +def _make_frame_with_lines(seq: int = 0) -> Frame: + """Create a frame with two strong horizontal lines (simulates hoarding edges).""" + import cv2 + + image = np.zeros((1080, 1920, 3), dtype=np.uint8) + cv2.line(image, (100, 800), (1800, 800), (255, 255, 255), 3) + cv2.line(image, (100, 860), (1800, 860), (255, 255, 255), 3) + return Frame(sequence=seq, chunk_id=0, timestamp=seq * 0.5, image=image) + + +# --- Config --- + +def test_soccer_profile_has_region_analysis_config(): + profile = SoccerBroadcastProfile() + config = profile.region_analysis_config() + assert isinstance(config, RegionAnalysisConfig) + assert config.enabled is True + + +def test_region_analysis_config_defaults(): + config = RegionAnalysisConfig() + assert config.edge_canny_low == 50 + assert config.edge_canny_high == 150 + assert config.edge_hough_threshold == 80 + + +# --- Edge detection (GPU side, loaded standalone) --- + +def test_detect_edges_blank_frame(): + """Blank frame should produce no regions.""" + image = np.zeros((1080, 1920, 3), dtype=np.uint8) + results = detect_edges(image) + assert results == [] + + +def test_detect_edges_with_lines(): + """Frame with parallel horizontal lines should produce at least one region.""" + import cv2 + + image = np.zeros((1080, 1920, 3), dtype=np.uint8) + cv2.line(image, (100, 800), (1800, 800), (255, 255, 255), 3) + cv2.line(image, (100, 860), (1800, 860), (255, 255, 255), 3) + + results = detect_edges(image) + assert len(results) >= 1 + + for r in results: + assert "x" in r and "y" in r and "w" in r and "h" in r + assert r["label"] == "edge_region" + assert 0 <= r["confidence"] <= 1 + + +def test_detect_edges_returns_dict_format(): + """Each result must have the expected keys.""" + import cv2 + + image = np.zeros((720, 1280, 3), dtype=np.uint8) + cv2.line(image, (50, 400), (1200, 400), (255, 255, 255), 2) + cv2.line(image, (50, 450), (1200, 450), (255, 255, 255), 2) + + results = detect_edges(image) + if results: + r = results[0] + expected_keys = {"x", "y", "w", "h", "confidence", "label"} + assert set(r.keys()) == expected_keys + + +# --- Debug function --- + +def test_detect_edges_debug_returns_overlays(): + """Debug function must return overlay images and counts.""" + import cv2 + + image = np.zeros((1080, 1920, 3), dtype=np.uint8) + cv2.line(image, (100, 800), (1800, 800), (255, 255, 255), 3) + cv2.line(image, (100, 860), (1800, 860), (255, 255, 255), 3) + + result = detect_edges_debug(image) + assert "regions" in result + assert "edge_overlay_b64" in result + assert "lines_overlay_b64" in result + assert "horizontal_count" in result + assert "pair_count" in result + assert isinstance(result["edge_overlay_b64"], str) + assert len(result["edge_overlay_b64"]) > 0 # non-empty base64 + assert isinstance(result["lines_overlay_b64"], str) + assert len(result["lines_overlay_b64"]) > 0 + assert result["horizontal_count"] >= 2 + assert result["pair_count"] >= 1 + assert len(result["regions"]) >= 1 + + +def test_detect_edges_debug_blank_frame(): + """Debug on blank frame should still return structure.""" + image = np.zeros((720, 1280, 3), dtype=np.uint8) + result = detect_edges_debug(image) + assert result["regions"] == [] + assert result["horizontal_count"] == 0 + assert result["pair_count"] == 0 + assert isinstance(result["edge_overlay_b64"], str) + + +# --- Stage function --- + +def test_stage_disabled(monkeypatch): + """When disabled, returns empty dict.""" + monkeypatch.setattr("detect.emit.push_detect_event", lambda *a, **kw: None) + + from detect.stages.edge_detector import detect_edge_regions + + config = RegionAnalysisConfig(enabled=False) + result = detect_edge_regions([_make_frame()], config, job_id="test") + assert result == {} + + +def test_stage_local_blank(monkeypatch): + """Local mode on blank frames returns empty boxes.""" + monkeypatch.setattr("detect.emit.push_detect_event", lambda *a, **kw: None) + + from detect.stages.edge_detector import detect_edge_regions + + config = RegionAnalysisConfig() + result = detect_edge_regions([_make_frame()], config, job_id="test") + assert isinstance(result, dict) + assert all(isinstance(v, list) for v in result.values()) + + +def test_stage_local_with_lines(monkeypatch): + """Local mode on frame with lines should find regions.""" + monkeypatch.setattr("detect.emit.push_detect_event", lambda *a, **kw: None) + + from detect.stages.edge_detector import detect_edge_regions + + config = RegionAnalysisConfig() + frame = _make_frame_with_lines() + result = detect_edge_regions([frame], config, job_id="test") + + boxes = result.get(frame.sequence, []) + assert len(boxes) >= 1 + assert all(isinstance(b, BoundingBox) for b in boxes) + assert all(b.label == "edge_region" for b in boxes) + + +# --- Graph wiring --- + +def test_detect_edges_in_nodes(): + """detect_edges must be in the pipeline node list.""" + from detect.graph import NODES, NODE_FUNCTIONS + + assert "detect_edges" in NODES + node_names = [name for name, _ in NODE_FUNCTIONS] + assert "detect_edges" in node_names + + # Must be after filter_scenes, before detect_objects + idx = NODES.index("detect_edges") + assert NODES[idx - 1] == "filter_scenes" + assert NODES[idx + 1] == "detect_objects" + + +# --- State --- + +def test_state_has_edge_regions_field(): + from detect.state import DetectState + + hints = DetectState.__annotations__ + assert "edge_regions_by_frame" in hints diff --git a/tests/detect/test_replay.py b/tests/detect/test_replay.py index dc3cc64..4b3cabb 100644 --- a/tests/detect/test_replay.py +++ b/tests/detect/test_replay.py @@ -3,7 +3,8 @@ import pytest from detect.profiles.soccer import SoccerBroadcastProfile -from detect.checkpoint.replay import OverrideProfile +from detect.profiles.base import RegionAnalysisConfig +from detect.checkpoint.replay import OverrideProfile, replay_single_stage def test_override_profile_patches_ocr(): @@ -65,3 +66,31 @@ def test_override_profile_ignores_unknown_fields(): assert not hasattr(config, "nonexistent_field") assert config.min_confidence == base.ocr_config().min_confidence + + +# --- OverrideProfile for region_analysis --- + +def test_override_profile_patches_region_analysis(): + base = SoccerBroadcastProfile() + overrides = {"region_analysis": {"edge_canny_low": 25, "edge_canny_high": 200}} + profile = OverrideProfile(base, overrides) + + config = profile.region_analysis_config() + + assert isinstance(config, RegionAnalysisConfig) + assert config.edge_canny_low == 25 + assert config.edge_canny_high == 200 + # Unchanged fields keep defaults + assert config.edge_hough_threshold == base.region_analysis_config().edge_hough_threshold + + +# --- replay_single_stage --- + +def test_replay_single_stage_unknown_stage(): + with pytest.raises(ValueError, match="Unknown stage"): + replay_single_stage("fake-job", "nonexistent_stage") + + +def test_replay_single_stage_first_stage(): + with pytest.raises(ValueError, match="Cannot replay the first stage"): + replay_single_stage("fake-job", "extract_frames") diff --git a/tests/detect/test_stage_registry.py b/tests/detect/test_stage_registry.py index 52e0db0..6419989 100644 --- a/tests/detect/test_stage_registry.py +++ b/tests/detect/test_stage_registry.py @@ -4,7 +4,7 @@ from detect.stages import list_stages, get_stage, get_palette EXPECTED_STAGES = [ - "extract_frames", "filter_scenes", "detect_objects", "preprocess", + "extract_frames", "filter_scenes", "detect_edges", "detect_objects", "preprocess", "run_ocr", "match_brands", "escalate_vlm", "escalate_cloud", "compile_report", ] diff --git a/ui/detection-app/src/App.vue b/ui/detection-app/src/App.vue index 793a144..ecc396c 100644 --- a/ui/detection-app/src/App.vue +++ b/ui/detection-app/src/App.vue @@ -10,7 +10,9 @@ import BrandTablePanel from './panels/BrandTablePanel.vue' import TimelinePanel from './panels/TimelinePanel.vue' import CostStatsPanel from './panels/CostStatsPanel.vue' import SourceSelector from './panels/SourceSelector.vue' +import StageConfigSliders from './components/StageConfigSliders.vue' import type { StatsUpdate, RunContext } from './types/sse-contract' +import type { FrameOverlay } from 'mpr-ui-framework/src/renderers/FrameRenderer.vue' import { usePipelineStore } from './stores/pipeline' const pipeline = usePipelineStore() @@ -21,9 +23,10 @@ const stats = ref(null) const runContext = ref(null) const status = ref<'idle' | 'live' | 'processing' | 'error'>('idle') const logPanel = ref<{ clear: () => void } | null>(null) +const sseConnected = ref(false) -// No job selected → open source selector -if (!jobParam) { +// No job selected and no hash route → open source selector +if (!jobParam && !window.location.hash.replace(/^#\/?/, '')) { pipeline.openSourceSelector() } @@ -35,7 +38,6 @@ const source = new SSEDataSource({ source.on('stats_update', (e) => { stats.value = e - // Capture run context from first event that carries it if (!runContext.value && e.run_id) { runContext.value = { run_id: (e as any).run_id, @@ -53,7 +55,7 @@ source.on<{ report?: { status?: string, error?: string } }>('job_complete', (e) // Resizable splits const pipelineWidth = ref(320) -const detectionsFlex = ref(3) // ratio for detections vs stats +const detectionsFlex = ref(3) const viewerHeight = ref(240) const timelineFlex = ref(1) const tableFlex = ref(1) @@ -82,11 +84,19 @@ const statusMap: Record = { live: 'live', error: 'error', } -const checkStatus = () => { status.value = statusMap[source.status.value] ?? 'idle' } +const checkStatus = () => { + if (sseConnected.value) { + status.value = statusMap[source.status.value] ?? 'idle' + } +} setInterval(checkStatus, 500) -if (jobId.value) { +// Only connect SSE for live pipeline runs (no hash route = dashboard mode) +// Scenario URLs use hash routing and load from checkpoint instead +const isScenarioMode = pipeline.isEditing || pipeline.layoutMode !== 'normal' +if (jobId.value && !isScenarioMode) { source.connect() + sseConnected.value = true } async function stopPipeline() { @@ -96,6 +106,37 @@ async function stopPipeline() { } catch { /* ignore — UI will see the cancel event via SSE */ } } +// Current frame image (base64) — tracked for the editor's direct GPU calls +const currentFrameImage = ref(null) +const currentFrameRef = ref(null) + +source.on<{ frame_ref: number; jpeg_b64: string }>('frame_update', (e) => { + currentFrameImage.value = e.jpeg_b64 + currentFrameRef.value = e.frame_ref +}) + +// Debug overlays from replay-stage results +const editorOverlays = ref([]) + +function onReplayResult(result: { + debug?: Record +}) { + const overlays: FrameOverlay[] = [] + if (result.debug) { + // Take first frame's debug data (editor shows one frame at a time) + const firstDebug = Object.values(result.debug)[0] + if (firstDebug) { + if (firstDebug.edge_overlay_b64) { + overlays.push({ src: firstDebug.edge_overlay_b64, label: 'Canny edges', visible: true, opacity: 0.4 }) + } + if (firstDebug.lines_overlay_b64) { + overlays.push({ src: firstDebug.lines_overlay_b64, label: 'Hough lines', visible: true, opacity: 0.5 }) + } + } + } + editorOverlays.value = overlays +} + function onJobStarted(newJobId: string) { jobId.value = newJobId // Reset UI state @@ -113,7 +154,7 @@ function onJobStarted(newJobId: string) { source.disconnect() source.setUrl(`/api/detect/stream/${newJobId}`) source.connect() - // Switch to normal layout (reset sets it to normal already) + sseConnected.value = true } @@ -131,7 +172,7 @@ function onJobStarted(newJobId: string) { @@ -233,10 +280,24 @@ function onJobStarted(newJobId: string) {