diff --git a/admin/mpr/media_assets/models.py b/admin/mpr/media_assets/models.py index 6576093..a9a1f51 100644 --- a/admin/mpr/media_assets/models.py +++ b/admin/mpr/media_assets/models.py @@ -28,6 +28,25 @@ class ChunkJobStatus(models.TextChoices): FAILED = "failed", "Failed" CANCELLED = "cancelled", "Cancelled" +class DetectJobStatus(models.TextChoices): + PENDING = "pending", "Pending" + RUNNING = "running", "Running" + PAUSED = "paused", "Paused" + COMPLETED = "completed", "Completed" + FAILED = "failed", "Failed" + CANCELLED = "cancelled", "Cancelled" + +class RunType(models.TextChoices): + INITIAL = "initial", "Initial" + REPLAY = "replay", "Replay" + RETRY = "retry", "Retry" + +class BrandSource(models.TextChoices): + OCR = "ocr", "Ocr" + VLM = "local_vlm", "Vlm" + CLOUD = "cloud_llm", "Cloud" + MANUAL = "manual", "Manual" + class MediaAsset(models.Model): """A video/audio file registered in the system.""" @@ -148,3 +167,104 @@ class ChunkJob(models.Model): def __str__(self): return str(self.id) + +class DetectJob(models.Model): + """A detection pipeline job.""" + + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + source_asset_id = models.UUIDField() + video_path = models.CharField(max_length=1000) + profile_name = models.CharField(max_length=255) + parent_job_id = models.UUIDField(null=True, blank=True) + run_type = models.CharField(max_length=20, choices=RunType.choices, default=RunType.INITIAL) + replay_from_stage = models.CharField(max_length=255, null=True, blank=True) + config_overrides = models.JSONField(default=dict, blank=True) + status = models.CharField(max_length=20, choices=DetectJobStatus.choices, default=DetectJobStatus.PENDING) + current_stage = models.CharField(max_length=255, null=True, blank=True) + progress = models.FloatField(default=0.0) + error_message = models.TextField(blank=True, default='') + total_detections = models.IntegerField(default=0) + brands_found = models.IntegerField(default=0) + cloud_llm_calls = models.IntegerField(default=0) + estimated_cost_usd = models.FloatField(default=0.0) + celery_task_id = models.CharField(max_length=255, null=True, blank=True) + priority = models.IntegerField(default=0) + created_at = models.DateTimeField(auto_now_add=True) + started_at = models.DateTimeField(null=True, blank=True) + completed_at = models.DateTimeField(null=True, blank=True) + + class Meta: + ordering = ["-created_at"] + + def __str__(self): + return str(self.id) + + +class StageCheckpoint(models.Model): + """A checkpoint saved after a pipeline stage completes.""" + + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + job_id = models.UUIDField() + stage = models.CharField(max_length=255) + stage_index = models.IntegerField() + frames_prefix = models.CharField(max_length=255) + frames_manifest = models.JSONField(default=dict, blank=True) + frames_meta = models.JSONField(default=list, blank=True) + filtered_frame_sequences = models.JSONField(default=list, blank=True) + boxes_by_frame = models.JSONField(default=dict, blank=True) + text_candidates = models.JSONField(default=list, blank=True) + unresolved_candidates = models.JSONField(default=list, blank=True) + detections = models.JSONField(default=list, blank=True) + stats = models.JSONField(default=dict, blank=True) + config_snapshot = models.JSONField(default=dict, blank=True) + config_overrides = models.JSONField(default=dict, blank=True) + video_path = models.CharField(max_length=1000) + profile_name = models.CharField(max_length=255) + created_at = models.DateTimeField(auto_now_add=True) + + class Meta: + ordering = ["-created_at"] + + def __str__(self): + return str(self.id) + + +class KnownBrand(models.Model): + """A brand discovered or registered in the system.""" + + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + canonical_name = models.CharField(max_length=255) + aliases = models.JSONField(default=list, blank=True) + first_source = models.CharField(max_length=20, choices=BrandSource.choices, default=BrandSource.OCR) + total_occurrences = models.IntegerField(default=0) + confirmed = models.BooleanField(default=False) + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) + + class Meta: + ordering = ["-created_at"] + + def __str__(self): + return str(self.id) + + +class SourceBrandSighting(models.Model): + """A brand seen in a specific source (video/asset).""" + + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + source_asset_id = models.UUIDField() + brand_id = models.UUIDField() + brand_name = models.CharField(max_length=255) + first_seen_timestamp = models.FloatField(default=0.0) + last_seen_timestamp = models.FloatField(default=0.0) + occurrences = models.IntegerField(default=0) + detection_source = models.CharField(max_length=20, choices=BrandSource.choices, default=BrandSource.OCR) + avg_confidence = models.FloatField(default=0.0) + created_at = models.DateTimeField(auto_now_add=True) + + class Meta: + ordering = ["-created_at"] + + def __str__(self): + return str(self.id) + diff --git a/core/api/detect_config.py b/core/api/detect_config.py new file mode 100644 index 0000000..75034b6 --- /dev/null +++ b/core/api/detect_config.py @@ -0,0 +1,88 @@ +""" +Runtime config endpoint for the detection pipeline. + +GET /detect/config — read current config +PUT /detect/config — update config (takes effect on next run) +GET /detect/config/stages — list stage palette with config fields +""" + +from __future__ import annotations + +import logging + +from fastapi import APIRouter +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/detect", tags=["detect"]) + +# In-memory config — persists until server restart. +# Phase 12+ moves this to DB. +_runtime_config: dict = {} + + +class ConfigUpdate(BaseModel): + detection: dict | None = None + ocr: dict | None = None + resolver: dict | None = None + escalation: dict | None = None + preprocessing: dict | None = None + + +class StageConfigInfo(BaseModel): + name: str + label: str + description: str + category: str + config_fields: list[dict] + reads: list[str] + writes: list[str] + + +@router.get("/config") +def read_config(): + return _runtime_config + + +@router.put("/config") +def write_config(update: ConfigUpdate): + changes = update.model_dump(exclude_none=True) + for section, values in changes.items(): + if section not in _runtime_config: + _runtime_config[section] = {} + _runtime_config[section].update(values) + + logger.info("Config updated: %s", list(changes.keys())) + return _runtime_config + + +@router.get("/config/stages", response_model=list[StageConfigInfo]) +def list_stage_configs(): + """Return the stage palette with config field metadata for the editor.""" + from detect.stages import list_stages + + result = [] + for stage in list_stages(): + info = StageConfigInfo( + name=stage.name, + label=stage.label, + description=stage.description, + category=stage.category, + config_fields=[ + { + "name": f.name, + "type": f.type, + "default": f.default, + "description": f.description, + "min": f.min, + "max": f.max, + "options": f.options, + } + for f in stage.config_fields + ], + reads=stage.io.reads, + writes=stage.io.writes, + ) + result.append(info) + return result diff --git a/core/api/main.py b/core/api/main.py index 9f02e55..64c84f8 100644 --- a/core/api/main.py +++ b/core/api/main.py @@ -26,6 +26,7 @@ from strawberry.fastapi import GraphQLRouter from core.api.chunker_sse import router as chunker_router from core.api.detect_sse import router as detect_router from core.api.detect_replay import router as detect_replay_router +from core.api.detect_config import router as detect_config_router from core.api.graphql import schema as graphql_schema CALLBACK_API_KEY = os.environ.get("CALLBACK_API_KEY", "") @@ -60,6 +61,9 @@ app.include_router(detect_router) # Detection replay/retry app.include_router(detect_replay_router) +# Detection config +app.include_router(detect_config_router) + @app.get("/health") def health(): diff --git a/core/schema/modelgen.json b/core/schema/modelgen.json index 9df4022..6705af6 100644 --- a/core/schema/modelgen.json +++ b/core/schema/modelgen.json @@ -30,6 +30,11 @@ "target": "typescript", "output": "ui/detection-app/src/types/sse-contract.ts", "include": ["detect_views"] + }, + { + "target": "typescript", + "output": "ui/detection-app/src/types/store-state.ts", + "include": ["ui_state_views"] } ] } diff --git a/core/schema/models/__init__.py b/core/schema/models/__init__.py index d66d450..9f1d3dc 100644 --- a/core/schema/models/__init__.py +++ b/core/schema/models/__init__.py @@ -33,6 +33,7 @@ from .detect_jobs import ( from .media import AssetStatus, MediaAsset from .presets import BUILTIN_PRESETS, TranscodePreset from .detect import DETECT_VIEWS # noqa: F401 — discovered by modelgen generic loader +from .ui_state import UI_STATE_VIEWS # noqa: F401 — UI store state types from .views import ChunkEvent, ChunkOutputFile, PipelineStats, WorkerEvent # Core domain models - generates Django, Pydantic, TypeScript diff --git a/core/schema/models/ui_state.py b/core/schema/models/ui_state.py new file mode 100644 index 0000000..d06e6d5 --- /dev/null +++ b/core/schema/models/ui_state.py @@ -0,0 +1,139 @@ +""" +UI application state models. + +Source of truth for all frontend store state shapes. +Generates TypeScript types via modelgen. +The store implementation (Pinia, etc.) is just the reactive container. +""" + +from dataclasses import dataclass, field +from typing import List, Optional + + +# --------------------------------------------------------------------------- +# Pipeline store +# --------------------------------------------------------------------------- + +@dataclass +class NodeState: + """A pipeline node's current status.""" + id: str + status: str = "pending" # pending | running | done | error + has_checkpoint: bool = False + has_region_editor: bool = False # stage works with visual regions + has_config_editor: bool = True # all stages have config + + +@dataclass +class PipelineState: + """Full pipeline run state.""" + job_id: str = "" + status: str = "idle" # idle | running | paused | completed | error + layout_mode: str = "normal" # normal | bbox_editor | stage_editor + editor_stage: Optional[str] = None # which stage's editor is open + nodes: List[NodeState] = field(default_factory=list) + current_stage: Optional[str] = None + run_id: Optional[str] = None + parent_job_id: Optional[str] = None + run_type: str = "initial" # initial | replay | retry + error: Optional[str] = None + + +# --------------------------------------------------------------------------- +# Config store +# --------------------------------------------------------------------------- + +@dataclass +class DetectionConfigOverrides: + """Tunable detection stage config.""" + model_name: Optional[str] = None + confidence_threshold: Optional[float] = None + target_classes: Optional[List[str]] = None + + +@dataclass +class OCRConfigOverrides: + """Tunable OCR stage config.""" + languages: Optional[List[str]] = None + min_confidence: Optional[float] = None + + +@dataclass +class ResolverConfigOverrides: + """Tunable brand resolver config.""" + fuzzy_threshold: Optional[int] = None + + +@dataclass +class EscalationConfigOverrides: + """Tunable escalation config.""" + vlm_min_confidence: Optional[float] = None + cloud_min_confidence: Optional[float] = None + cloud_provider: Optional[str] = None + + +@dataclass +class PreprocessingConfigOverrides: + """Tunable preprocessing config.""" + binarize: Optional[bool] = None + deskew: Optional[bool] = None + contrast: Optional[bool] = None + + +@dataclass +class ConfigOverrides: + """Aggregated config overrides from all panels.""" + detection: Optional[DetectionConfigOverrides] = None + ocr: Optional[OCRConfigOverrides] = None + resolver: Optional[ResolverConfigOverrides] = None + escalation: Optional[EscalationConfigOverrides] = None + preprocessing: Optional[PreprocessingConfigOverrides] = None + + +@dataclass +class ConfigState: + """Config store state.""" + current: ConfigOverrides = field(default_factory=ConfigOverrides) + pending: ConfigOverrides = field(default_factory=ConfigOverrides) + dirty: bool = False + + +# --------------------------------------------------------------------------- +# Selection store +# --------------------------------------------------------------------------- + +@dataclass +class BboxRegion: + """A user-drawn bounding box region.""" + x: int + y: int + w: int + h: int + + +@dataclass +class SelectionState: + """Cross-panel selection state.""" + selected_frame: Optional[int] = None + selected_brand: Optional[str] = None + hovered_timestamp: Optional[float] = None + bbox_region: Optional[BboxRegion] = None + + +# --------------------------------------------------------------------------- +# Export for modelgen +# --------------------------------------------------------------------------- + +UI_STATE_VIEWS = [ + NodeState, + PipelineState, + DetectionConfigOverrides, + OCRConfigOverrides, + ResolverConfigOverrides, + EscalationConfigOverrides, + PreprocessingConfigOverrides, + ConfigOverrides, + ConfigState, + BboxRegion, + SelectionState, +] diff --git a/detect/graph.py b/detect/graph.py index b945f3b..f6017ed 100644 --- a/detect/graph.py +++ b/detect/graph.py @@ -18,6 +18,7 @@ from detect.state import DetectState from detect.stages.frame_extractor import extract_frames from detect.stages.scene_filter import scene_filter from detect.stages.yolo_detector import detect_objects +from detect.stages.preprocess import preprocess_regions from detect.stages.ocr_stage import run_ocr from detect.stages.brand_resolver import resolve_brands from detect.stages.vlm_local import escalate_vlm @@ -31,6 +32,7 @@ NODES = [ "extract_frames", "filter_scenes", "detect_objects", + "preprocess", "run_ocr", "match_brands", "escalate_vlm", @@ -137,6 +139,36 @@ def node_detect_objects(state: DetectState) -> dict: return {"boxes_by_frame": all_boxes, "stats": stats} +def node_preprocess(state: DetectState) -> dict: + _emit_transition(state, "preprocess", "running") + + with trace_node(state, "preprocess") as span: + profile = _get_profile(state) + frames = state.get("filtered_frames", []) + boxes = state.get("boxes_by_frame", {}) + job_id = state.get("job_id") + + # Get preprocessing config from profile overrides or defaults + overrides = state.get("config_overrides", {}) + prep_config = overrides.get("preprocessing", {}) + do_contrast = prep_config.get("contrast", True) + do_deskew = prep_config.get("deskew", False) + do_binarize = prep_config.get("binarize", False) + + preprocessed = preprocess_regions( + frames, boxes, + do_contrast=do_contrast, + do_deskew=do_deskew, + do_binarize=do_binarize, + inference_url=INFERENCE_URL, + job_id=job_id, + ) + span.set_output({"regions_preprocessed": len(preprocessed)}) + + _emit_transition(state, "preprocess", "done") + return {"preprocessed_crops": preprocessed} + + def node_run_ocr(state: DetectState) -> dict: _emit_transition(state, "run_ocr", "running") @@ -304,6 +336,7 @@ NODE_FUNCTIONS = [ ("extract_frames", node_extract_frames), ("filter_scenes", node_filter_scenes), ("detect_objects", node_detect_objects), + ("preprocess", node_preprocess), ("run_ocr", node_run_ocr), ("match_brands", node_match_brands), ("escalate_vlm", node_escalate_vlm), diff --git a/detect/sse_contract.py b/detect/sse_contract.py index 7615015..d5161d0 100644 --- a/detect/sse_contract.py +++ b/detect/sse_contract.py @@ -101,3 +101,40 @@ class JobComplete(BaseModel): """Final report when pipeline finishes. SSE event: job_complete""" job_id: str report: Optional[DetectionReportSummary] = None + +class RunContext(BaseModel): + """Run context injected into all SSE events for grouping.""" + run_id: str + parent_job_id: str + run_type: str = "initial" + +class CheckpointInfo(BaseModel): + """Available checkpoint for a stage.""" + stage: str + +class ReplayRequest(BaseModel): + """Request to replay pipeline from a specific stage.""" + job_id: str + start_stage: str + config_overrides: Optional[Dict[str, Any]] = None + +class ReplayResponse(BaseModel): + """Result of a replay invocation.""" + status: str + job_id: str + start_stage: str + detections: int = 0 + brands_found: int = 0 + +class RetryRequest(BaseModel): + """Request to queue async retry with different config.""" + job_id: str + config_overrides: Optional[Dict[str, Any]] = None + start_stage: str = "escalate_vlm" + schedule_seconds: Optional[float] = None + +class RetryResponse(BaseModel): + """Result of queueing a retry task.""" + status: str + task_id: str + job_id: str diff --git a/detect/stages/preprocess.py b/detect/stages/preprocess.py new file mode 100644 index 0000000..a63c9be --- /dev/null +++ b/detect/stages/preprocess.py @@ -0,0 +1,128 @@ +""" +Stage 3.5 — Preprocessing + +Runs between YOLO detection and OCR. Applies configurable image +preprocessing to each detected region crop: contrast enhancement, +deskewing, binarization. + +Operates on the crops derived from boxes_by_frame, produces +preprocessed_crops keyed by (frame_sequence, box_index). +""" + +from __future__ import annotations + +import logging + +import numpy as np + +from detect import emit +from detect.models import BoundingBox, Frame + +logger = logging.getLogger(__name__) + + +def _crop_region(frame: Frame, box: BoundingBox) -> np.ndarray: + h, w = frame.image.shape[:2] + x1 = max(0, box.x) + y1 = max(0, box.y) + x2 = min(w, box.x + box.w) + y2 = min(h, box.y + box.h) + return frame.image[y1:y2, x1:x2] + + +def preprocess_regions( + frames: list[Frame], + boxes_by_frame: dict[int, list[BoundingBox]], + do_contrast: bool = True, + do_deskew: bool = False, + do_binarize: bool = False, + inference_url: str | None = None, + job_id: str | None = None, +) -> dict[str, np.ndarray]: + """ + Preprocess cropped regions from YOLO detections. + + Returns dict keyed by "{frame_seq}_{box_idx}" → preprocessed crop. + These are passed to the OCR stage instead of raw crops. + """ + total_regions = sum(len(boxes) for boxes in boxes_by_frame.values()) + any_active = do_contrast or do_deskew or do_binarize + + if not any_active: + emit.log(job_id, "Preprocess", "INFO", + f"Preprocessing disabled, passing {total_regions} regions through") + return {} + + mode = "remote" if inference_url else "local" + emit.log(job_id, "Preprocess", "INFO", + f"Preprocessing {total_regions} regions (mode={mode}, " + f"contrast={do_contrast}, deskew={do_deskew}, binarize={do_binarize})") + + frame_map = {f.sequence: f for f in frames} + preprocessed: dict[str, np.ndarray] = {} + processed_count = 0 + + for seq, boxes in boxes_by_frame.items(): + frame = frame_map.get(seq) + if not frame: + continue + + for idx, box in enumerate(boxes): + crop = _crop_region(frame, box) + if crop.size == 0: + continue + + key = f"{seq}_{idx}" + + if inference_url: + result = _preprocess_remote(crop, inference_url, + do_contrast, do_deskew, do_binarize) + else: + result = _preprocess_local(crop, do_contrast, do_deskew, do_binarize) + + preprocessed[key] = result + processed_count += 1 + + emit.log(job_id, "Preprocess", "INFO", + f"Preprocessed {processed_count} regions") + + return preprocessed + + +def _preprocess_remote(crop: np.ndarray, inference_url: str, + do_contrast: bool, do_deskew: bool, do_binarize: bool) -> np.ndarray: + """Call GPU server /preprocess endpoint.""" + import base64 + import io + + import requests + from PIL import Image + + img = Image.fromarray(crop) + buf = io.BytesIO() + img.save(buf, format="JPEG", quality=85) + image_b64 = base64.b64encode(buf.getvalue()).decode() + + resp = requests.post( + f"{inference_url.rstrip('/')}/preprocess", + json={ + "image": image_b64, + "contrast": do_contrast, + "deskew": do_deskew, + "binarize": do_binarize, + }, + timeout=30, + ) + resp.raise_for_status() + data = resp.json() + + result_bytes = base64.b64decode(data["image"]) + result_img = Image.open(io.BytesIO(result_bytes)).convert("RGB") + return np.array(result_img) + + +def _preprocess_local(crop: np.ndarray, + do_contrast: bool, do_deskew: bool, do_binarize: bool) -> np.ndarray: + """Run preprocessing in-process (requires opencv-python-headless).""" + from gpu.models.preprocess import preprocess + return preprocess(crop, do_binarize=do_binarize, do_deskew=do_deskew, do_contrast=do_contrast) diff --git a/detect/stages/registry/preprocessing.py b/detect/stages/registry/preprocessing.py index 11d40f5..cdc3f4f 100644 --- a/detect/stages/registry/preprocessing.py +++ b/detect/stages/registry/preprocessing.py @@ -1,4 +1,4 @@ -"""Registration for preprocessing stages: frame extraction, scene filter.""" +"""Registration for preprocessing stages: frame extraction, scene filter, image preprocessing.""" from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage from ._serializers import serialize_frames, deserialize_frames @@ -25,6 +25,17 @@ def _deser_filter(data: dict, job_id: str) -> dict: return {"_filtered_sequences": data["filtered_frame_sequences"]} +def _ser_preprocess(state: dict, job_id: str) -> dict: + # Preprocessed crops are numpy arrays — regenerable from frames + boxes + config + crops = state.get("preprocessed_crops", {}) + return {"crop_keys": list(crops.keys()), "count": len(crops)} + + +def _deser_preprocess(data: dict, job_id: str) -> dict: + # Crops are regenerable — no need to restore from checkpoint + return {"preprocessed_crops": {}} + + def register(): extract = StageDefinition( name="extract_frames", @@ -55,3 +66,22 @@ def register(): deserialize_fn=_deser_filter, ) register_stage(scene_filter) + + preprocess = StageDefinition( + name="preprocess", + label="Preprocess", + description="Image preprocessing on detected regions before OCR", + category="preprocessing", + io=StageIO( + reads=["filtered_frames", "boxes_by_frame"], + writes=["preprocessed_crops"], + ), + config_fields=[ + StageConfigField("contrast", "bool", True, "CLAHE contrast enhancement"), + StageConfigField("deskew", "bool", False, "Correct slight rotation"), + StageConfigField("binarize", "bool", False, "Otsu binarization"), + ], + serialize_fn=_ser_preprocess, + deserialize_fn=_deser_preprocess, + ) + register_stage(preprocess) diff --git a/detect/state.py b/detect/state.py index 56b4d27..0cf4b4a 100644 --- a/detect/state.py +++ b/detect/state.py @@ -23,6 +23,7 @@ class DetectState(TypedDict, total=False): frames: list[Frame] filtered_frames: list[Frame] boxes_by_frame: dict[int, list[BoundingBox]] + preprocessed_crops: dict # "{frame_seq}_{box_idx}" → np.ndarray text_candidates: list[TextCandidate] unresolved_candidates: list[TextCandidate] detections: list[BrandDetection] diff --git a/gpu/models/preprocess.py b/gpu/models/preprocess.py new file mode 100644 index 0000000..4c060e2 --- /dev/null +++ b/gpu/models/preprocess.py @@ -0,0 +1,117 @@ +""" +Image preprocessing pipeline for crops before OCR. + +Each step is independently toggleable via config. +Operates on numpy arrays (BGR or RGB), returns processed array. +""" + +from __future__ import annotations + +import logging + +import numpy as np + +logger = logging.getLogger(__name__) + + +def binarize(image: np.ndarray, threshold: int = 128) -> np.ndarray: + """Convert to grayscale and apply Otsu binarization.""" + import cv2 + + if len(image.shape) == 3: + gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + else: + gray = image + + _, binary = cv2.threshold(gray, threshold, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) + + # Convert back to 3-channel for downstream compatibility + result = cv2.cvtColor(binary, cv2.COLOR_GRAY2RGB) + return result + + +def deskew(image: np.ndarray) -> np.ndarray: + """Correct slight rotation using minimum area rectangle.""" + import cv2 + + if len(image.shape) == 3: + gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + else: + gray = image + + coords = np.column_stack(np.where(gray < 128)) + if len(coords) < 10: + return image + + rect = cv2.minAreaRect(coords) + angle = rect[-1] + + # Normalize angle + if angle < -45: + angle = -(90 + angle) + else: + angle = -angle + + if abs(angle) < 0.5: + return image + + h, w = image.shape[:2] + center = (w // 2, h // 2) + rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0) + result = cv2.warpAffine( + image, rotation_matrix, (w, h), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_REPLICATE, + ) + return result + + +def enhance_contrast(image: np.ndarray) -> np.ndarray: + """Apply CLAHE (adaptive histogram equalization) for contrast normalization.""" + import cv2 + + if len(image.shape) == 3: + lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB) + l_channel = lab[:, :, 0] + else: + l_channel = image + + clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) + enhanced = clahe.apply(l_channel) + + if len(image.shape) == 3: + lab[:, :, 0] = enhanced + result = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB) + else: + result = enhanced + + return result + + +def preprocess( + image: np.ndarray, + do_binarize: bool = False, + do_deskew: bool = False, + do_contrast: bool = True, +) -> np.ndarray: + """ + Run the preprocessing pipeline on a crop image. + + Each step is independently toggleable. Order: contrast → deskew → binarize. + Contrast first (works best on color), binarize last (destroys color info). + """ + result = image + + if do_contrast: + result = enhance_contrast(result) + logger.debug("Preprocessing: contrast enhanced") + + if do_deskew: + result = deskew(result) + logger.debug("Preprocessing: deskewed") + + if do_binarize: + result = binarize(result) + logger.debug("Preprocessing: binarized") + + return result diff --git a/gpu/requirements.txt b/gpu/requirements.txt index 771aadb..0fc2427 100644 --- a/gpu/requirements.txt +++ b/gpu/requirements.txt @@ -25,3 +25,6 @@ paddleocr>=3.0.0 # (all_tied_weights_keys API change). Also needs accelerate for device_map. transformers>=4.40.0,<5 accelerate>=0.27.0 + +# Preprocessing (phase 12) +opencv-python-headless>=4.8.0 diff --git a/gpu/server.py b/gpu/server.py index e9a46fa..c23df6c 100644 --- a/gpu/server.py +++ b/gpu/server.py @@ -73,6 +73,17 @@ class OCRResponse(BaseModel): results: list[OCRTextResult] +class PreprocessRequest(BaseModel): + image: str + binarize: bool = False + deskew: bool = False + contrast: bool = True + + +class PreprocessResponse(BaseModel): + image: str # base64 JPEG of processed image + + class VLMRequest(BaseModel): image: str prompt: str @@ -183,6 +194,34 @@ def ocr(req: OCRRequest): return OCRResponse(results=[OCRTextResult(**r) for r in results]) +@app.post("/preprocess", response_model=PreprocessResponse) +def preprocess_image(req: PreprocessRequest): + try: + image = _decode_image(req.image) + except Exception as e: + raise HTTPException(status_code=400, detail=f"Bad image: {e}") + + try: + from models.preprocess import preprocess + processed = preprocess( + image, + do_binarize=req.binarize, + do_deskew=req.deskew, + do_contrast=req.contrast, + ) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Preprocessing failed: {e}") + + from PIL import Image as PILImage + import io + img = PILImage.fromarray(processed) + buf = io.BytesIO() + img.save(buf, format="JPEG", quality=90) + result_b64 = base64.b64encode(buf.getvalue()).decode() + + return PreprocessResponse(image=result_b64) + + @app.post("/vlm", response_model=VLMResponse) def vlm(req: VLMRequest): try: diff --git a/tests/detect/manual/push_pipeline.py b/tests/detect/manual/push_pipeline.py index c18de17..763d327 100644 --- a/tests/detect/manual/push_pipeline.py +++ b/tests/detect/manual/push_pipeline.py @@ -35,7 +35,7 @@ def push(r, key, event): def main(): parser = argparse.ArgumentParser() - parser.add_argument("--job", default="pipeline-test") + parser.add_argument("--job", default=f"pipeline-{int(__import__('time').time()) % 100000}") parser.add_argument("--port", type=int, default=6382) parser.add_argument("--delay", type=float, default=0.5) args = parser.parse_args() diff --git a/tests/detect/manual/run_extract_filter.py b/tests/detect/manual/run_extract_filter.py index ba2948b..4b967eb 100644 --- a/tests/detect/manual/run_extract_filter.py +++ b/tests/detect/manual/run_extract_filter.py @@ -15,7 +15,7 @@ import sys # Parse args early so we can set REDIS_URL before imports parser = argparse.ArgumentParser() -parser.add_argument("--job", default="extract-filter-test") +parser.add_argument("--job", default=f"extract-{int(__import__('time').time()) % 100000}") parser.add_argument("--port", type=int, default=6382) args = parser.parse_args() diff --git a/tests/detect/manual/run_graph.py b/tests/detect/manual/run_graph.py index 1aad4a7..267445c 100644 --- a/tests/detect/manual/run_graph.py +++ b/tests/detect/manual/run_graph.py @@ -13,8 +13,9 @@ import logging import os import sys +import time as _time parser = argparse.ArgumentParser() -parser.add_argument("--job", default="graph-test") +parser.add_argument("--job", default=f"graph-{int(_time.time()) % 100000}") parser.add_argument("--port", type=int, default=6382) args = parser.parse_args() diff --git a/tests/detect/manual/test_brand_table_e2e.py b/tests/detect/manual/test_brand_table_e2e.py index 12b9973..8851d74 100644 --- a/tests/detect/manual/test_brand_table_e2e.py +++ b/tests/detect/manual/test_brand_table_e2e.py @@ -55,7 +55,7 @@ def push(r, key, event): def main(): parser = argparse.ArgumentParser() - parser.add_argument("--job", default="brand-table-test") + parser.add_argument("--job", default=f"brand-{int(__import__('time').time()) % 100000}") parser.add_argument("--port", type=int, default=6382) parser.add_argument("--delay", type=float, default=0.6) args = parser.parse_args() diff --git a/tests/detect/manual/test_escalation_e2e.py b/tests/detect/manual/test_escalation_e2e.py index b763c58..e4ddb5b 100644 --- a/tests/detect/manual/test_escalation_e2e.py +++ b/tests/detect/manual/test_escalation_e2e.py @@ -23,8 +23,8 @@ import redis logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(message)s") logger = logging.getLogger(__name__) -NODES = ["extract_frames", "filter_scenes", "detect_objects", "run_ocr", - "match_brands", "escalate_vlm", "escalate_cloud", "compile_report"] +NODES = ["extract_frames", "filter_scenes", "detect_objects", "preprocess", + "run_ocr", "match_brands", "escalate_vlm", "escalate_cloud", "compile_report"] def ts(): @@ -70,12 +70,22 @@ def push_stats(r, key, **fields): push(r, key, base) +_bbox_idx = 0 + def push_detection(r, key, brand, conf, source, timestamp, frame_ref, delay): + global _bbox_idx + # Spread fake bboxes across the frame so they don't overlap + col = _bbox_idx % 4 + row = _bbox_idx // 4 + bbox = {"x": 50 + col * 200, "y": 50 + row * 120, "w": 160, "h": 80} + _bbox_idx += 1 + push(r, key, { "event": "detection", "brand": brand, "confidence": conf, "source": source, "timestamp": timestamp, "duration": 0.5, "content_type": "soccer_broadcast", "frame_ref": frame_ref, + "bbox": bbox, }) logger.info(" [%s] %s %.2f t=%.1fs", source, brand, conf, timestamp) time.sleep(delay * 0.3) @@ -83,7 +93,9 @@ def push_detection(r, key, brand, conf, source, timestamp, frame_ref, delay): def main(): parser = argparse.ArgumentParser() - parser.add_argument("--job", default="escalation-test") + import time as _time + default_job = f"escalation-{int(_time.time()) % 100000}" + parser.add_argument("--job", default=default_job) parser.add_argument("--port", type=int, default=6382) parser.add_argument("--delay", type=float, default=0.5) args = parser.parse_args() @@ -121,6 +133,32 @@ def main(): push(r, key, {"event": "log", "level": "INFO", "stage": "YOLODetector", "msg": "Running yolov8n on 52 frames"}) time.sleep(delay) + + # Push a sample frame with YOLO boxes + import base64, io + from PIL import Image as PILImage, ImageDraw + frame_img = PILImage.new("RGB", (960, 540), "#1a1a2e") + draw = ImageDraw.Draw(frame_img) + draw.rectangle([40, 440, 900, 520], outline="#444", width=2) + draw.text((100, 460), "SPONSOR BOARD AREA", fill="#666") + draw.rectangle([350, 150, 610, 380], outline="#333", width=1) + draw.text((400, 200), "PLAYER", fill="#555") + buf = io.BytesIO() + frame_img.save(buf, "JPEG") + frame_b64 = base64.b64encode(buf.getvalue()).decode() + + yolo_boxes = [ + {"x": 40, "y": 440, "w": 860, "h": 80, "confidence": 0.92, + "label": "ad_board", "stage": "detect_objects", "source": "yolo"}, + {"x": 350, "y": 150, "w": 260, "h": 230, "confidence": 0.87, + "label": "person", "stage": "detect_objects", "source": "yolo"}, + {"x": 700, "y": 30, "w": 200, "h": 60, "confidence": 0.78, + "label": "scoreboard", "stage": "detect_objects", "source": "yolo"}, + ] + push(r, key, {"event": "frame_update", "frame_ref": 25, "timestamp": 12.5, + "jpeg_b64": frame_b64, "boxes": yolo_boxes}) + time.sleep(delay) + push_stats(r, key, frames_extracted=180, frames_after_scene_filter=52, regions_detected=41, processing_time_seconds=14.2) push_graph(r, key, "detect_objects", "done", delay) diff --git a/tests/detect/manual/test_timeline_cost.py b/tests/detect/manual/test_timeline_cost.py index 6f3bd9f..5ca69f8 100644 --- a/tests/detect/manual/test_timeline_cost.py +++ b/tests/detect/manual/test_timeline_cost.py @@ -85,7 +85,7 @@ def push_stats(r, key, **overrides): def main(): parser = argparse.ArgumentParser() - parser.add_argument("--job", default="timeline-cost-test") + parser.add_argument("--job", default=f"timeline-{int(__import__('time').time()) % 100000}") parser.add_argument("--port", type=int, default=6382) parser.add_argument("--delay", type=float, default=0.4) args = parser.parse_args() diff --git a/tests/detect/test_config_endpoint.py b/tests/detect/test_config_endpoint.py new file mode 100644 index 0000000..291d92f --- /dev/null +++ b/tests/detect/test_config_endpoint.py @@ -0,0 +1,44 @@ +"""Tests for the config endpoint and stage palette.""" + +from detect.stages import list_stages, get_palette + + +def test_stage_palette_has_config_fields(): + """Every stage with config fields should be servable by the endpoint.""" + stages = list_stages() + stages_with_config = [s for s in stages if s.config_fields] + + assert len(stages_with_config) > 0 + + for stage in stages_with_config: + for field in stage.config_fields: + assert field.name + assert field.type + assert field.default is not None or field.type == "bool" + + +def test_palette_categories(): + palette = get_palette() + + expected_categories = {"preprocessing", "detection", "resolution", "escalation", "output"} + actual_categories = set(palette.keys()) + + assert actual_categories == expected_categories + + +def test_stage_config_serializable(): + """Config fields should be JSON-serializable for the API response.""" + import json + + stages = list_stages() + for stage in stages: + data = { + "name": stage.name, + "label": stage.label, + "config_fields": [ + {"name": f.name, "type": f.type, "default": f.default} + for f in stage.config_fields + ], + } + json_str = json.dumps(data) + assert len(json_str) > 0 diff --git a/tests/detect/test_preprocess.py b/tests/detect/test_preprocess.py new file mode 100644 index 0000000..d2363b9 --- /dev/null +++ b/tests/detect/test_preprocess.py @@ -0,0 +1,84 @@ +"""Tests for OpenCV preprocessing — runs without GPU.""" + +import numpy as np +import pytest + +try: + import cv2 + HAS_CV2 = True +except ImportError: + HAS_CV2 = False + +requires_cv2 = pytest.mark.skipif(not HAS_CV2, reason="opencv-python-headless not installed") + +# Add gpu/ to path so imports resolve (gpu modules use relative imports) +import sys +from pathlib import Path +sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "gpu")) + + +def _make_image(w: int = 200, h: int = 60) -> np.ndarray: + """White image with black text-like region.""" + img = np.ones((h, w, 3), dtype=np.uint8) * 255 + img[15:45, 20:180] = 30 # dark band simulating text + return img + + +@requires_cv2 +def test_binarize(): + from gpu.models.preprocess import binarize + + img = _make_image() + result = binarize(img) + + assert result.shape == img.shape + assert result.dtype == np.uint8 + # Should be mostly black and white (no grays) + unique_values = np.unique(result) + assert len(unique_values) <= 3 # 0, 255, maybe one more from anti-aliasing + + +@requires_cv2 +def test_enhance_contrast(): + from gpu.models.preprocess import enhance_contrast + + img = _make_image() + result = enhance_contrast(img) + + assert result.shape == img.shape + assert result.dtype == np.uint8 + + +@requires_cv2 +def test_deskew_no_rotation(): + from gpu.models.preprocess import deskew + + img = _make_image() + result = deskew(img) + + assert result.shape == img.shape + # Straight image should be unchanged (angle < 0.5 deg) + assert np.allclose(result, img, atol=5) + + +@requires_cv2 +def test_preprocess_pipeline(): + from gpu.models.preprocess import preprocess + + img = _make_image() + + result = preprocess(img, do_binarize=False, do_deskew=False, do_contrast=True) + assert result.shape == img.shape + + result = preprocess(img, do_binarize=True, do_deskew=True, do_contrast=True) + assert result.shape[:2] == img.shape[:2] # h, w same; channels may differ then get converted back + + +@requires_cv2 +def test_preprocess_all_disabled(): + from gpu.models.preprocess import preprocess + + img = _make_image() + result = preprocess(img, do_binarize=False, do_deskew=False, do_contrast=False) + + assert np.array_equal(result, img) diff --git a/tests/detect/test_stage_registry.py b/tests/detect/test_stage_registry.py index 8f9919b..52e0db0 100644 --- a/tests/detect/test_stage_registry.py +++ b/tests/detect/test_stage_registry.py @@ -4,8 +4,8 @@ from detect.stages import list_stages, get_stage, get_palette EXPECTED_STAGES = [ - "extract_frames", "filter_scenes", "detect_objects", "run_ocr", - "match_brands", "escalate_vlm", "escalate_cloud", "compile_report", + "extract_frames", "filter_scenes", "detect_objects", "preprocess", + "run_ocr", "match_brands", "escalate_vlm", "escalate_cloud", "compile_report", ] diff --git a/ui/common/types/generated.ts b/ui/common/types/generated.ts index 45ce962..46dbf6f 100644 --- a/ui/common/types/generated.ts +++ b/ui/common/types/generated.ts @@ -7,6 +7,9 @@ export type AssetStatus = "pending" | "ready" | "error"; export type JobStatus = "pending" | "processing" | "completed" | "failed" | "cancelled"; export type ChunkJobStatus = "pending" | "chunking" | "processing" | "collecting" | "completed" | "failed" | "cancelled"; +export type DetectJobStatus = "pending" | "running" | "paused" | "completed" | "failed" | "cancelled"; +export type RunType = "initial" | "replay" | "retry"; +export type BrandSource = "ocr" | "local_vlm" | "cloud_llm" | "manual"; export interface MediaAsset { id: string; @@ -97,6 +100,75 @@ export interface ChunkJob { completed_at: string | null; } +export interface DetectJob { + id: string; + source_asset_id: string; + video_path: string; + profile_name: string; + parent_job_id: string | null; + run_type: RunType; + replay_from_stage: string | null; + config_overrides: Record; + status: DetectJobStatus; + current_stage: string | null; + progress: number; + error_message: string | null; + total_detections: number; + brands_found: number; + cloud_llm_calls: number; + estimated_cost_usd: number; + celery_task_id: string | null; + priority: number; + created_at: string | null; + started_at: string | null; + completed_at: string | null; +} + +export interface StageCheckpoint { + id: string; + job_id: string; + stage: string; + stage_index: number; + frames_prefix: string; + frames_manifest: Record; + frames_meta: string[]; + filtered_frame_sequences: number[]; + boxes_by_frame: Record; + text_candidates: string[]; + unresolved_candidates: string[]; + detections: string[]; + stats: Record; + config_snapshot: Record; + config_overrides: Record; + video_path: string; + profile_name: string; + created_at: string | null; +} + +export interface KnownBrand { + id: string; + canonical_name: string; + aliases: string[]; + first_source: BrandSource; + total_occurrences: number; + confirmed: boolean; + created_at: string | null; + updated_at: string | null; +} + +export interface SourceBrandSighting { + id: string; + source_asset_id: string; + brand_id: string; + brand_name: string; + first_seen_timestamp: number; + last_seen_timestamp: number; + occurrences: number; + detection_source: BrandSource; + avg_confidence: number; + created_at: string | null; +} + export interface CreateJobRequest { source_asset_id: string; preset_id: string | null; diff --git a/ui/detection-app/src/App.vue b/ui/detection-app/src/App.vue index 5ce2b96..09cb3aa 100644 --- a/ui/detection-app/src/App.vue +++ b/ui/detection-app/src/App.vue @@ -10,6 +10,9 @@ import BrandTablePanel from './panels/BrandTablePanel.vue' import TimelinePanel from './panels/TimelinePanel.vue' import CostStatsPanel from './panels/CostStatsPanel.vue' import type { StatsUpdate, RunContext } from './types/sse-contract' +import { usePipelineStore } from './stores/pipeline' + +const pipeline = usePipelineStore() const jobId = ref(new URLSearchParams(window.location.search).get('job') || 'test-job') const stats = ref(null) @@ -89,56 +92,99 @@ source.connect() - +
- -
- - -
- - -
-
- -
-
- -
- -
- -
-
-
+ + + + + + + + +
- +
- + +
@@ -290,4 +336,67 @@ header h1 { font-size: var(--font-size-lg); font-weight: 600; } } .empty { color: var(--text-dim); padding: var(--space-6); text-align: center; } + +/* Editor placeholders */ +.editor-placeholder { + display: flex; + height: 100%; + gap: var(--space-2); +} + +.editor-frame { + flex: 1; + min-height: 0; +} + +.editor-tools { + width: 200px; + flex-shrink: 0; + padding: var(--space-3); + background: var(--surface-2); + border-radius: var(--panel-radius); + display: flex; + flex-direction: column; + gap: var(--space-2); + font-size: var(--font-size-sm); + color: var(--text-secondary); +} + +.editor-config { + padding: var(--space-4); + font-size: var(--font-size-sm); + color: var(--text-secondary); + display: flex; + flex-direction: column; + gap: var(--space-2); +} + +.editor-close { + background: var(--surface-3); + border: 1px solid var(--surface-3); + border-radius: 4px; + padding: var(--space-2) var(--space-3); + color: var(--text-secondary); + font-family: var(--font-mono); + font-size: var(--font-size-sm); + cursor: pointer; + margin-top: auto; +} + +.editor-close:hover { + background: var(--status-error); + color: #000; +} + +.blob-viewer { + height: 100%; + overflow-x: auto; +} + +.blob-placeholder { + padding: var(--space-4); + color: var(--text-dim); + text-align: center; + font-size: var(--font-size-sm); +} diff --git a/ui/detection-app/src/main.ts b/ui/detection-app/src/main.ts index 01433bc..27b8b78 100644 --- a/ui/detection-app/src/main.ts +++ b/ui/detection-app/src/main.ts @@ -1,4 +1,7 @@ import { createApp } from 'vue' +import { createPinia } from 'pinia' import App from './App.vue' -createApp(App).mount('#app') +const app = createApp(App) +app.use(createPinia()) +app.mount('#app') diff --git a/ui/detection-app/src/panels/FramePanel.vue b/ui/detection-app/src/panels/FramePanel.vue index 0be7383..dc850ff 100644 --- a/ui/detection-app/src/panels/FramePanel.vue +++ b/ui/detection-app/src/panels/FramePanel.vue @@ -1,5 +1,5 @@ + + diff --git a/ui/detection-app/src/panels/PipelineGraphPanel.vue b/ui/detection-app/src/panels/PipelineGraphPanel.vue index bafa245..1c8d851 100644 --- a/ui/detection-app/src/panels/PipelineGraphPanel.vue +++ b/ui/detection-app/src/panels/PipelineGraphPanel.vue @@ -4,10 +4,11 @@ import { Panel } from 'mpr-ui-framework' import GraphRenderer from 'mpr-ui-framework/src/renderers/GraphRenderer.vue' import type { GraphNode } from 'mpr-ui-framework/src/renderers/GraphRenderer.vue' import type { DataSource } from 'mpr-ui-framework' +import { usePipelineStore } from '../stores/pipeline' const PIPELINE_NODES = [ - 'extract_frames', 'filter_scenes', 'detect_objects', 'run_ocr', - 'match_brands', 'escalate_vlm', 'escalate_cloud', 'compile_report', + 'extract_frames', 'filter_scenes', 'detect_objects', 'preprocess', + 'run_ocr', 'match_brands', 'escalate_vlm', 'escalate_cloud', 'compile_report', ] const props = defineProps<{ @@ -15,6 +16,8 @@ const props = defineProps<{ status?: 'idle' | 'live' | 'processing' | 'error' }>() +const pipeline = usePipelineStore() + const nodes = ref( PIPELINE_NODES.map((id) => ({ id, status: 'pending' })) ) @@ -22,10 +25,22 @@ const nodes = ref( props.source.on<{ nodes: GraphNode[] }>('graph_update', (e) => { nodes.value = e.nodes }) + +function onOpenRegionEditor(stage: string) { + pipeline.openBBoxEditor(stage) +} + +function onOpenStageEditor(stage: string) { + pipeline.openStageEditor(stage) +} diff --git a/ui/detection-app/src/stores/config.ts b/ui/detection-app/src/stores/config.ts new file mode 100644 index 0000000..65a5037 --- /dev/null +++ b/ui/detection-app/src/stores/config.ts @@ -0,0 +1,47 @@ +/** + * Config store — aggregated config from all panels. + * + * Panels write their own config slice (ocr, detection, etc.). + * Pipeline panel reads the full config and triggers replay. + * State shape defined in types/store-state.ts. + */ + +import { defineStore } from 'pinia' +import { ref, computed } from 'vue' +import type { ConfigState, ConfigOverrides } from '../types/store-state' + +export const useConfigStore = defineStore('config', () => { + const current = ref({}) + const pending = ref({}) + + const dirty = computed(() => JSON.stringify(pending.value) !== JSON.stringify(current.value)) + + function updatePending(section: keyof ConfigOverrides, values: Record) { + pending.value = { + ...pending.value, + [section]: { ...(pending.value[section] as Record || {}), ...values }, + } + } + + function apply() { + current.value = JSON.parse(JSON.stringify(pending.value)) + } + + function revert() { + pending.value = JSON.parse(JSON.stringify(current.value)) + } + + function loadFromServer(config: ConfigOverrides) { + current.value = config + pending.value = JSON.parse(JSON.stringify(config)) + } + + function getOverrides(): ConfigOverrides { + return JSON.parse(JSON.stringify(pending.value)) + } + + return { + current, pending, dirty, + updatePending, apply, revert, loadFromServer, getOverrides, + } +}) diff --git a/ui/detection-app/src/stores/data.ts b/ui/detection-app/src/stores/data.ts new file mode 100644 index 0000000..d9e5ff1 --- /dev/null +++ b/ui/detection-app/src/stores/data.ts @@ -0,0 +1,40 @@ +/** + * Data store — latest SSE data, replaces inline refs in App.vue. + * + * The SSE DataSource writes here. Panels read from here. + * State shape defined in types/store-state.ts. + */ + +import { defineStore } from 'pinia' +import { ref } from 'vue' +import type { DataState } from '../types/store-state' +import type { StatsUpdate, Detection } from '../types/sse-contract' + +export const useDataStore = defineStore('data', () => { + const stats = ref(null) + const detections = ref([]) + const connectionStatus = ref<'idle' | 'connecting' | 'live' | 'error'>('idle') + + function updateStats(s: StatsUpdate) { + stats.value = s + } + + function addDetection(d: Detection) { + detections.value.push(d) + } + + function setConnectionStatus(s: 'idle' | 'connecting' | 'live' | 'error') { + connectionStatus.value = s + } + + function reset() { + stats.value = null + detections.value = [] + connectionStatus.value = 'idle' + } + + return { + stats, detections, connectionStatus, + updateStats, addDetection, setConnectionStatus, reset, + } +}) diff --git a/ui/detection-app/src/stores/index.ts b/ui/detection-app/src/stores/index.ts new file mode 100644 index 0000000..3974a18 --- /dev/null +++ b/ui/detection-app/src/stores/index.ts @@ -0,0 +1,13 @@ +/** + * Store index — re-exports all stores. + * + * State shapes are in types/store-state.ts (the contract). + * These files are the Pinia bindings (the implementation). + * Swap Pinia for anything else by replacing these files, + * keeping the same function signatures. + */ + +export { usePipelineStore } from './pipeline' +export { useConfigStore } from './config' +export { useSelectionStore } from './selection' +export { useDataStore } from './data' diff --git a/ui/detection-app/src/stores/pipeline.ts b/ui/detection-app/src/stores/pipeline.ts new file mode 100644 index 0000000..65856da --- /dev/null +++ b/ui/detection-app/src/stores/pipeline.ts @@ -0,0 +1,96 @@ +/** + * Pipeline store — run state, transport controls, checkpoint status. + * + * State shape defined in types/store-state.ts. + * This file is just the Pinia binding. + */ + +import { defineStore } from 'pinia' +import { ref, computed } from 'vue' +import type { NodeState } from '../types/store-state' +import type { CheckpointInfo } from '../types/sse-contract' + +export const usePipelineStore = defineStore('pipeline', () => { + const jobId = ref('') + const status = ref('idle') + const nodes = ref([]) + const currentStage = ref(null) + const runId = ref(null) + const parentJobId = ref(null) + const runType = ref('initial') + const checkpoints = ref([]) + const error = ref(null) + + // Layout mode + const layoutMode = ref('normal') // normal | bbox_editor | stage_editor + const editorStage = ref(null) // which stage's editor is open + + const isRunning = computed(() => status.value === 'running') + const isPaused = computed(() => status.value === 'paused') + const canReplay = computed(() => checkpoints.value.length > 0) + const isEditing = computed(() => layoutMode.value !== 'normal') + + function setJob(id: string) { + jobId.value = id + } + + function setStatus(s: string) { + status.value = s + } + + function updateNodes(nodeList: NodeState[]) { + nodes.value = nodeList + const running = nodeList.find((n) => n.status === 'running') + currentStage.value = running?.id ?? null + } + + function setRunContext(rid: string, parentId: string, rtype: string) { + runId.value = rid + parentJobId.value = parentId + runType.value = rtype + } + + function setCheckpoints(list: CheckpointInfo[]) { + checkpoints.value = list + } + + function setError(msg: string | null) { + error.value = msg + if (msg) status.value = 'error' + } + + function openBBoxEditor(stage: string) { + layoutMode.value = 'bbox_editor' + editorStage.value = stage + } + + function openStageEditor(stage: string) { + layoutMode.value = 'stage_editor' + editorStage.value = stage + } + + function closeEditor() { + layoutMode.value = 'normal' + editorStage.value = null + } + + function reset() { + status.value = 'idle' + layoutMode.value = 'normal' + editorStage.value = null + nodes.value = [] + currentStage.value = null + runId.value = null + parentJobId.value = null + runType.value = 'initial' + error.value = null + } + + return { + jobId, status, nodes, currentStage, runId, parentJobId, runType, + checkpoints, error, layoutMode, editorStage, + isRunning, isPaused, canReplay, isEditing, + setJob, setStatus, updateNodes, setRunContext, setCheckpoints, setError, + openBBoxEditor, openStageEditor, closeEditor, reset, + } +}) diff --git a/ui/detection-app/src/stores/selection.ts b/ui/detection-app/src/stores/selection.ts new file mode 100644 index 0000000..324f8b2 --- /dev/null +++ b/ui/detection-app/src/stores/selection.ts @@ -0,0 +1,59 @@ +/** + * Selection store — cross-panel selection state. + * + * When you click a detection in the table, the frame viewer highlights it. + * When you hover on the timeline, the crosshair syncs across charts. + * When you draw a bbox, it feeds into the config store. + * + * State shape defined in types/store-state.ts. + */ + +import { defineStore } from 'pinia' +import { ref } from 'vue' +import type { SelectionState } from '../types/store-state' +import type { Detection } from '../types/sse-contract' + +export const useSelectionStore = defineStore('selection', () => { + const selectedFrame = ref(null) + const selectedDetection = ref(null) + const selectedBrand = ref(null) + const hoveredTimestamp = ref(null) + const bboxRegion = ref<{ x: number; y: number; w: number; h: number } | null>(null) + + function selectFrame(seq: number | null) { + selectedFrame.value = seq + } + + function selectDetection(det: Detection | null) { + selectedDetection.value = det + if (det) { + selectedBrand.value = det.brand + selectedFrame.value = det.frame_ref + } + } + + function selectBrand(brand: string | null) { + selectedBrand.value = brand + } + + function hoverTimestamp(ts: number | null) { + hoveredTimestamp.value = ts + } + + function setBbox(region: { x: number; y: number; w: number; h: number } | null) { + bboxRegion.value = region + } + + function clearAll() { + selectedFrame.value = null + selectedDetection.value = null + selectedBrand.value = null + hoveredTimestamp.value = null + bboxRegion.value = null + } + + return { + selectedFrame, selectedDetection, selectedBrand, hoveredTimestamp, bboxRegion, + selectFrame, selectDetection, selectBrand, hoverTimestamp, setBbox, clearAll, + } +}) diff --git a/ui/detection-app/src/types/sse-contract.ts b/ui/detection-app/src/types/sse-contract.ts index 78833f8..be6b447 100644 --- a/ui/detection-app/src/types/sse-contract.ts +++ b/ui/detection-app/src/types/sse-contract.ts @@ -96,16 +96,12 @@ export interface JobComplete { report: DetectionReportSummary | null; } -// --- Run context (injected into all SSE events) --- - export interface RunContext { run_id: string; parent_job_id: string; - run_type: 'initial' | 'replay' | 'retry'; + run_type: string; } -// --- Checkpoint API types --- - export interface CheckpointInfo { stage: string; } @@ -113,7 +109,7 @@ export interface CheckpointInfo { export interface ReplayRequest { job_id: string; start_stage: string; - config_overrides?: Record; + config_overrides: Record | null; } export interface ReplayResponse { @@ -126,9 +122,9 @@ export interface ReplayResponse { export interface RetryRequest { job_id: string; - config_overrides?: Record; - start_stage?: string; - schedule_seconds?: number; + config_overrides: Record | null; + start_stage: string; + schedule_seconds: number | null; } export interface RetryResponse { diff --git a/ui/detection-app/src/types/store-state.ts b/ui/detection-app/src/types/store-state.ts new file mode 100644 index 0000000..f4f17e7 --- /dev/null +++ b/ui/detection-app/src/types/store-state.ts @@ -0,0 +1,82 @@ +/** + * TypeScript Types - GENERATED FILE + * + * Do not edit directly. Regenerate using modelgen. + */ + + +export interface NodeState { + id: string; + status: string; + has_checkpoint: boolean; + has_region_editor: boolean; + has_config_editor: boolean; +} + +export interface PipelineState { + job_id: string; + status: string; + layout_mode: string; + editor_stage: string | null; + nodes: NodeState[]; + current_stage: string | null; + run_id: string | null; + parent_job_id: string | null; + run_type: string; + error: string | null; +} + +export interface DetectionConfigOverrides { + model_name: string | null; + confidence_threshold: number | null; + target_classes: string[] | null; +} + +export interface OCRConfigOverrides { + languages: string[] | null; + min_confidence: number | null; +} + +export interface ResolverConfigOverrides { + fuzzy_threshold: number | null; +} + +export interface EscalationConfigOverrides { + vlm_min_confidence: number | null; + cloud_min_confidence: number | null; + cloud_provider: string | null; +} + +export interface PreprocessingConfigOverrides { + binarize: boolean | null; + deskew: boolean | null; + contrast: boolean | null; +} + +export interface ConfigOverrides { + detection: DetectionConfigOverrides | null; + ocr: OCRConfigOverrides | null; + resolver: ResolverConfigOverrides | null; + escalation: EscalationConfigOverrides | null; + preprocessing: PreprocessingConfigOverrides | null; +} + +export interface ConfigState { + current: ConfigOverrides; + pending: ConfigOverrides; + dirty: boolean; +} + +export interface BboxRegion { + x: number; + y: number; + w: number; + h: number; +} + +export interface SelectionState { + selected_frame: number | null; + selected_brand: string | null; + hovered_timestamp: number | null; + bbox_region: BboxRegion | null; +} diff --git a/ui/framework/src/index.ts b/ui/framework/src/index.ts index a9742c8..9d52836 100644 --- a/ui/framework/src/index.ts +++ b/ui/framework/src/index.ts @@ -15,3 +15,10 @@ export { default as TimeSeriesRenderer } from './renderers/TimeSeriesRenderer.vu export { default as GraphRenderer } from './renderers/GraphRenderer.vue' export { default as FrameRenderer } from './renderers/FrameRenderer.vue' export { default as TableRenderer } from './renderers/TableRenderer.vue' + +// Interaction plugins +export type { InteractionPlugin, PluginContext } from './plugins/InteractionPlugin' +export { BBoxDrawPlugin } from './plugins/BBoxDrawPlugin' +export type { BBoxResult, BBoxCallback } from './plugins/BBoxDrawPlugin' +export { CrosshairPlugin } from './plugins/CrosshairPlugin' +export type { CrosshairCallback } from './plugins/CrosshairPlugin' diff --git a/ui/framework/src/plugins/BBoxDrawPlugin.ts b/ui/framework/src/plugins/BBoxDrawPlugin.ts new file mode 100644 index 0000000..064ef6d --- /dev/null +++ b/ui/framework/src/plugins/BBoxDrawPlugin.ts @@ -0,0 +1,88 @@ +/** + * BBoxDrawPlugin — draw bounding boxes on the frame viewer. + * + * User drags on the canvas to draw a rectangle. + * On pointer up, emits the bbox coordinates via the callback. + * The frame viewer panel feeds this into the selection store. + */ + +import type { InteractionPlugin, PluginContext } from './InteractionPlugin' + +export interface BBoxResult { + x: number + y: number + w: number + h: number +} + +export type BBoxCallback = (bbox: BBoxResult) => void + +export class BBoxDrawPlugin implements InteractionPlugin { + name = 'bbox-draw' + + private ctx: CanvasRenderingContext2D | null = null + private drawing = false + private startX = 0 + private startY = 0 + private currentBox: BBoxResult | null = null + private callback: BBoxCallback + + constructor(callback: BBoxCallback) { + this.callback = callback + } + + onMount(context: PluginContext): void { + this.ctx = context.ctx + } + + onUnmount(): void { + this.ctx = null + this.drawing = false + this.currentBox = null + } + + onPointerDown(e: PointerEvent): void { + this.drawing = true + this.startX = e.offsetX + this.startY = e.offsetY + this.currentBox = null + } + + onPointerMove(e: PointerEvent): void { + if (!this.drawing) return + + const x = Math.min(this.startX, e.offsetX) + const y = Math.min(this.startY, e.offsetY) + const w = Math.abs(e.offsetX - this.startX) + const h = Math.abs(e.offsetY - this.startY) + + this.currentBox = { x, y, w, h } + } + + onPointerUp(_e: PointerEvent): void { + if (!this.drawing) return + this.drawing = false + + if (this.currentBox && this.currentBox.w > 5 && this.currentBox.h > 5) { + this.callback(this.currentBox) + } + + this.currentBox = null + } + + render(ctx: CanvasRenderingContext2D): void { + if (!this.currentBox) return + + const box = this.currentBox + + ctx.strokeStyle = '#4f9cf9' + ctx.lineWidth = 2 + ctx.setLineDash([6, 3]) + ctx.strokeRect(box.x, box.y, box.w, box.h) + ctx.setLineDash([]) + + // Semi-transparent fill + ctx.fillStyle = 'rgba(79, 156, 249, 0.1)' + ctx.fillRect(box.x, box.y, box.w, box.h) + } +} diff --git a/ui/framework/src/plugins/CrosshairPlugin.ts b/ui/framework/src/plugins/CrosshairPlugin.ts new file mode 100644 index 0000000..0011b5d --- /dev/null +++ b/ui/framework/src/plugins/CrosshairPlugin.ts @@ -0,0 +1,60 @@ +/** + * CrosshairPlugin — synchronized vertical crosshair across time-series panels. + * + * When the user hovers on any panel with this plugin, the crosshair + * position (as a timestamp) is written to the selection store. + * All panels with this plugin render a vertical line at that timestamp. + */ + +import type { InteractionPlugin, PluginContext } from './InteractionPlugin' + +export type CrosshairCallback = (timestamp: number | null) => void + +export class CrosshairPlugin implements InteractionPlugin { + name = 'crosshair' + + private width = 0 + private callback: CrosshairCallback + + /** Current crosshair X position (pixels), set externally from store */ + public crosshairX: number | null = null + + constructor(callback: CrosshairCallback) { + this.callback = callback + } + + onMount(context: PluginContext): void { + this.width = context.width + } + + onUnmount(): void { + this.crosshairX = null + } + + onPointerMove(e: PointerEvent): void { + // Convert pixel X to normalized position (0-1) + const normalized = e.offsetX / this.width + this.callback(normalized) + } + + onPointerDown(_e: PointerEvent): void { + // no-op for crosshair + } + + onPointerUp(_e: PointerEvent): void { + this.callback(null) + } + + render(ctx: CanvasRenderingContext2D): void { + if (this.crosshairX === null) return + + ctx.strokeStyle = '#a78bfa' + ctx.lineWidth = 1 + ctx.setLineDash([4, 4]) + ctx.beginPath() + ctx.moveTo(this.crosshairX, 0) + ctx.lineTo(this.crosshairX, ctx.canvas.height) + ctx.stroke() + ctx.setLineDash([]) + } +} diff --git a/ui/framework/src/plugins/InteractionPlugin.ts b/ui/framework/src/plugins/InteractionPlugin.ts new file mode 100644 index 0000000..82fd944 --- /dev/null +++ b/ui/framework/src/plugins/InteractionPlugin.ts @@ -0,0 +1,36 @@ +/** + * Interaction plugin interface. + * + * Plugins attach to a Panel's overlay canvas. They receive pointer events + * and emit typed results via the callback. The panel handles rendering + * the overlay and routing events to the active plugin. + */ + +export interface PluginContext { + /** Canvas element for drawing overlays */ + canvas: HTMLCanvasElement + /** 2D rendering context */ + ctx: CanvasRenderingContext2D + /** Canvas dimensions (may differ from display size) */ + width: number + height: number +} + +export interface InteractionPlugin { + /** Unique plugin name */ + name: string + + /** Called when the plugin is mounted on a panel */ + onMount(context: PluginContext): void + + /** Called when the plugin is unmounted */ + onUnmount(): void + + /** Pointer event handlers (optional) */ + onPointerDown?(e: PointerEvent): void + onPointerMove?(e: PointerEvent): void + onPointerUp?(e: PointerEvent): void + + /** Called each animation frame to render the overlay */ + render(ctx: CanvasRenderingContext2D): void +} diff --git a/ui/framework/src/renderers/FrameRenderer.vue b/ui/framework/src/renderers/FrameRenderer.vue index c0b13cd..86a2b37 100644 --- a/ui/framework/src/renderers/FrameRenderer.vue +++ b/ui/framework/src/renderers/FrameRenderer.vue @@ -8,6 +8,10 @@ export interface FrameBBox { h: number confidence: number label: string + resolved_brand?: string | null + source?: string | null + stage?: string | null + ocr_text?: string | null } const props = defineProps<{ @@ -46,27 +50,37 @@ function draw() { const bw = box.w * scale const bh = box.h * scale - // Box outline - ctx.strokeStyle = confidenceColor(box.confidence) + const color = sourceColor(box) + const resolved = box.resolved_brand || box.ocr_text + + // Box outline only — no labels, no percentages + ctx.strokeStyle = color ctx.lineWidth = 2 + if (!resolved) { + ctx.setLineDash([4, 3]) + } ctx.strokeRect(bx, by, bw, bh) - - // Label background - const label = `${box.label} ${(box.confidence * 100).toFixed(0)}%` - ctx.font = '11px var(--font-mono)' - const metrics = ctx.measureText(label) - const labelH = 16 - ctx.fillStyle = confidenceColor(box.confidence) - ctx.fillRect(bx, by - labelH, metrics.width + 8, labelH) - - // Label text - ctx.fillStyle = '#000' - ctx.fillText(label, bx + 4, by - 4) + ctx.setLineDash([]) } } img.src = `data:image/jpeg;base64,${props.imageSrc}` } +const SOURCE_COLORS: Record = { + yolo: '#f5a623', // yellow — raw detection + ocr: '#ff8c42', // orange — text extracted + ocr_matched: '#3ecf8e', // green — brand resolved + local_vlm: '#4f9cf9', // blue — VLM resolved + cloud_llm: '#a78bfa', // purple — cloud resolved + unresolved: '#e05252', // red — nothing matched +} + +function sourceColor(box: FrameBBox): string { + if (box.resolved_brand) return SOURCE_COLORS.ocr_matched + if (box.source && box.source in SOURCE_COLORS) return SOURCE_COLORS[box.source] + return confidenceColor(box.confidence) +} + function confidenceColor(conf: number): string { if (conf >= 0.7) return 'var(--conf-high)' if (conf >= 0.4) return 'var(--conf-mid)' diff --git a/ui/framework/src/renderers/GraphRenderer.vue b/ui/framework/src/renderers/GraphRenderer.vue index f0f7bee..60f12b7 100644 --- a/ui/framework/src/renderers/GraphRenderer.vue +++ b/ui/framework/src/renderers/GraphRenderer.vue @@ -11,8 +11,19 @@ export interface GraphNode { const props = defineProps<{ nodes: GraphNode[] + /** Stages that have a region editor (bbox/polygon) */ + regionStages?: string[] }>() +const emit = defineEmits<{ + 'open-region-editor': [stage: string] + 'open-stage-editor': [stage: string] +}>() + +const regionStageSet = computed(() => new Set(props.regionStages ?? [ + 'detect_objects', 'run_ocr', 'match_brands', 'escalate_vlm', 'escalate_cloud', +])) + const statusColors: Record = { pending: 'var(--status-idle)', running: 'var(--status-processing)', @@ -23,17 +34,15 @@ const statusColors: Record = { const flowNodes = computed(() => props.nodes.map((n, i) => ({ id: n.id, - label: n.id.replace(/_/g, ' '), - position: { x: 20, y: i * 70 }, - style: { - background: statusColors[n.status] ?? statusColors.pending, - color: n.status === 'pending' ? '#ccc' : '#000', - border: 'none', - borderRadius: 'var(--panel-radius)', - fontFamily: 'var(--font-mono)', - fontSize: 'var(--font-size-sm)', - fontWeight: '600', - padding: '8px 16px', + type: 'stage', + position: { x: 20, y: i * 80 }, + data: { + label: n.id.replace(/_/g, ' '), + status: n.status, + color: statusColors[n.status] ?? statusColors.pending, + textColor: n.status === 'pending' ? '#888' : '#000', + hasRegionEditor: regionStageSet.value.has(n.id), + isRunning: n.status === 'running', }, })) ) @@ -63,7 +72,38 @@ const flowEdges = computed(() => { :nodes-connectable="false" :zoom-on-scroll="false" :pan-on-scroll="false" - /> + > + + @@ -77,4 +117,66 @@ const flowEdges = computed(() => { .graph-renderer :deep(.vue-flow__background) { background: transparent; } + +/* Hide default node styling — we use custom template */ +.graph-renderer :deep(.vue-flow__node-stage) { + padding: 0; + border: none; + background: transparent; + border-radius: 0; +} + +.stage-node { + display: flex; + align-items: center; + gap: 6px; + padding: 6px 10px; + border-radius: var(--panel-radius); + font-family: var(--font-mono); + font-size: var(--font-size-sm); + font-weight: 600; + min-width: 180px; +} + +.stage-node.running { + animation: node-pulse 1.5s infinite; +} + +.stage-label { + flex: 1; +} + +.stage-actions { + display: flex; + gap: 2px; + opacity: 0; + transition: opacity 0.15s; +} + +.stage-node:hover .stage-actions { + opacity: 1; +} + +.stage-btn { + background: rgba(0, 0, 0, 0.15); + border: none; + border-radius: 3px; + width: 20px; + height: 20px; + font-size: 11px; + cursor: pointer; + display: flex; + align-items: center; + justify-content: center; + color: inherit; +} + +.stage-btn:hover { + background: rgba(0, 0, 0, 0.3); +} + +@keyframes node-pulse { + 0%, 100% { opacity: 1; } + 50% { opacity: 0.7; } +}