phase cv 0
This commit is contained in:
@@ -26,6 +26,7 @@ class OverrideProfile:
|
||||
{
|
||||
"frame_extraction": {"fps": 1.0},
|
||||
"scene_filter": {"hamming_threshold": 12},
|
||||
"region_analysis": {"edge_canny_low": 30, "edge_canny_high": 120},
|
||||
"detection": {"confidence_threshold": 0.5},
|
||||
"ocr": {"languages": ["en", "es"], "min_confidence": 0.3},
|
||||
"resolver": {"fuzzy_threshold": 60},
|
||||
@@ -52,6 +53,9 @@ class OverrideProfile:
|
||||
def scene_filter_config(self):
|
||||
return self._patch(self._base.scene_filter_config(), "scene_filter")
|
||||
|
||||
def region_analysis_config(self):
|
||||
return self._patch(self._base.region_analysis_config(), "region_analysis")
|
||||
|
||||
def detection_config(self):
|
||||
return self._patch(self._base.detection_config(), "detection")
|
||||
|
||||
@@ -130,3 +134,137 @@ def replay_from(
|
||||
emit.clear_run_context()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def replay_single_stage(
|
||||
job_id: str,
|
||||
stage: str,
|
||||
frame_refs: list[int] | None = None,
|
||||
config_overrides: dict | None = None,
|
||||
debug: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Replay a single stage on specific frames (or all frames from checkpoint).
|
||||
|
||||
Fast path for interactive parameter tuning — runs only the target stage
|
||||
function, not the full pipeline tail. Returns the stage output directly.
|
||||
|
||||
When debug=True and stage is detect_edges, returns additional overlay
|
||||
data (Canny edges, Hough lines) for visual feedback in the editor.
|
||||
|
||||
For detect_edges: returns {"edge_regions_by_frame": {seq: [box, ...]}}
|
||||
With debug=True, also returns {"debug": {seq: {edge_overlay_b64, lines_overlay_b64, ...}}}
|
||||
"""
|
||||
if stage not in NODES:
|
||||
raise ValueError(f"Unknown stage: {stage!r}. Options: {NODES}")
|
||||
|
||||
stage_idx = NODES.index(stage)
|
||||
if stage_idx == 0:
|
||||
raise ValueError("Cannot replay the first stage — just run the full pipeline")
|
||||
|
||||
previous_stage = NODES[stage_idx - 1]
|
||||
|
||||
available = list_checkpoints(job_id)
|
||||
if previous_stage not in available:
|
||||
raise ValueError(
|
||||
f"No checkpoint for stage {previous_stage!r} (job {job_id}). "
|
||||
f"Available: {available}"
|
||||
)
|
||||
|
||||
logger.info("Single-stage replay: job %s, stage %s (loading checkpoint: %s, debug=%s)",
|
||||
job_id, stage, previous_stage, debug)
|
||||
|
||||
state = load_checkpoint(job_id, previous_stage)
|
||||
|
||||
# Build profile with overrides
|
||||
from detect.profiles import get_profile
|
||||
profile = get_profile(state.get("profile_name", "soccer_broadcast"))
|
||||
if config_overrides:
|
||||
profile = OverrideProfile(profile, config_overrides)
|
||||
|
||||
# Run the stage function directly (not through the graph)
|
||||
if stage == "detect_edges":
|
||||
return _replay_detect_edges(state, profile, frame_refs, job_id, debug)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Single-stage replay not yet implemented for {stage!r}. "
|
||||
f"Use replay_from() for full pipeline replay."
|
||||
)
|
||||
|
||||
|
||||
def _replay_detect_edges(
|
||||
state: dict,
|
||||
profile,
|
||||
frame_refs: list[int] | None,
|
||||
job_id: str,
|
||||
debug: bool,
|
||||
) -> dict:
|
||||
"""Run edge detection on checkpoint frames, optionally with debug overlays."""
|
||||
import os
|
||||
from detect.stages.edge_detector import detect_edge_regions
|
||||
|
||||
config = profile.region_analysis_config()
|
||||
frames = state.get("filtered_frames", [])
|
||||
|
||||
if frame_refs:
|
||||
ref_set = set(frame_refs)
|
||||
frames = [f for f in frames if f.sequence in ref_set]
|
||||
|
||||
inference_url = os.environ.get("INFERENCE_URL")
|
||||
|
||||
# Normal run — always needed for the boxes
|
||||
result = detect_edge_regions(
|
||||
frames=frames,
|
||||
config=config,
|
||||
inference_url=inference_url,
|
||||
job_id=job_id,
|
||||
)
|
||||
output = {"edge_regions_by_frame": result}
|
||||
|
||||
# Debug overlays — call debug endpoint (remote) or local debug function
|
||||
if debug and frames:
|
||||
debug_data = {}
|
||||
if inference_url:
|
||||
from detect.inference import InferenceClient
|
||||
client = InferenceClient(base_url=inference_url, job_id=job_id)
|
||||
for frame in frames:
|
||||
dr = client.detect_edges_debug(
|
||||
image=frame.image,
|
||||
edge_canny_low=config.edge_canny_low,
|
||||
edge_canny_high=config.edge_canny_high,
|
||||
edge_hough_threshold=config.edge_hough_threshold,
|
||||
edge_hough_min_length=config.edge_hough_min_length,
|
||||
edge_hough_max_gap=config.edge_hough_max_gap,
|
||||
edge_pair_max_distance=config.edge_pair_max_distance,
|
||||
edge_pair_min_distance=config.edge_pair_min_distance,
|
||||
)
|
||||
debug_data[frame.sequence] = {
|
||||
"edge_overlay_b64": dr.edge_overlay_b64,
|
||||
"lines_overlay_b64": dr.lines_overlay_b64,
|
||||
"horizontal_count": dr.horizontal_count,
|
||||
"pair_count": dr.pair_count,
|
||||
}
|
||||
else:
|
||||
# Local mode — import GPU module directly
|
||||
from detect.stages.edge_detector import _load_cv_edges
|
||||
edges_mod = _load_cv_edges()
|
||||
for frame in frames:
|
||||
dr = edges_mod.detect_edges_debug(
|
||||
frame.image,
|
||||
canny_low=config.edge_canny_low,
|
||||
canny_high=config.edge_canny_high,
|
||||
hough_threshold=config.edge_hough_threshold,
|
||||
hough_min_length=config.edge_hough_min_length,
|
||||
hough_max_gap=config.edge_hough_max_gap,
|
||||
pair_max_distance=config.edge_pair_max_distance,
|
||||
pair_min_distance=config.edge_pair_min_distance,
|
||||
)
|
||||
debug_data[frame.sequence] = {
|
||||
"edge_overlay_b64": dr["edge_overlay_b64"],
|
||||
"lines_overlay_b64": dr["lines_overlay_b64"],
|
||||
"horizontal_count": dr["horizontal_count"],
|
||||
"pair_count": dr["pair_count"],
|
||||
}
|
||||
output["debug"] = debug_data
|
||||
|
||||
return output
|
||||
|
||||
@@ -2,39 +2,19 @@
|
||||
Checkpoint storage — save/load stage state.
|
||||
|
||||
Binary data (frame images) → S3/MinIO via frames.py
|
||||
Structured data (boxes, detections, stats, config) → Postgres via Django ORM
|
||||
|
||||
Until the Django model is generated by modelgen, checkpoint data is stored
|
||||
as JSON in S3 as a fallback. Once DetectJob/StageCheckpoint models exist,
|
||||
this module switches to Postgres.
|
||||
Structured data (stage output, stats, config) → Postgres
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from .frames import save_frames, load_frames, BUCKET, CHECKPOINT_PREFIX
|
||||
from .frames import save_frames, load_frames, CHECKPOINT_PREFIX
|
||||
from .serializer import serialize_state, deserialize_state
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _has_db() -> bool:
|
||||
"""Check if Postgres is reachable."""
|
||||
try:
|
||||
from core.db.connection import get_session
|
||||
from sqlmodel import text
|
||||
with get_session() as session:
|
||||
session.exec(text("SELECT 1"))
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Save
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -45,34 +25,24 @@ def save_checkpoint(
|
||||
stage_index: int,
|
||||
state: dict,
|
||||
frames_manifest: dict[int, str] | None = None,
|
||||
is_scenario: bool = False,
|
||||
scenario_label: str = "",
|
||||
) -> str:
|
||||
"""
|
||||
Save a stage checkpoint.
|
||||
|
||||
Saves frame images to S3 (if not already saved), then persists
|
||||
structured state to Postgres (or S3 JSON fallback).
|
||||
structured state to Postgres.
|
||||
|
||||
Returns the checkpoint identifier (DB id or S3 key).
|
||||
Returns the checkpoint DB id.
|
||||
"""
|
||||
# Save frames to S3 if no manifest provided
|
||||
from core.db.detect import save_stage_checkpoint
|
||||
|
||||
if frames_manifest is None:
|
||||
all_frames = state.get("frames", [])
|
||||
frames_manifest = save_frames(job_id, all_frames)
|
||||
|
||||
checkpoint_data = serialize_state(state, frames_manifest)
|
||||
|
||||
if _has_db():
|
||||
checkpoint_id = _save_to_db(job_id, stage, stage_index, checkpoint_data)
|
||||
else:
|
||||
checkpoint_id = _save_to_s3(job_id, stage, checkpoint_data)
|
||||
|
||||
return checkpoint_id
|
||||
|
||||
|
||||
def _save_to_db(job_id: str, stage: str, stage_index: int, data: dict) -> str:
|
||||
"""Save checkpoint structured data to Postgres."""
|
||||
from core.db.detect import save_stage_checkpoint
|
||||
|
||||
frames_prefix = f"{CHECKPOINT_PREFIX}/{job_id}/frames/"
|
||||
|
||||
checkpoint = save_stage_checkpoint(
|
||||
@@ -80,44 +50,24 @@ def _save_to_db(job_id: str, stage: str, stage_index: int, data: dict) -> str:
|
||||
stage=stage,
|
||||
stage_index=stage_index,
|
||||
frames_prefix=frames_prefix,
|
||||
frames_manifest=data.get("frames_manifest", {}),
|
||||
frames_meta=data.get("frames_meta", []),
|
||||
filtered_frame_sequences=data.get("filtered_frame_sequences", []),
|
||||
boxes_by_frame=data.get("boxes_by_frame", {}),
|
||||
text_candidates=data.get("text_candidates", []),
|
||||
unresolved_candidates=data.get("unresolved_candidates", []),
|
||||
detections=data.get("detections", []),
|
||||
stats=data.get("stats", {}),
|
||||
config_snapshot=data.get("config_overrides", {}),
|
||||
config_overrides=data.get("config_overrides", {}),
|
||||
video_path=data.get("video_path", ""),
|
||||
profile_name=data.get("profile_name", ""),
|
||||
frames_manifest=checkpoint_data.get("frames_manifest", {}),
|
||||
frames_meta=checkpoint_data.get("frames_meta", []),
|
||||
filtered_frame_sequences=checkpoint_data.get("filtered_frame_sequences", []),
|
||||
stage_output_key=checkpoint_data.get("stage_output_key", ""),
|
||||
stats=checkpoint_data.get("stats", {}),
|
||||
config_snapshot=checkpoint_data.get("config_overrides", {}),
|
||||
config_overrides=checkpoint_data.get("config_overrides", {}),
|
||||
video_path=checkpoint_data.get("video_path", ""),
|
||||
profile_name=checkpoint_data.get("profile_name", ""),
|
||||
is_scenario=is_scenario,
|
||||
scenario_label=scenario_label,
|
||||
)
|
||||
|
||||
logger.info("Checkpoint saved to DB: %s/%s (id=%s)", job_id, stage, checkpoint.id)
|
||||
logger.info("Checkpoint saved: %s/%s (id=%s, scenario=%s)",
|
||||
job_id, stage, checkpoint.id, is_scenario)
|
||||
return str(checkpoint.id)
|
||||
|
||||
|
||||
def _save_to_s3(job_id: str, stage: str, data: dict) -> str:
|
||||
"""Fallback: save checkpoint as JSON to S3 (before modelgen generates DB models)."""
|
||||
from core.storage.s3 import upload_file
|
||||
|
||||
checkpoint_json = json.dumps(data, default=str)
|
||||
key = f"{CHECKPOINT_PREFIX}/{job_id}/stages/{stage}.json"
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp:
|
||||
tmp.write(checkpoint_json)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
upload_file(tmp_path, BUCKET, key)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
logger.info("Checkpoint saved to S3: s3://%s/%s", BUCKET, key)
|
||||
return key
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Load
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -125,30 +75,12 @@ def _save_to_s3(job_id: str, stage: str, data: dict) -> str:
|
||||
def load_checkpoint(job_id: str, stage: str) -> dict:
|
||||
"""
|
||||
Load a stage checkpoint and reconstitute full DetectState.
|
||||
|
||||
Tries Postgres first, falls back to S3 JSON.
|
||||
"""
|
||||
if _has_db():
|
||||
data = _load_from_db(job_id, stage)
|
||||
else:
|
||||
data = _load_from_s3(job_id, stage)
|
||||
|
||||
raw_manifest = data.get("frames_manifest", {})
|
||||
manifest = {int(k): v for k, v in raw_manifest.items()}
|
||||
frame_metadata = data.get("frames_meta", [])
|
||||
frames = load_frames(manifest, frame_metadata)
|
||||
|
||||
state = deserialize_state(data, frames)
|
||||
|
||||
logger.info("Checkpoint loaded: %s/%s (%d frames)", job_id, stage, len(frames))
|
||||
return state
|
||||
|
||||
|
||||
def _load_from_db(job_id: str, stage: str) -> dict:
|
||||
"""Load checkpoint data from Postgres via core/db."""
|
||||
from core.db.detect import get_stage_checkpoint
|
||||
|
||||
checkpoint = get_stage_checkpoint(job_id, stage)
|
||||
if not checkpoint:
|
||||
raise ValueError(f"No checkpoint for {job_id}/{stage}")
|
||||
|
||||
data = {
|
||||
"job_id": str(checkpoint.job_id),
|
||||
@@ -158,28 +90,20 @@ def _load_from_db(job_id: str, stage: str) -> dict:
|
||||
"frames_manifest": checkpoint.frames_manifest,
|
||||
"frames_meta": checkpoint.frames_meta,
|
||||
"filtered_frame_sequences": checkpoint.filtered_frame_sequences,
|
||||
"boxes_by_frame": checkpoint.boxes_by_frame,
|
||||
"text_candidates": checkpoint.text_candidates,
|
||||
"unresolved_candidates": checkpoint.unresolved_candidates,
|
||||
"detections": checkpoint.detections,
|
||||
"stage_output_key": checkpoint.stage_output_key,
|
||||
"stats": checkpoint.stats,
|
||||
}
|
||||
return data
|
||||
|
||||
raw_manifest = data.get("frames_manifest", {})
|
||||
manifest = {int(k): v for k, v in raw_manifest.items()}
|
||||
frame_metadata = data.get("frames_meta", [])
|
||||
frames = load_frames(manifest, frame_metadata)
|
||||
|
||||
def _load_from_s3(job_id: str, stage: str) -> dict:
|
||||
"""Fallback: load checkpoint JSON from S3."""
|
||||
from core.storage.s3 import download_to_temp
|
||||
state = deserialize_state(data, frames)
|
||||
|
||||
key = f"{CHECKPOINT_PREFIX}/{job_id}/stages/{stage}.json"
|
||||
tmp_path = download_to_temp(BUCKET, key)
|
||||
try:
|
||||
with open(tmp_path) as f:
|
||||
data = json.load(f)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
return data
|
||||
logger.info("Checkpoint loaded: %s/%s (%d frames, scenario=%s)",
|
||||
job_id, stage, len(frames), checkpoint.is_scenario)
|
||||
return state
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -188,25 +112,5 @@ def _load_from_s3(job_id: str, stage: str) -> dict:
|
||||
|
||||
def list_checkpoints(job_id: str) -> list[str]:
|
||||
"""List available checkpoint stages for a job."""
|
||||
if _has_db():
|
||||
return _list_from_db(job_id)
|
||||
return _list_from_s3(job_id)
|
||||
|
||||
|
||||
def _list_from_db(job_id: str) -> list[str]:
|
||||
from core.db.detect import list_stage_checkpoints
|
||||
return list_stage_checkpoints(job_id)
|
||||
|
||||
|
||||
def _list_from_s3(job_id: str) -> list[str]:
|
||||
from core.storage.s3 import list_objects
|
||||
|
||||
prefix = f"{CHECKPOINT_PREFIX}/{job_id}/stages/"
|
||||
objects = list_objects(BUCKET, prefix)
|
||||
|
||||
stages = []
|
||||
for obj in objects:
|
||||
name = Path(obj["key"]).stem
|
||||
stages.append(name)
|
||||
|
||||
return stages
|
||||
|
||||
@@ -17,6 +17,7 @@ from detect.profiles import SoccerBroadcastProfile
|
||||
from detect.state import DetectState
|
||||
from detect.stages.frame_extractor import extract_frames
|
||||
from detect.stages.scene_filter import scene_filter
|
||||
from detect.stages.edge_detector import detect_edge_regions
|
||||
from detect.stages.yolo_detector import detect_objects
|
||||
from detect.stages.preprocess import preprocess_regions
|
||||
from detect.stages.ocr_stage import run_ocr
|
||||
@@ -31,6 +32,7 @@ INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode
|
||||
NODES = [
|
||||
"extract_frames",
|
||||
"filter_scenes",
|
||||
"detect_edges",
|
||||
"detect_objects",
|
||||
"preprocess",
|
||||
"run_ocr",
|
||||
@@ -119,6 +121,28 @@ def node_filter_scenes(state: DetectState) -> dict:
|
||||
return {"filtered_frames": kept, "stats": stats}
|
||||
|
||||
|
||||
def node_detect_edges(state: DetectState) -> dict:
|
||||
_emit_transition(state, "detect_edges", "running")
|
||||
|
||||
with trace_node(state, "detect_edges") as span:
|
||||
profile = _get_profile(state)
|
||||
config = profile.region_analysis_config()
|
||||
frames = state.get("filtered_frames", [])
|
||||
job_id = state.get("job_id")
|
||||
|
||||
regions = detect_edge_regions(
|
||||
frames, config, inference_url=INFERENCE_URL, job_id=job_id,
|
||||
)
|
||||
total = sum(len(r) for r in regions.values())
|
||||
span.set_output({"frames": len(frames), "edge_regions": total})
|
||||
|
||||
stats = state.get("stats", PipelineStats())
|
||||
stats.cv_regions_detected = total
|
||||
|
||||
_emit_transition(state, "detect_edges", "done")
|
||||
return {"edge_regions_by_frame": regions, "stats": stats}
|
||||
|
||||
|
||||
def node_detect_objects(state: DetectState) -> dict:
|
||||
_emit_transition(state, "detect_objects", "running")
|
||||
|
||||
@@ -359,6 +383,7 @@ def _checkpointing_node(node_name: str, node_fn):
|
||||
NODE_FUNCTIONS = [
|
||||
("extract_frames", node_extract_frames),
|
||||
("filter_scenes", node_filter_scenes),
|
||||
("detect_edges", node_detect_edges),
|
||||
("detect_objects", node_detect_objects),
|
||||
("preprocess", node_preprocess),
|
||||
("run_ocr", node_run_ocr),
|
||||
|
||||
@@ -16,7 +16,7 @@ import numpy as np
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
from .types import DetectResult, OCRResult, ServerStatus, VLMResult
|
||||
from .types import DetectResult, OCRResult, RegionDebugResult, RegionResult, ServerStatus, VLMResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -145,6 +145,92 @@ class InferenceClient:
|
||||
reasoning=data.get("reasoning", ""),
|
||||
)
|
||||
|
||||
def detect_edges(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
edge_canny_low: int = 50,
|
||||
edge_canny_high: int = 150,
|
||||
edge_hough_threshold: int = 80,
|
||||
edge_hough_min_length: int = 100,
|
||||
edge_hough_max_gap: int = 10,
|
||||
edge_pair_max_distance: int = 200,
|
||||
edge_pair_min_distance: int = 15,
|
||||
) -> list[RegionResult]:
|
||||
"""Run edge detection on an image."""
|
||||
payload = {
|
||||
"image": _encode_image(image),
|
||||
"edge_canny_low": edge_canny_low,
|
||||
"edge_canny_high": edge_canny_high,
|
||||
"edge_hough_threshold": edge_hough_threshold,
|
||||
"edge_hough_min_length": edge_hough_min_length,
|
||||
"edge_hough_max_gap": edge_hough_max_gap,
|
||||
"edge_pair_max_distance": edge_pair_max_distance,
|
||||
"edge_pair_min_distance": edge_pair_min_distance,
|
||||
}
|
||||
|
||||
resp = self.session.post(
|
||||
f"{self.base_url}/detect_edges",
|
||||
json=payload,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
results = []
|
||||
for r in resp.json().get("regions", []):
|
||||
result = RegionResult(
|
||||
x=r["x"], y=r["y"], w=r["w"], h=r["h"],
|
||||
confidence=r["confidence"], label=r["label"],
|
||||
)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
def detect_edges_debug(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
edge_canny_low: int = 50,
|
||||
edge_canny_high: int = 150,
|
||||
edge_hough_threshold: int = 80,
|
||||
edge_hough_min_length: int = 100,
|
||||
edge_hough_max_gap: int = 10,
|
||||
edge_pair_max_distance: int = 200,
|
||||
edge_pair_min_distance: int = 15,
|
||||
) -> RegionDebugResult:
|
||||
"""Run edge detection with debug overlays."""
|
||||
payload = {
|
||||
"image": _encode_image(image),
|
||||
"edge_canny_low": edge_canny_low,
|
||||
"edge_canny_high": edge_canny_high,
|
||||
"edge_hough_threshold": edge_hough_threshold,
|
||||
"edge_hough_min_length": edge_hough_min_length,
|
||||
"edge_hough_max_gap": edge_hough_max_gap,
|
||||
"edge_pair_max_distance": edge_pair_max_distance,
|
||||
"edge_pair_min_distance": edge_pair_min_distance,
|
||||
}
|
||||
|
||||
resp = self.session.post(
|
||||
f"{self.base_url}/detect_edges/debug",
|
||||
json=payload,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
data = resp.json()
|
||||
regions = []
|
||||
for r in data.get("regions", []):
|
||||
region = RegionResult(
|
||||
x=r["x"], y=r["y"], w=r["w"], h=r["h"],
|
||||
confidence=r["confidence"], label=r["label"],
|
||||
)
|
||||
regions.append(region)
|
||||
|
||||
return RegionDebugResult(
|
||||
regions=regions,
|
||||
edge_overlay_b64=data.get("edge_overlay_b64", ""),
|
||||
lines_overlay_b64=data.get("lines_overlay_b64", ""),
|
||||
horizontal_count=data.get("horizontal_count", 0),
|
||||
pair_count=data.get("pair_count", 0),
|
||||
)
|
||||
|
||||
def load_model(self, model: str, quantization: str = "fp16") -> None:
|
||||
"""Request the server to load a model into VRAM."""
|
||||
self.session.post(
|
||||
|
||||
@@ -38,6 +38,27 @@ class VLMResult:
|
||||
reasoning: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegionResult:
|
||||
"""A candidate region from CV analysis."""
|
||||
x: int
|
||||
y: int
|
||||
w: int
|
||||
h: int
|
||||
confidence: float
|
||||
label: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegionDebugResult:
|
||||
"""CV region analysis with debug overlays."""
|
||||
regions: list[RegionResult] = field(default_factory=list)
|
||||
edge_overlay_b64: str = ""
|
||||
lines_overlay_b64: str = ""
|
||||
horizontal_count: int = 0
|
||||
pair_count: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""Info about a loaded model."""
|
||||
|
||||
@@ -9,6 +9,19 @@ from .base import (
|
||||
)
|
||||
from .soccer import SoccerBroadcastProfile
|
||||
|
||||
_PROFILES: dict[str, type] = {
|
||||
"soccer_broadcast": SoccerBroadcastProfile,
|
||||
}
|
||||
|
||||
|
||||
def get_profile(name: str) -> ContentTypeProfile:
|
||||
"""Get a profile instance by name."""
|
||||
cls = _PROFILES.get(name)
|
||||
if cls is None:
|
||||
raise ValueError(f"Unknown profile: {name!r}. Available: {list(_PROFILES)}")
|
||||
return cls()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ContentTypeProfile",
|
||||
"CropContext",
|
||||
@@ -18,4 +31,5 @@ __all__ = [
|
||||
"ResolverConfig",
|
||||
"SceneFilterConfig",
|
||||
"SoccerBroadcastProfile",
|
||||
"get_profile",
|
||||
]
|
||||
|
||||
@@ -44,6 +44,19 @@ class ResolverConfig:
|
||||
fuzzy_threshold: int = 75
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegionAnalysisConfig:
|
||||
enabled: bool = True
|
||||
# Edge detection (Canny + HoughLinesP)
|
||||
edge_canny_low: int = 50
|
||||
edge_canny_high: int = 150
|
||||
edge_hough_threshold: int = 80
|
||||
edge_hough_min_length: int = 100
|
||||
edge_hough_max_gap: int = 10
|
||||
edge_pair_max_distance: int = 200
|
||||
edge_pair_min_distance: int = 15
|
||||
|
||||
|
||||
@dataclass
|
||||
class CropContext:
|
||||
image: bytes
|
||||
@@ -56,6 +69,7 @@ class ContentTypeProfile(Protocol):
|
||||
|
||||
def frame_extraction_config(self) -> FrameExtractionConfig: ...
|
||||
def scene_filter_config(self) -> SceneFilterConfig: ...
|
||||
def region_analysis_config(self) -> RegionAnalysisConfig: ...
|
||||
def detection_config(self) -> DetectionConfig: ...
|
||||
def ocr_config(self) -> OCRConfig: ...
|
||||
def resolver_config(self) -> ResolverConfig: ...
|
||||
|
||||
@@ -9,6 +9,7 @@ from .base import (
|
||||
DetectionConfig,
|
||||
FrameExtractionConfig,
|
||||
OCRConfig,
|
||||
RegionAnalysisConfig,
|
||||
ResolverConfig,
|
||||
SceneFilterConfig,
|
||||
)
|
||||
@@ -23,6 +24,17 @@ class SoccerBroadcastProfile:
|
||||
def scene_filter_config(self) -> SceneFilterConfig:
|
||||
return SceneFilterConfig(hamming_threshold=8, enabled=True)
|
||||
|
||||
def region_analysis_config(self) -> RegionAnalysisConfig:
|
||||
return RegionAnalysisConfig(
|
||||
edge_canny_low=50,
|
||||
edge_canny_high=150,
|
||||
edge_hough_threshold=80,
|
||||
edge_hough_min_length=100,
|
||||
edge_hough_max_gap=10,
|
||||
edge_pair_max_distance=200,
|
||||
edge_pair_min_distance=15,
|
||||
)
|
||||
|
||||
def detection_config(self) -> DetectionConfig:
|
||||
return DetectionConfig(
|
||||
model_name="yolov8n.pt",
|
||||
|
||||
@@ -34,6 +34,7 @@ class BoundingBoxEvent(BaseModel):
|
||||
label: str
|
||||
resolved_brand: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
stage: Optional[str] = None
|
||||
|
||||
class BrandSummary(BaseModel):
|
||||
"""Per-brand stats in the final report."""
|
||||
@@ -54,6 +55,7 @@ class StatsUpdate(BaseModel):
|
||||
"""Funnel statistics snapshot. SSE event: stats_update"""
|
||||
frames_extracted: int = 0
|
||||
frames_after_scene_filter: int = 0
|
||||
cv_regions_detected: int = 0
|
||||
regions_detected: int = 0
|
||||
regions_resolved_by_ocr: int = 0
|
||||
regions_escalated_to_local_vlm: int = 0
|
||||
|
||||
174
detect/stages/edge_detector.py
Normal file
174
detect/stages/edge_detector.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""
|
||||
Stage — Edge Detection
|
||||
|
||||
Canny + HoughLinesP to find horizontal line pairs that bound
|
||||
advertising hoardings. Pure OpenCV, no ML models.
|
||||
|
||||
Two modes:
|
||||
- Remote: calls GPU inference server over HTTP
|
||||
- Local: imports cv2 directly (OpenCV on same machine)
|
||||
|
||||
Emits frame_update events with bounding boxes for the frame viewer.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import time
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from detect import emit
|
||||
from detect.models import BoundingBox, Frame
|
||||
from detect.profiles.base import RegionAnalysisConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _frame_to_b64(frame: Frame) -> str:
|
||||
"""Encode frame as base64 JPEG for SSE frame_update events."""
|
||||
img = Image.fromarray(frame.image)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG", quality=70)
|
||||
return base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
|
||||
def _detect_remote(
|
||||
frame: Frame,
|
||||
config: RegionAnalysisConfig,
|
||||
inference_url: str,
|
||||
job_id: str = "",
|
||||
log_level: str = "INFO",
|
||||
) -> list[BoundingBox]:
|
||||
"""Call the inference server over HTTP."""
|
||||
from detect.inference import InferenceClient
|
||||
|
||||
client = InferenceClient(
|
||||
base_url=inference_url, job_id=job_id, log_level=log_level,
|
||||
)
|
||||
results = client.detect_edges(
|
||||
image=frame.image,
|
||||
edge_canny_low=config.edge_canny_low,
|
||||
edge_canny_high=config.edge_canny_high,
|
||||
edge_hough_threshold=config.edge_hough_threshold,
|
||||
edge_hough_min_length=config.edge_hough_min_length,
|
||||
edge_hough_max_gap=config.edge_hough_max_gap,
|
||||
edge_pair_max_distance=config.edge_pair_max_distance,
|
||||
edge_pair_min_distance=config.edge_pair_min_distance,
|
||||
)
|
||||
boxes = []
|
||||
for r in results:
|
||||
box = BoundingBox(
|
||||
x=r.x, y=r.y, w=r.w, h=r.h,
|
||||
confidence=r.confidence, label=r.label,
|
||||
)
|
||||
boxes.append(box)
|
||||
return boxes
|
||||
|
||||
|
||||
_cv_edges_mod = None
|
||||
|
||||
|
||||
def _load_cv_edges():
|
||||
"""Load edges module directly — gpu/models/__init__.py has GPU-container-only imports."""
|
||||
global _cv_edges_mod
|
||||
if _cv_edges_mod is None:
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
|
||||
spec = importlib.util.spec_from_file_location("cv_edges", Path("gpu/models/cv/edges.py"))
|
||||
_cv_edges_mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(_cv_edges_mod)
|
||||
return _cv_edges_mod
|
||||
|
||||
|
||||
def _detect_local(frame: Frame, config: RegionAnalysisConfig) -> list[BoundingBox]:
|
||||
"""Run edge detection in-process (requires opencv-python)."""
|
||||
detect_edges_fn = _load_cv_edges().detect_edges
|
||||
|
||||
edge_results = detect_edges_fn(
|
||||
frame.image,
|
||||
canny_low=config.edge_canny_low,
|
||||
canny_high=config.edge_canny_high,
|
||||
hough_threshold=config.edge_hough_threshold,
|
||||
hough_min_length=config.edge_hough_min_length,
|
||||
hough_max_gap=config.edge_hough_max_gap,
|
||||
pair_max_distance=config.edge_pair_max_distance,
|
||||
pair_min_distance=config.edge_pair_min_distance,
|
||||
)
|
||||
|
||||
boxes = []
|
||||
for r in edge_results:
|
||||
box = BoundingBox(
|
||||
x=r["x"], y=r["y"], w=r["w"], h=r["h"],
|
||||
confidence=r["confidence"], label=r["label"],
|
||||
)
|
||||
boxes.append(box)
|
||||
return boxes
|
||||
|
||||
|
||||
def detect_edge_regions(
|
||||
frames: list[Frame],
|
||||
config: RegionAnalysisConfig,
|
||||
inference_url: str | None = None,
|
||||
job_id: str | None = None,
|
||||
) -> dict[int, list[BoundingBox]]:
|
||||
"""
|
||||
Run edge detection on all frames.
|
||||
|
||||
Returns a dict mapping frame sequence → list of bounding boxes.
|
||||
"""
|
||||
if not config.enabled:
|
||||
emit.log(job_id, "EdgeDetection", "INFO", "Edge detection disabled, skipping")
|
||||
return {}
|
||||
|
||||
mode = "remote" if inference_url else "local"
|
||||
emit.log(job_id, "EdgeDetection", "INFO",
|
||||
f"Detecting edges in {len(frames)} frames (mode={mode})")
|
||||
|
||||
all_boxes: dict[int, list[BoundingBox]] = {}
|
||||
total_regions = 0
|
||||
|
||||
for i, frame in enumerate(frames):
|
||||
t0 = time.monotonic()
|
||||
if inference_url:
|
||||
from detect.emit import _run_log_level
|
||||
boxes = _detect_remote(
|
||||
frame, config, inference_url,
|
||||
job_id=job_id or "", log_level=_run_log_level,
|
||||
)
|
||||
else:
|
||||
boxes = _detect_local(frame, config)
|
||||
analysis_ms = (time.monotonic() - t0) * 1000
|
||||
|
||||
all_boxes[frame.sequence] = boxes
|
||||
total_regions += len(boxes)
|
||||
|
||||
emit.log(job_id, "EdgeDetection", "DEBUG",
|
||||
f"Frame {frame.sequence}: {len(boxes)} regions in {analysis_ms:.0f}ms"
|
||||
+ (f" [{', '.join(b.label for b in boxes)}]" if boxes else ""))
|
||||
|
||||
if boxes and job_id:
|
||||
box_dicts = [
|
||||
{
|
||||
"x": b.x, "y": b.y, "w": b.w, "h": b.h,
|
||||
"confidence": b.confidence, "label": b.label,
|
||||
"stage": "detect_edges",
|
||||
}
|
||||
for b in boxes
|
||||
]
|
||||
emit.frame_update(
|
||||
job_id,
|
||||
frame_ref=frame.sequence,
|
||||
timestamp=frame.timestamp,
|
||||
jpeg_b64=_frame_to_b64(frame),
|
||||
boxes=box_dicts,
|
||||
)
|
||||
|
||||
emit.log(job_id, "EdgeDetection", "INFO",
|
||||
f"Found {total_regions} edge regions across {len(frames)} frames")
|
||||
emit.stats(job_id, cv_regions_detected=total_regions)
|
||||
|
||||
return all_boxes
|
||||
@@ -3,6 +3,7 @@ Stage registry — registers all built-in stages.
|
||||
|
||||
Split by category:
|
||||
preprocessing.py — extract_frames, filter_scenes
|
||||
cv_analysis.py — detect_edges (+ future: detect_contours, detect_color, merge_regions)
|
||||
detection.py — detect_objects, run_ocr
|
||||
resolution.py — match_brands
|
||||
escalation.py — escalate_vlm, escalate_cloud
|
||||
@@ -11,6 +12,7 @@ Split by category:
|
||||
"""
|
||||
|
||||
from . import preprocessing
|
||||
from . import cv_analysis
|
||||
from . import detection
|
||||
from . import resolution
|
||||
from . import escalation
|
||||
@@ -19,6 +21,7 @@ from . import output
|
||||
|
||||
def register_all():
|
||||
preprocessing.register()
|
||||
cv_analysis.register()
|
||||
detection.register()
|
||||
resolution.register()
|
||||
escalation.register()
|
||||
|
||||
45
detect/stages/registry/cv_analysis.py
Normal file
45
detect/stages/registry/cv_analysis.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Registration for CV analysis stages: edge detection."""
|
||||
|
||||
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
|
||||
from ._serializers import serialize_dataclass_list, deserialize_bounding_box
|
||||
|
||||
|
||||
def _ser_regions(state: dict, job_id: str) -> dict:
|
||||
regions = state.get("edge_regions_by_frame", {})
|
||||
serialized = {
|
||||
str(seq): serialize_dataclass_list(bl) for seq, bl in regions.items()
|
||||
}
|
||||
return {"edge_regions_by_frame": serialized}
|
||||
|
||||
|
||||
def _deser_regions(data: dict, job_id: str) -> dict:
|
||||
regions = {}
|
||||
for seq_str, box_dicts in data.get("edge_regions_by_frame", {}).items():
|
||||
regions[int(seq_str)] = [deserialize_bounding_box(b) for b in box_dicts]
|
||||
return {"edge_regions_by_frame": regions}
|
||||
|
||||
|
||||
def register():
|
||||
edge_detection = StageDefinition(
|
||||
name="detect_edges",
|
||||
label="Edge Detection",
|
||||
description="Canny + HoughLinesP — find horizontal line pairs (hoarding boundaries)",
|
||||
category="cv_analysis",
|
||||
io=StageIO(
|
||||
reads=["filtered_frames"],
|
||||
writes=["edge_regions_by_frame"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField("enabled", "bool", True, "Enable region analysis"),
|
||||
StageConfigField("edge_canny_low", "int", 50, "Canny low threshold", min=0, max=255),
|
||||
StageConfigField("edge_canny_high", "int", 150, "Canny high threshold", min=0, max=255),
|
||||
StageConfigField("edge_hough_threshold", "int", 80, "Hough accumulator threshold", min=1, max=500),
|
||||
StageConfigField("edge_hough_min_length", "int", 100, "Min line length (px)", min=10, max=2000),
|
||||
StageConfigField("edge_hough_max_gap", "int", 10, "Max line gap (px)", min=1, max=100),
|
||||
StageConfigField("edge_pair_max_distance", "int", 200, "Max distance between line pair (px)", min=10, max=500),
|
||||
StageConfigField("edge_pair_min_distance", "int", 15, "Min distance between line pair (px)", min=5, max=200),
|
||||
],
|
||||
serialize_fn=_ser_regions,
|
||||
deserialize_fn=_deser_regions,
|
||||
)
|
||||
register_stage(edge_detection)
|
||||
@@ -22,6 +22,7 @@ class DetectState(TypedDict, total=False):
|
||||
# Stage outputs
|
||||
frames: list[Frame]
|
||||
filtered_frames: list[Frame]
|
||||
edge_regions_by_frame: dict[int, list[BoundingBox]]
|
||||
boxes_by_frame: dict[int, list[BoundingBox]]
|
||||
preprocessed_crops: dict # "{frame_seq}_{box_idx}" → np.ndarray
|
||||
text_candidates: list[TextCandidate]
|
||||
|
||||
Reference in New Issue
Block a user