phase 7
This commit is contained in:
@@ -18,6 +18,8 @@ from detect.state import DetectState
|
|||||||
from detect.stages.frame_extractor import extract_frames
|
from detect.stages.frame_extractor import extract_frames
|
||||||
from detect.stages.scene_filter import scene_filter
|
from detect.stages.scene_filter import scene_filter
|
||||||
from detect.stages.yolo_detector import detect_objects
|
from detect.stages.yolo_detector import detect_objects
|
||||||
|
from detect.stages.ocr_stage import run_ocr
|
||||||
|
from detect.stages.brand_resolver import resolve_brands
|
||||||
|
|
||||||
INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode
|
INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode
|
||||||
|
|
||||||
@@ -101,23 +103,43 @@ def node_detect_objects(state: DetectState) -> dict:
|
|||||||
stats.regions_detected = sum(len(boxes) for boxes in all_boxes.values())
|
stats.regions_detected = sum(len(boxes) for boxes in all_boxes.values())
|
||||||
|
|
||||||
_emit_transition(state, "detect_objects", "done")
|
_emit_transition(state, "detect_objects", "done")
|
||||||
return {"stats": stats}
|
return {"boxes_by_frame": all_boxes, "stats": stats}
|
||||||
|
|
||||||
|
|
||||||
def node_run_ocr(state: DetectState) -> dict:
|
def node_run_ocr(state: DetectState) -> dict:
|
||||||
_emit_transition(state, "run_ocr", "running")
|
_emit_transition(state, "run_ocr", "running")
|
||||||
|
|
||||||
|
profile = _get_profile(state)
|
||||||
|
config = profile.ocr_config()
|
||||||
|
frames = state.get("filtered_frames", [])
|
||||||
|
boxes = state.get("boxes_by_frame", {})
|
||||||
job_id = state.get("job_id")
|
job_id = state.get("job_id")
|
||||||
emit.log(job_id, "OCRStage", "INFO", "Stub: OCR not yet implemented")
|
|
||||||
|
candidates = run_ocr(frames, boxes, config, inference_url=INFERENCE_URL, job_id=job_id)
|
||||||
|
|
||||||
|
stats = state.get("stats", PipelineStats())
|
||||||
|
stats.regions_resolved_by_ocr = len(candidates)
|
||||||
|
|
||||||
_emit_transition(state, "run_ocr", "done")
|
_emit_transition(state, "run_ocr", "done")
|
||||||
return {}
|
return {"text_candidates": candidates, "stats": stats}
|
||||||
|
|
||||||
|
|
||||||
def node_match_brands(state: DetectState) -> dict:
|
def node_match_brands(state: DetectState) -> dict:
|
||||||
_emit_transition(state, "match_brands", "running")
|
_emit_transition(state, "match_brands", "running")
|
||||||
|
|
||||||
|
profile = _get_profile(state)
|
||||||
|
dictionary = profile.brand_dictionary()
|
||||||
|
resolver_config = profile.resolver_config()
|
||||||
|
candidates = state.get("text_candidates", [])
|
||||||
job_id = state.get("job_id")
|
job_id = state.get("job_id")
|
||||||
emit.log(job_id, "BrandResolver", "INFO", "Stub: brand matching not yet implemented")
|
|
||||||
|
matched, unresolved = resolve_brands(
|
||||||
|
candidates, dictionary, resolver_config,
|
||||||
|
content_type=profile.name, job_id=job_id,
|
||||||
|
)
|
||||||
|
|
||||||
_emit_transition(state, "match_brands", "done")
|
_emit_transition(state, "match_brands", "done")
|
||||||
return {"detections": []}
|
return {"detections": matched, "unresolved_candidates": unresolved}
|
||||||
|
|
||||||
|
|
||||||
def node_escalate_vlm(state: DetectState) -> dict:
|
def node_escalate_vlm(state: DetectState) -> dict:
|
||||||
|
|||||||
121
detect/stages/brand_resolver.py
Normal file
121
detect/stages/brand_resolver.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
"""
|
||||||
|
Stage 5 — Brand Resolver
|
||||||
|
|
||||||
|
Matches OCR text against the profile's brand dictionary.
|
||||||
|
Uses exact matching first, then fuzzy matching (rapidfuzz) as fallback.
|
||||||
|
Emits detection events for confirmed brands.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from rapidfuzz import fuzz
|
||||||
|
|
||||||
|
from detect import emit
|
||||||
|
from detect.models import BrandDetection, TextCandidate
|
||||||
|
from detect.profiles.base import BrandDictionary, ResolverConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize(text: str) -> str:
|
||||||
|
"""Normalize text for matching."""
|
||||||
|
return text.strip().lower()
|
||||||
|
|
||||||
|
|
||||||
|
def _exact_match(text: str, dictionary: BrandDictionary) -> str | None:
|
||||||
|
"""Try exact match against all aliases."""
|
||||||
|
normalized = _normalize(text)
|
||||||
|
for canonical, aliases in dictionary.brands.items():
|
||||||
|
if normalized == _normalize(canonical):
|
||||||
|
return canonical
|
||||||
|
for alias in aliases:
|
||||||
|
if normalized == _normalize(alias):
|
||||||
|
return canonical
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _fuzzy_match(text: str, dictionary: BrandDictionary, threshold: int) -> tuple[str | None, int]:
|
||||||
|
"""Try fuzzy match, return (brand, score) or (None, 0)."""
|
||||||
|
normalized = _normalize(text)
|
||||||
|
best_brand = None
|
||||||
|
best_score = 0
|
||||||
|
|
||||||
|
for canonical, aliases in dictionary.brands.items():
|
||||||
|
all_names = [canonical] + aliases
|
||||||
|
for name in all_names:
|
||||||
|
score = fuzz.ratio(normalized, _normalize(name))
|
||||||
|
if score > best_score and score >= threshold:
|
||||||
|
best_score = score
|
||||||
|
best_brand = canonical
|
||||||
|
|
||||||
|
return best_brand, best_score
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_brands(
|
||||||
|
candidates: list[TextCandidate],
|
||||||
|
dictionary: BrandDictionary,
|
||||||
|
config: ResolverConfig,
|
||||||
|
content_type: str = "",
|
||||||
|
job_id: str | None = None,
|
||||||
|
) -> tuple[list[BrandDetection], list[TextCandidate]]:
|
||||||
|
"""
|
||||||
|
Match text candidates against the brand dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- matched: list of BrandDetection for confirmed brands
|
||||||
|
- unresolved: list of TextCandidate that couldn't be matched
|
||||||
|
"""
|
||||||
|
emit.log(job_id, "BrandResolver", "INFO",
|
||||||
|
f"Matching {len(candidates)} candidates against "
|
||||||
|
f"{len(dictionary.brands)} brands (fuzzy_threshold={config.fuzzy_threshold})")
|
||||||
|
|
||||||
|
matched: list[BrandDetection] = []
|
||||||
|
unresolved: list[TextCandidate] = []
|
||||||
|
exact_count = 0
|
||||||
|
fuzzy_count = 0
|
||||||
|
|
||||||
|
for candidate in candidates:
|
||||||
|
# Try exact match first
|
||||||
|
brand = _exact_match(candidate.text, dictionary)
|
||||||
|
source = "ocr"
|
||||||
|
|
||||||
|
if brand:
|
||||||
|
exact_count += 1
|
||||||
|
else:
|
||||||
|
# Try fuzzy match
|
||||||
|
brand, score = _fuzzy_match(candidate.text, dictionary, config.fuzzy_threshold)
|
||||||
|
if brand:
|
||||||
|
fuzzy_count += 1
|
||||||
|
|
||||||
|
if brand:
|
||||||
|
detection = BrandDetection(
|
||||||
|
brand=brand,
|
||||||
|
timestamp=candidate.frame.timestamp,
|
||||||
|
duration=0.5,
|
||||||
|
confidence=candidate.ocr_confidence,
|
||||||
|
source=source,
|
||||||
|
bbox=candidate.bbox,
|
||||||
|
frame_ref=candidate.frame.sequence,
|
||||||
|
content_type=content_type,
|
||||||
|
)
|
||||||
|
matched.append(detection)
|
||||||
|
|
||||||
|
emit.detection(
|
||||||
|
job_id,
|
||||||
|
brand=brand,
|
||||||
|
confidence=candidate.ocr_confidence,
|
||||||
|
source=source,
|
||||||
|
timestamp=candidate.frame.timestamp,
|
||||||
|
content_type=content_type,
|
||||||
|
frame_ref=candidate.frame.sequence,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
unresolved.append(candidate)
|
||||||
|
|
||||||
|
emit.log(job_id, "BrandResolver", "INFO",
|
||||||
|
f"Exact: {exact_count}, Fuzzy: {fuzzy_count}, "
|
||||||
|
f"Unresolved: {len(unresolved)} → escalating to VLM")
|
||||||
|
|
||||||
|
return matched, unresolved
|
||||||
130
detect/stages/ocr_stage.py
Normal file
130
detect/stages/ocr_stage.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
"""
|
||||||
|
Stage 4 — OCR
|
||||||
|
|
||||||
|
Reads text from detected regions (YOLO bounding box crops).
|
||||||
|
Two modes:
|
||||||
|
- remote: calls inference server over HTTP (separate GPU box, or localhost)
|
||||||
|
- local: runs PaddleOCR in-process (single-box setup with enough VRAM)
|
||||||
|
|
||||||
|
The mode is selected by whether inference_url is provided.
|
||||||
|
Model instances are cached at module level so they survive across pipeline runs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from detect import emit
|
||||||
|
from detect.models import BoundingBox, Frame, TextCandidate
|
||||||
|
from detect.profiles.base import OCRConfig
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Module-level cache — avoids reloading the model for every crop or pipeline run
|
||||||
|
_local_ocr_cache: dict[str, object] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _crop_region(frame: Frame, box: BoundingBox) -> np.ndarray:
|
||||||
|
h, w = frame.image.shape[:2]
|
||||||
|
x1 = max(0, box.x)
|
||||||
|
y1 = max(0, box.y)
|
||||||
|
x2 = min(w, box.x + box.w)
|
||||||
|
y2 = min(h, box.y + box.h)
|
||||||
|
return frame.image[y1:y2, x1:x2]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_local_model(lang: str):
|
||||||
|
if lang not in _local_ocr_cache:
|
||||||
|
from paddleocr import PaddleOCR
|
||||||
|
logger.info("Loading PaddleOCR locally (lang=%s)", lang)
|
||||||
|
_local_ocr_cache[lang] = PaddleOCR(lang=lang)
|
||||||
|
return _local_ocr_cache[lang]
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_ocr_raw(raw, min_confidence: float) -> list[dict]:
|
||||||
|
"""Parse PaddleOCR 3.x result — handles dict-based and nested-list layouts."""
|
||||||
|
results = []
|
||||||
|
for page in (raw or []):
|
||||||
|
if not page:
|
||||||
|
continue
|
||||||
|
if isinstance(page, dict):
|
||||||
|
for text, confidence in zip(page.get("rec_texts", []), page.get("rec_scores", [])):
|
||||||
|
if float(confidence) >= min_confidence:
|
||||||
|
results.append({"text": text, "confidence": float(confidence)})
|
||||||
|
continue
|
||||||
|
for line in page:
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
rec = line[1]
|
||||||
|
if isinstance(rec, (list, tuple)) and len(rec) >= 2:
|
||||||
|
text, confidence = rec[0], rec[1]
|
||||||
|
if float(confidence) >= min_confidence:
|
||||||
|
results.append({"text": text, "confidence": float(confidence)})
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def run_ocr(
|
||||||
|
frames: list[Frame],
|
||||||
|
boxes_by_frame: dict[int, list[BoundingBox]],
|
||||||
|
config: OCRConfig,
|
||||||
|
inference_url: str | None = None,
|
||||||
|
job_id: str | None = None,
|
||||||
|
) -> list[TextCandidate]:
|
||||||
|
"""
|
||||||
|
Run OCR on cropped regions from YOLO detections.
|
||||||
|
|
||||||
|
inference_url=None → local in-process PaddleOCR (single-box)
|
||||||
|
inference_url=str → remote inference server (split or localhost)
|
||||||
|
"""
|
||||||
|
total_regions = sum(len(boxes) for boxes in boxes_by_frame.values())
|
||||||
|
mode = "remote" if inference_url else "local"
|
||||||
|
|
||||||
|
emit.log(job_id, "OCRStage", "INFO",
|
||||||
|
f"Running OCR on {total_regions} regions (mode={mode})")
|
||||||
|
|
||||||
|
# Build these once per pipeline run, not per crop
|
||||||
|
if inference_url:
|
||||||
|
from detect.inference import InferenceClient
|
||||||
|
client = InferenceClient(base_url=inference_url)
|
||||||
|
else:
|
||||||
|
model = _get_local_model(config.languages[0])
|
||||||
|
|
||||||
|
frame_map = {f.sequence: f for f in frames}
|
||||||
|
candidates: list[TextCandidate] = []
|
||||||
|
|
||||||
|
for seq, boxes in boxes_by_frame.items():
|
||||||
|
frame = frame_map.get(seq)
|
||||||
|
if not frame:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for box in boxes:
|
||||||
|
crop = _crop_region(frame, box)
|
||||||
|
if crop.size == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if inference_url:
|
||||||
|
raw_results = client.ocr(image=crop, languages=config.languages)
|
||||||
|
texts = [{"text": r.text, "confidence": r.confidence} for r in raw_results]
|
||||||
|
else:
|
||||||
|
raw = model.ocr(crop)
|
||||||
|
texts = _parse_ocr_raw(raw, config.min_confidence)
|
||||||
|
|
||||||
|
for t in texts:
|
||||||
|
candidates.append(TextCandidate(
|
||||||
|
frame=frame,
|
||||||
|
bbox=box,
|
||||||
|
text=t["text"],
|
||||||
|
ocr_confidence=t["confidence"],
|
||||||
|
))
|
||||||
|
|
||||||
|
emit.log(job_id, "OCRStage", "INFO",
|
||||||
|
f"Extracted text from {len(candidates)} regions")
|
||||||
|
emit.stats(job_id, regions_resolved_by_ocr=len(candidates))
|
||||||
|
|
||||||
|
return candidates
|
||||||
@@ -9,7 +9,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
|
|
||||||
from detect.models import BrandDetection, DetectionReport, Frame, PipelineStats
|
from detect.models import BoundingBox, BrandDetection, DetectionReport, Frame, PipelineStats, TextCandidate
|
||||||
|
|
||||||
|
|
||||||
class DetectState(TypedDict, total=False):
|
class DetectState(TypedDict, total=False):
|
||||||
@@ -21,6 +21,9 @@ class DetectState(TypedDict, total=False):
|
|||||||
# Stage outputs
|
# Stage outputs
|
||||||
frames: list[Frame]
|
frames: list[Frame]
|
||||||
filtered_frames: list[Frame]
|
filtered_frames: list[Frame]
|
||||||
|
boxes_by_frame: dict[int, list[BoundingBox]]
|
||||||
|
text_candidates: list[TextCandidate]
|
||||||
|
unresolved_candidates: list[TextCandidate]
|
||||||
detections: list[BrandDetection]
|
detections: list[BrandDetection]
|
||||||
report: DetectionReport
|
report: DetectionReport
|
||||||
|
|
||||||
|
|||||||
@@ -10,5 +10,9 @@ STRATEGY=sequential # sequential | concurrent | auto
|
|||||||
YOLO_MODEL=yolov8n.pt
|
YOLO_MODEL=yolov8n.pt
|
||||||
YOLO_CONFIDENCE=0.3
|
YOLO_CONFIDENCE=0.3
|
||||||
|
|
||||||
|
# OCR
|
||||||
|
OCR_LANGUAGES=en,es
|
||||||
|
OCR_MIN_CONFIDENCE=0.5
|
||||||
|
|
||||||
# Device
|
# Device
|
||||||
DEVICE=auto # auto | cpu | cuda | cuda:0
|
DEVICE=auto # auto | cpu | cuda | cuda:0
|
||||||
|
|||||||
0
gpu/__init__.py
Normal file
0
gpu/__init__.py
Normal file
39
gpu/config.py
Normal file
39
gpu/config.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
"""
|
||||||
|
Runtime config — loaded from env, mutable via API.
|
||||||
|
|
||||||
|
The UI config panel is just a visual editor for these same values.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
_config = {
|
||||||
|
"device": os.environ.get("DEVICE", "auto"),
|
||||||
|
"yolo_model": os.environ.get("YOLO_MODEL", "yolov8n.pt"),
|
||||||
|
"yolo_confidence": float(os.environ.get("YOLO_CONFIDENCE", "0.3")),
|
||||||
|
"vram_budget_mb": int(os.environ.get("VRAM_BUDGET_MB", "10240")),
|
||||||
|
"strategy": os.environ.get("STRATEGY", "sequential"),
|
||||||
|
"ocr_languages": os.environ.get("OCR_LANGUAGES", "en").split(","),
|
||||||
|
"ocr_min_confidence": float(os.environ.get("OCR_MIN_CONFIDENCE", "0.5")),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_config() -> dict:
|
||||||
|
return _config
|
||||||
|
|
||||||
|
|
||||||
|
def update_config(changes: dict) -> dict:
|
||||||
|
_config.update(changes)
|
||||||
|
return _config
|
||||||
|
|
||||||
|
|
||||||
|
def get_device() -> str:
|
||||||
|
device = _config["device"]
|
||||||
|
if device != "auto":
|
||||||
|
return device
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
except ImportError:
|
||||||
|
return "cpu"
|
||||||
5
gpu/models/__init__.py
Normal file
5
gpu/models/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from . import registry
|
||||||
|
from .yolo import detect
|
||||||
|
from .ocr import ocr
|
||||||
|
|
||||||
|
__all__ = ["registry", "detect", "ocr"]
|
||||||
105
gpu/models/ocr.py
Normal file
105
gpu/models/ocr.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
"""PaddleOCR 3.x text extraction wrapper."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from models import registry
|
||||||
|
from config import get_config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _load(languages: list[str]):
|
||||||
|
from paddleocr import PaddleOCR
|
||||||
|
key = f"ocr_{'_'.join(languages)}"
|
||||||
|
model = PaddleOCR(lang=languages[0])
|
||||||
|
registry.put(key, model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _get(languages: list[str] | None = None):
|
||||||
|
langs = languages or get_config()["ocr_languages"]
|
||||||
|
key = f"ocr_{'_'.join(langs)}"
|
||||||
|
model = registry.get(key)
|
||||||
|
if model is None:
|
||||||
|
model = _load(langs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_raw(raw) -> list[tuple[list, str, float]]:
|
||||||
|
"""
|
||||||
|
Parse PaddleOCR output into (points, text, confidence) tuples.
|
||||||
|
|
||||||
|
PaddleOCR 3.x changed the result format. Two known layouts:
|
||||||
|
|
||||||
|
Layout A — dict-based (new pipeline API):
|
||||||
|
raw = [{'rec_texts': [...], 'rec_scores': [...], 'dt_polys': [...]}]
|
||||||
|
|
||||||
|
Layout B — nested list (2.x compat / some 3.x builds):
|
||||||
|
raw = [[ [points, [text, score]], ... ]]
|
||||||
|
raw = [[ [points, [text, score], [cls, cls_score]], ... ]] # with angle cls
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for page in raw:
|
||||||
|
if not page:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Layout A: dict with parallel lists
|
||||||
|
if isinstance(page, dict):
|
||||||
|
texts = page.get("rec_texts", [])
|
||||||
|
scores = page.get("rec_scores", [])
|
||||||
|
polys = page.get("dt_polys", [])
|
||||||
|
for points, text, confidence in zip(polys, texts, scores):
|
||||||
|
results.append((points, text, float(confidence)))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Layout B: list of per-line entries
|
||||||
|
for line in page:
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# line[0] is always the polygon points
|
||||||
|
points = line[0]
|
||||||
|
|
||||||
|
# line[1] is [text, score] — ignore any extra elements (angle cls etc.)
|
||||||
|
rec = line[1]
|
||||||
|
if isinstance(rec, (list, tuple)) and len(rec) >= 2:
|
||||||
|
text, confidence = rec[0], rec[1]
|
||||||
|
else:
|
||||||
|
logger.warning("Unexpected OCR line format: %s", line)
|
||||||
|
continue
|
||||||
|
|
||||||
|
results.append((points, str(text), float(confidence)))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def ocr(image, languages: list[str] | None = None, min_confidence: float | None = None) -> list[dict]:
|
||||||
|
"""Run OCR on an image, return list of text result dicts."""
|
||||||
|
cfg = get_config()
|
||||||
|
min_conf = min_confidence if min_confidence is not None else cfg["ocr_min_confidence"]
|
||||||
|
model = _get(languages)
|
||||||
|
|
||||||
|
raw = model.ocr(image)
|
||||||
|
logger.debug("OCR raw: %s", raw)
|
||||||
|
|
||||||
|
parsed = _parse_raw(raw)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for points, text, confidence in parsed:
|
||||||
|
if confidence < min_conf:
|
||||||
|
continue
|
||||||
|
|
||||||
|
xs = [p[0] for p in points]
|
||||||
|
ys = [p[1] for p in points]
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"text": text,
|
||||||
|
"confidence": confidence,
|
||||||
|
"bbox": [int(min(xs)), int(min(ys)),
|
||||||
|
int(max(xs) - min(xs)), int(max(ys) - min(ys))],
|
||||||
|
})
|
||||||
|
|
||||||
|
return results
|
||||||
37
gpu/models/registry.py
Normal file
37
gpu/models/registry.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
"""
|
||||||
|
Model registry — manages loaded models and VRAM lifecycle.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_models: dict[str, object] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def get(name: str) -> object | None:
|
||||||
|
return _models.get(name)
|
||||||
|
|
||||||
|
|
||||||
|
def put(name: str, model: object) -> None:
|
||||||
|
_models[name] = model
|
||||||
|
logger.info("Loaded %s", name)
|
||||||
|
|
||||||
|
|
||||||
|
def unload(name: str) -> bool:
|
||||||
|
if name in _models:
|
||||||
|
del _models[name]
|
||||||
|
logger.info("Unloaded %s", name)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def loaded() -> list[str]:
|
||||||
|
return list(_models.keys())
|
||||||
|
|
||||||
|
|
||||||
|
def clear() -> None:
|
||||||
|
_models.clear()
|
||||||
|
logger.info("All models unloaded")
|
||||||
54
gpu/models/yolo.py
Normal file
54
gpu/models/yolo.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
"""YOLO object detection model wrapper."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from models import registry
|
||||||
|
from config import get_config, get_device
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _load(model_name: str):
|
||||||
|
from ultralytics import YOLO
|
||||||
|
device = get_device()
|
||||||
|
model = YOLO(model_name)
|
||||||
|
model.to(device)
|
||||||
|
registry.put(model_name, model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _get(model_name: str | None = None):
|
||||||
|
name = model_name or get_config()["yolo_model"]
|
||||||
|
model = registry.get(name)
|
||||||
|
if model is None:
|
||||||
|
model = _load(name)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def detect(image, model_name: str | None = None, confidence: float | None = None, target_classes: list[str] | None = None) -> list[dict]:
|
||||||
|
"""Run YOLO detection, return list of bbox dicts."""
|
||||||
|
cfg = get_config()
|
||||||
|
conf = confidence if confidence is not None else cfg["yolo_confidence"]
|
||||||
|
model = _get(model_name)
|
||||||
|
|
||||||
|
results = model(image, conf=conf, verbose=False)
|
||||||
|
|
||||||
|
detections = []
|
||||||
|
for r in results:
|
||||||
|
for box in r.boxes:
|
||||||
|
x1, y1, x2, y2 = box.xyxy[0].tolist()
|
||||||
|
label = r.names[int(box.cls[0])]
|
||||||
|
|
||||||
|
if target_classes and label not in target_classes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
detections.append({
|
||||||
|
"x": int(x1), "y": int(y1),
|
||||||
|
"w": int(x2 - x1), "h": int(y2 - y1),
|
||||||
|
"confidence": float(box.conf[0]),
|
||||||
|
"label": label,
|
||||||
|
})
|
||||||
|
|
||||||
|
return detections
|
||||||
@@ -1,4 +1,21 @@
|
|||||||
fastapi>=0.109.0
|
fastapi>=0.109.0
|
||||||
uvicorn[standard]>=0.27.0
|
uvicorn[standard]>=0.27.0
|
||||||
ultralytics>=8.0.0
|
rapidfuzz>=3.0.0
|
||||||
Pillow>=10.0.0
|
Pillow>=10.0.0
|
||||||
|
|
||||||
|
# --- GPU-specific installs (mcrn: RTX 3080, CUDA toolkit 12.8) ---
|
||||||
|
#
|
||||||
|
# torch: must be installed from the PyTorch index, NOT from PyPI.
|
||||||
|
# cu126 is the closest build to CUDA 12.8 (no cu128 wheel yet; cu126 is forward-compatible).
|
||||||
|
# Install with:
|
||||||
|
# uv pip install --reinstall torch torchvision --index-url https://download.pytorch.org/whl/cu126
|
||||||
|
#
|
||||||
|
# ultralytics pulls torch as a dependency — reinstall torch after ultralytics to ensure
|
||||||
|
# the correct CUDA build. Mixing the PyPI torch with CUDA 12.8 causes NCCL symbol errors.
|
||||||
|
ultralytics>=8.0.0
|
||||||
|
|
||||||
|
# paddlepaddle-gpu: NOT available on PyPI. Install from PaddlePaddle's package index.
|
||||||
|
# cu126 build works on CUDA 12.8.
|
||||||
|
# Install with:
|
||||||
|
# uv pip install paddlepaddle-gpu==3.0.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||||
|
paddleocr>=3.0.0
|
||||||
|
|||||||
169
gpu/server.py
169
gpu/server.py
@@ -1,16 +1,10 @@
|
|||||||
"""
|
"""
|
||||||
Inference server — thin HTTP wrapper around ML models.
|
Inference server — thin HTTP routes over model wrappers.
|
||||||
|
|
||||||
Runs on the GPU machine. The detection pipeline calls this over HTTP,
|
Config lives in config.py, model logic in models/.
|
||||||
or imports the same logic locally if GPU is on the same machine.
|
This file is just the FastAPI glue.
|
||||||
|
|
||||||
Config is loaded from env on startup, then editable at runtime via
|
|
||||||
GET/PUT /config. The UI config panel is just a visual editor for these
|
|
||||||
same values.
|
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
cd gpu && uvicorn server:app --host 0.0.0.0 --port 8000
|
|
||||||
# or
|
|
||||||
cd gpu && python server.py
|
cd gpu && python server.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -27,45 +21,13 @@ from fastapi import FastAPI, HTTPException
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from config import get_config, get_device, update_config
|
||||||
|
from models import registry
|
||||||
|
from models.yolo import detect as yolo_detect
|
||||||
|
from models.ocr import ocr as ocr_run
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# --- Runtime config (loaded from env, mutable via API) ---
|
|
||||||
_config = {
|
|
||||||
"device": os.environ.get("DEVICE", "auto"),
|
|
||||||
"yolo_model": os.environ.get("YOLO_MODEL", "yolov8n.pt"),
|
|
||||||
"yolo_confidence": float(os.environ.get("YOLO_CONFIDENCE", "0.3")),
|
|
||||||
"vram_budget_mb": int(os.environ.get("VRAM_BUDGET_MB", "10240")),
|
|
||||||
"strategy": os.environ.get("STRATEGY", "sequential"),
|
|
||||||
}
|
|
||||||
|
|
||||||
# --- Model registry ---
|
|
||||||
_models: dict[str, object] = {}
|
|
||||||
|
|
||||||
|
|
||||||
# --- Helpers ---
|
|
||||||
|
|
||||||
def _get_device() -> str:
|
|
||||||
device = _config["device"]
|
|
||||||
if device != "auto":
|
|
||||||
return device
|
|
||||||
try:
|
|
||||||
import torch
|
|
||||||
return "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
except ImportError:
|
|
||||||
return "cpu"
|
|
||||||
|
|
||||||
|
|
||||||
def _get_yolo(model_name: str | None = None):
|
|
||||||
name = model_name or _config["yolo_model"]
|
|
||||||
if name not in _models:
|
|
||||||
from ultralytics import YOLO
|
|
||||||
device = _get_device()
|
|
||||||
logger.info("Loading %s on %s", name, device)
|
|
||||||
model = YOLO(name)
|
|
||||||
model.to(device)
|
|
||||||
_models[name] = model
|
|
||||||
return _models[name]
|
|
||||||
|
|
||||||
|
|
||||||
def _decode_image(b64: str) -> np.ndarray:
|
def _decode_image(b64: str) -> np.ndarray:
|
||||||
data = base64.b64decode(b64)
|
data = base64.b64decode(b64)
|
||||||
@@ -76,9 +38,9 @@ def _decode_image(b64: str) -> np.ndarray:
|
|||||||
# --- Request/Response models ---
|
# --- Request/Response models ---
|
||||||
|
|
||||||
class DetectRequest(BaseModel):
|
class DetectRequest(BaseModel):
|
||||||
image: str # base64 JPEG
|
image: str
|
||||||
model: str | None = None # defaults to config yolo_model
|
model: str | None = None
|
||||||
confidence: float | None = None # defaults to config yolo_confidence
|
confidence: float | None = None
|
||||||
target_classes: list[str] | None = None
|
target_classes: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
@@ -95,23 +57,39 @@ class DetectResponse(BaseModel):
|
|||||||
detections: list[BBox]
|
detections: list[BBox]
|
||||||
|
|
||||||
|
|
||||||
|
class OCRRequest(BaseModel):
|
||||||
|
image: str
|
||||||
|
languages: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class OCRTextResult(BaseModel):
|
||||||
|
text: str
|
||||||
|
confidence: float
|
||||||
|
bbox: list[int]
|
||||||
|
|
||||||
|
|
||||||
|
class OCRResponse(BaseModel):
|
||||||
|
results: list[OCRTextResult]
|
||||||
|
|
||||||
|
|
||||||
class ConfigUpdate(BaseModel):
|
class ConfigUpdate(BaseModel):
|
||||||
"""Partial config update — only provided fields are changed."""
|
|
||||||
device: str | None = None
|
device: str | None = None
|
||||||
yolo_model: str | None = None
|
yolo_model: str | None = None
|
||||||
yolo_confidence: float | None = None
|
yolo_confidence: float | None = None
|
||||||
vram_budget_mb: int | None = None
|
vram_budget_mb: int | None = None
|
||||||
strategy: str | None = None
|
strategy: str | None = None
|
||||||
|
ocr_languages: list[str] | None = None
|
||||||
|
ocr_min_confidence: float | None = None
|
||||||
|
|
||||||
|
|
||||||
# --- App ---
|
# --- App ---
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
logger.info("Inference server starting (device=%s)", _get_device())
|
logger.info("Inference server starting (device=%s)", get_device())
|
||||||
yield
|
yield
|
||||||
logger.info("Inference server shutting down")
|
logger.info("Shutting down")
|
||||||
_models.clear()
|
registry.clear()
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(title="MPR Inference Server", lifespan=lifespan)
|
app = FastAPI(title="MPR Inference Server", lifespan=lifespan)
|
||||||
@@ -119,82 +97,77 @@ app = FastAPI(title="MPR Inference Server", lifespan=lifespan)
|
|||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
def health():
|
def health():
|
||||||
|
cfg = get_config()
|
||||||
return {
|
return {
|
||||||
"status": "ok",
|
"status": "ok",
|
||||||
"device": _get_device(),
|
"device": get_device(),
|
||||||
"loaded_models": list(_models.keys()),
|
"loaded_models": registry.loaded(),
|
||||||
"vram_budget_mb": _config["vram_budget_mb"],
|
"vram_budget_mb": cfg["vram_budget_mb"],
|
||||||
"strategy": _config["strategy"],
|
"strategy": cfg["strategy"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/config")
|
@app.get("/config")
|
||||||
def get_config():
|
def read_config():
|
||||||
"""Current runtime config. Same values the .env sets at startup."""
|
return {**get_config(), "device_resolved": get_device()}
|
||||||
return {**_config, "device_resolved": _get_device()}
|
|
||||||
|
|
||||||
|
|
||||||
@app.put("/config")
|
@app.put("/config")
|
||||||
def update_config(update: ConfigUpdate):
|
def write_config(update: ConfigUpdate):
|
||||||
"""Update runtime config. Only provided fields are changed."""
|
|
||||||
changes = update.model_dump(exclude_none=True)
|
changes = update.model_dump(exclude_none=True)
|
||||||
if not changes:
|
if not changes:
|
||||||
return _config
|
return get_config()
|
||||||
|
|
||||||
# If model changed, unload the old one so it gets reloaded on next request
|
# Unload model if it changed
|
||||||
if "yolo_model" in changes and changes["yolo_model"] != _config["yolo_model"]:
|
old_model = get_config().get("yolo_model")
|
||||||
old = _config["yolo_model"]
|
if "yolo_model" in changes and changes["yolo_model"] != old_model:
|
||||||
if old in _models:
|
registry.unload(old_model)
|
||||||
del _models[old]
|
|
||||||
logger.info("Unloaded %s (model changed)", old)
|
|
||||||
|
|
||||||
_config.update(changes)
|
update_config(changes)
|
||||||
logger.info("Config updated: %s", changes)
|
logger.info("Config updated: %s", changes)
|
||||||
return {**_config, "device_resolved": _get_device()}
|
return {**get_config(), "device_resolved": get_device()}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/models/unload")
|
@app.post("/models/unload")
|
||||||
def unload_model(body: dict):
|
def unload_model(body: dict):
|
||||||
"""Unload a model from memory to free VRAM."""
|
|
||||||
name = body.get("model", "")
|
name = body.get("model", "")
|
||||||
if name in _models:
|
unloaded = registry.unload(name)
|
||||||
del _models[name]
|
return {"status": "unloaded" if unloaded else "not_loaded", "model": name}
|
||||||
logger.info("Unloaded %s", name)
|
|
||||||
return {"status": "unloaded", "model": name}
|
|
||||||
return {"status": "not_loaded", "model": name}
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/detect", response_model=DetectResponse)
|
@app.post("/detect", response_model=DetectResponse)
|
||||||
def detect(req: DetectRequest):
|
def detect(req: DetectRequest):
|
||||||
model_name = req.model or _config["yolo_model"]
|
try:
|
||||||
confidence = req.confidence if req.confidence is not None else _config["yolo_confidence"]
|
image = _decode_image(req.image)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model = _get_yolo(model_name)
|
results = yolo_detect(
|
||||||
|
image,
|
||||||
|
model_name=req.model,
|
||||||
|
confidence=req.confidence,
|
||||||
|
target_classes=req.target_classes,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to load model: {e}")
|
raise HTTPException(status_code=500, detail=f"Detection failed: {e}")
|
||||||
|
|
||||||
image = _decode_image(req.image)
|
return DetectResponse(detections=[BBox(**r) for r in results])
|
||||||
results = model(image, conf=confidence, verbose=False)
|
|
||||||
|
|
||||||
detections = []
|
|
||||||
for r in results:
|
|
||||||
for box in r.boxes:
|
|
||||||
x1, y1, x2, y2 = box.xyxy[0].tolist()
|
|
||||||
label = r.names[int(box.cls[0])]
|
|
||||||
|
|
||||||
if req.target_classes and label not in req.target_classes:
|
@app.post("/ocr", response_model=OCRResponse)
|
||||||
continue
|
def ocr(req: OCRRequest):
|
||||||
|
try:
|
||||||
|
image = _decode_image(req.image)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
||||||
|
|
||||||
det = BBox(
|
try:
|
||||||
x=int(x1), y=int(y1),
|
results = ocr_run(image, languages=req.languages)
|
||||||
w=int(x2 - x1), h=int(y2 - y1),
|
except Exception as e:
|
||||||
confidence=float(box.conf[0]),
|
raise HTTPException(status_code=500, detail=f"OCR failed: {e}")
|
||||||
label=label,
|
|
||||||
)
|
|
||||||
detections.append(det)
|
|
||||||
|
|
||||||
return DetectResponse(detections=detections)
|
return OCRResponse(results=[OCRTextResult(**r) for r in results])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
176
tests/detect/manual/test_brand_table_e2e.py
Normal file
176
tests/detect/manual/test_brand_table_e2e.py
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Push OCR + brand detection events to test the BrandTablePanel live.
|
||||||
|
|
||||||
|
Simulates what the OCR and BrandResolver stages emit: detection events
|
||||||
|
with brand names, confidence scores, sources, and frame refs. Watch
|
||||||
|
the BrandTablePanel in the UI populate and sort in real time.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python tests/detect/manual/test_brand_table_e2e.py [--job JOB_ID] [--port PORT] [--delay SECS]
|
||||||
|
|
||||||
|
Opens: http://mpr.local.ar/detection/?job=<JOB_ID>
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import redis
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(message)s")
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DETECTIONS = [
|
||||||
|
# (brand, confidence, source, timestamp, frame_ref) — simulates a real match
|
||||||
|
("Nike", 0.97, "ocr", 2.0, 4),
|
||||||
|
("Nike", 0.95, "ocr", 3.5, 7),
|
||||||
|
("Emirates", 0.92, "ocr", 5.0, 10),
|
||||||
|
("Adidas", 0.89, "ocr", 7.5, 15),
|
||||||
|
("Coca-Cola", 0.85, "ocr", 10.0, 20),
|
||||||
|
("Nike", 0.94, "ocr", 12.5, 25),
|
||||||
|
("Emirates", 0.88, "ocr", 15.0, 30),
|
||||||
|
("Mastercard", 0.78, "local_vlm", 18.0, 36),
|
||||||
|
("Heineken", 0.72, "cloud_llm", 22.5, 45),
|
||||||
|
("Adidas", 0.91, "ocr", 25.0, 50),
|
||||||
|
("Nike", 0.96, "ocr", 27.5, 55),
|
||||||
|
("Emirates", 0.90, "ocr", 30.0, 60),
|
||||||
|
("Unknown Brand", 0.65, "cloud_llm", 33.0, 66),
|
||||||
|
("Coca-Cola", 0.87, "ocr", 35.5, 71),
|
||||||
|
("Nike", 0.93, "ocr", 38.0, 76),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def ts():
|
||||||
|
return datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
|
|
||||||
|
def push(r, key, event):
|
||||||
|
event["ts"] = event.get("ts", ts())
|
||||||
|
r.rpush(key, json.dumps(event))
|
||||||
|
return event
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--job", default="brand-table-test")
|
||||||
|
parser.add_argument("--port", type=int, default=6382)
|
||||||
|
parser.add_argument("--delay", type=float, default=0.6)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
r = redis.Redis(port=args.port, decode_responses=True)
|
||||||
|
key = f"detect_events:{args.job}"
|
||||||
|
|
||||||
|
r.delete(key)
|
||||||
|
|
||||||
|
logger.info("Pushing %d detections to %s", len(DETECTIONS), key)
|
||||||
|
logger.info("Open: http://mpr.local.ar/detection/?job=%s", args.job)
|
||||||
|
input("\nPress Enter to start...")
|
||||||
|
|
||||||
|
# Progressive stats — mimics real pipeline stages so the funnel chart draws lines
|
||||||
|
STATS_PROGRESSION = [
|
||||||
|
{"event": "stats_update",
|
||||||
|
"frames_extracted": 120, "frames_after_scene_filter": 0,
|
||||||
|
"regions_detected": 0, "regions_resolved_by_ocr": 0,
|
||||||
|
"regions_escalated_to_local_vlm": 0, "regions_escalated_to_cloud_llm": 0,
|
||||||
|
"cloud_llm_calls": 0, "processing_time_seconds": 3.2, "estimated_cloud_cost_usd": 0},
|
||||||
|
{"event": "stats_update",
|
||||||
|
"frames_extracted": 120, "frames_after_scene_filter": 45,
|
||||||
|
"regions_detected": 0, "regions_resolved_by_ocr": 0,
|
||||||
|
"regions_escalated_to_local_vlm": 0, "regions_escalated_to_cloud_llm": 0,
|
||||||
|
"cloud_llm_calls": 0, "processing_time_seconds": 5.1, "estimated_cloud_cost_usd": 0},
|
||||||
|
{"event": "stats_update",
|
||||||
|
"frames_extracted": 120, "frames_after_scene_filter": 45,
|
||||||
|
"regions_detected": 32, "regions_resolved_by_ocr": 0,
|
||||||
|
"regions_escalated_to_local_vlm": 0, "regions_escalated_to_cloud_llm": 0,
|
||||||
|
"cloud_llm_calls": 0, "processing_time_seconds": 12.4, "estimated_cloud_cost_usd": 0},
|
||||||
|
]
|
||||||
|
|
||||||
|
NODES = ["extract_frames", "filter_scenes", "detect_objects", "run_ocr",
|
||||||
|
"match_brands", "escalate_vlm", "escalate_cloud", "compile_report"]
|
||||||
|
|
||||||
|
def push_graph(r, key, active_node, status, delay):
|
||||||
|
nodes = []
|
||||||
|
for n in NODES:
|
||||||
|
if n == active_node:
|
||||||
|
nodes.append({"id": n, "status": status})
|
||||||
|
elif NODES.index(n) < NODES.index(active_node):
|
||||||
|
nodes.append({"id": n, "status": "done"})
|
||||||
|
else:
|
||||||
|
nodes.append({"id": n, "status": "pending"})
|
||||||
|
push(r, key, {"event": "graph_update", "nodes": nodes})
|
||||||
|
time.sleep(delay)
|
||||||
|
|
||||||
|
# Simulate pipeline progression: extract → filter → detect
|
||||||
|
push(r, key, {"event": "log", "level": "INFO", "stage": "BrandResolver",
|
||||||
|
"msg": f"Starting brand matching — {len(DETECTIONS)} candidates"})
|
||||||
|
time.sleep(args.delay)
|
||||||
|
|
||||||
|
push_graph(r, key, "extract_frames", "running", args.delay)
|
||||||
|
push(r, key, STATS_PROGRESSION[0])
|
||||||
|
time.sleep(args.delay)
|
||||||
|
push_graph(r, key, "extract_frames", "done", args.delay)
|
||||||
|
|
||||||
|
push_graph(r, key, "filter_scenes", "running", args.delay)
|
||||||
|
push(r, key, STATS_PROGRESSION[1])
|
||||||
|
time.sleep(args.delay)
|
||||||
|
push_graph(r, key, "filter_scenes", "done", args.delay)
|
||||||
|
|
||||||
|
push_graph(r, key, "detect_objects", "running", args.delay)
|
||||||
|
push(r, key, STATS_PROGRESSION[2])
|
||||||
|
time.sleep(args.delay)
|
||||||
|
push_graph(r, key, "detect_objects", "done", args.delay)
|
||||||
|
|
||||||
|
push_graph(r, key, "run_ocr", "running", args.delay)
|
||||||
|
|
||||||
|
for i, (brand, conf, source, timestamp, frame_ref) in enumerate(DETECTIONS):
|
||||||
|
push(r, key, {"event": "detection",
|
||||||
|
"brand": brand,
|
||||||
|
"confidence": conf,
|
||||||
|
"source": source,
|
||||||
|
"timestamp": timestamp,
|
||||||
|
"duration": 0.5,
|
||||||
|
"content_type": "soccer_broadcast",
|
||||||
|
"frame_ref": frame_ref})
|
||||||
|
|
||||||
|
logger.info("[%d/%d] %s conf=%.2f source=%s t=%.1fs frame=%d",
|
||||||
|
i + 1, len(DETECTIONS), brand, conf, source, timestamp, frame_ref)
|
||||||
|
time.sleep(args.delay)
|
||||||
|
|
||||||
|
push_graph(r, key, "run_ocr", "done", args.delay)
|
||||||
|
push_graph(r, key, "match_brands", "running", args.delay)
|
||||||
|
|
||||||
|
# Final stats after brand matching
|
||||||
|
push_graph(r, key, "match_brands", "done", args.delay)
|
||||||
|
push_graph(r, key, "escalate_vlm", "running", args.delay)
|
||||||
|
push_graph(r, key, "escalate_vlm", "done", args.delay)
|
||||||
|
push_graph(r, key, "escalate_cloud", "running", args.delay)
|
||||||
|
push_graph(r, key, "escalate_cloud", "done", args.delay)
|
||||||
|
push_graph(r, key, "compile_report", "running", args.delay)
|
||||||
|
|
||||||
|
push(r, key, {"event": "stats_update",
|
||||||
|
"frames_extracted": 120,
|
||||||
|
"frames_after_scene_filter": 45,
|
||||||
|
"regions_detected": 32,
|
||||||
|
"regions_resolved_by_ocr": 24,
|
||||||
|
"regions_escalated_to_local_vlm": 6,
|
||||||
|
"regions_escalated_to_cloud_llm": 2,
|
||||||
|
"cloud_llm_calls": 2,
|
||||||
|
"processing_time_seconds": 31.4,
|
||||||
|
"estimated_cloud_cost_usd": 0.0038})
|
||||||
|
time.sleep(args.delay)
|
||||||
|
|
||||||
|
push_graph(r, key, "compile_report", "done", args.delay)
|
||||||
|
|
||||||
|
push(r, key, {"event": "log", "level": "INFO", "stage": "BrandResolver",
|
||||||
|
"msg": "Brand matching complete — "
|
||||||
|
f"{len(DETECTIONS)} detections, "
|
||||||
|
f"{len(set(d[0] for d in DETECTIONS))} unique brands"})
|
||||||
|
|
||||||
|
logger.info("Done. Watch the BrandTablePanel — try sorting by confidence and brand.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
135
tests/detect/manual/test_ocr_e2e.py
Normal file
135
tests/detect/manual/test_ocr_e2e.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test OCR stage end-to-end — sends real images to the inference server.
|
||||||
|
|
||||||
|
Creates test images with known text, sends them through the /ocr endpoint,
|
||||||
|
verifies the text comes back. Tests both the inference server and the
|
||||||
|
ocr_stage module's remote path.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python tests/detect/manual/test_ocr_e2e.py [--url URL]
|
||||||
|
|
||||||
|
Requires: inference server running (gpu/server.py)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(message)s")
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def make_text_image(text: str, width: int = 300, height: int = 80) -> np.ndarray:
|
||||||
|
"""Create a white image with black text for OCR testing."""
|
||||||
|
img = Image.new("RGB", (width, height), "white")
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
try:
|
||||||
|
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 36)
|
||||||
|
except (OSError, IOError):
|
||||||
|
font = ImageFont.load_default()
|
||||||
|
draw.text((10, 15), text, fill="black", font=font)
|
||||||
|
return np.array(img)
|
||||||
|
|
||||||
|
|
||||||
|
def image_to_b64(image: np.ndarray) -> str:
|
||||||
|
img = Image.fromarray(image)
|
||||||
|
buf = io.BytesIO()
|
||||||
|
img.save(buf, "JPEG")
|
||||||
|
return base64.b64encode(buf.getvalue()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def test_health(url: str):
|
||||||
|
logger.info("--- Health check ---")
|
||||||
|
resp = requests.get(f"{url}/health")
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
logger.info("Status: %s, device: %s", data["status"], data["device"])
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_ocr_endpoint(url: str, text: str):
|
||||||
|
logger.info("--- OCR endpoint: '%s' ---", text)
|
||||||
|
image = make_text_image(text)
|
||||||
|
b64 = image_to_b64(image)
|
||||||
|
|
||||||
|
resp = requests.post(f"{url}/ocr", json={"image": b64})
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
|
||||||
|
results = data.get("results", [])
|
||||||
|
logger.info("Results: %d text regions", len(results))
|
||||||
|
|
||||||
|
found = False
|
||||||
|
for r in results:
|
||||||
|
logger.info(" text=%r confidence=%.3f bbox=%s", r["text"], r["confidence"], r["bbox"])
|
||||||
|
if text.lower() in r["text"].lower():
|
||||||
|
found = True
|
||||||
|
|
||||||
|
if found:
|
||||||
|
logger.info("PASS — found '%s' in OCR output", text)
|
||||||
|
else:
|
||||||
|
logger.warning("MISS — '%s' not found (may be font/rendering issue, check results above)", text)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def test_ocr_stage_remote(url: str):
|
||||||
|
"""Test the detect/stages/ocr_stage.py remote path."""
|
||||||
|
logger.info("--- OCR stage (remote mode) ---")
|
||||||
|
|
||||||
|
sys.path.insert(0, ".")
|
||||||
|
from detect.models import BoundingBox, Frame
|
||||||
|
from detect.profiles.base import OCRConfig
|
||||||
|
from detect.stages.ocr_stage import run_ocr
|
||||||
|
|
||||||
|
# Create a frame with text baked in
|
||||||
|
image = make_text_image("EMIRATES")
|
||||||
|
frame = Frame(sequence=0, chunk_id=0, timestamp=1.0, image=image)
|
||||||
|
box = BoundingBox(x=0, y=0, w=image.shape[1], h=image.shape[0], confidence=0.9, label="text")
|
||||||
|
config = OCRConfig(languages=["en"], min_confidence=0.3)
|
||||||
|
|
||||||
|
candidates = run_ocr(
|
||||||
|
frames=[frame],
|
||||||
|
boxes_by_frame={0: [box]},
|
||||||
|
config=config,
|
||||||
|
inference_url=url,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Candidates: %d", len(candidates))
|
||||||
|
for c in candidates:
|
||||||
|
logger.info(" text=%r confidence=%.3f", c.text, c.ocr_confidence)
|
||||||
|
|
||||||
|
if candidates:
|
||||||
|
logger.info("PASS — ocr_stage remote path returned results")
|
||||||
|
else:
|
||||||
|
logger.warning("MISS — no candidates returned (check inference server logs)")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--url", default="http://mcrndeb:8000")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
url = args.url.rstrip("/")
|
||||||
|
logger.info("Inference server: %s", url)
|
||||||
|
input("\nPress Enter to start...")
|
||||||
|
|
||||||
|
test_health(url)
|
||||||
|
test_ocr_endpoint(url, "NIKE")
|
||||||
|
test_ocr_endpoint(url, "Coca-Cola")
|
||||||
|
test_ocr_endpoint(url, "EMIRATES")
|
||||||
|
test_ocr_stage_remote(url)
|
||||||
|
|
||||||
|
logger.info("All OCR tests complete.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
92
tests/detect/test_brand_resolver.py
Normal file
92
tests/detect/test_brand_resolver.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
"""Tests for BrandResolver stage."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from detect.models import BoundingBox, Frame, TextCandidate
|
||||||
|
from detect.profiles.base import BrandDictionary, ResolverConfig
|
||||||
|
from detect.stages.brand_resolver import resolve_brands, _exact_match, _fuzzy_match
|
||||||
|
|
||||||
|
|
||||||
|
DICTIONARY = BrandDictionary(brands={
|
||||||
|
"Nike": ["nike", "NIKE", "swoosh"],
|
||||||
|
"Adidas": ["adidas", "ADIDAS"],
|
||||||
|
"Coca-Cola": ["coca-cola", "coca cola", "coke", "COCA-COLA"],
|
||||||
|
"Emirates": ["emirates", "fly emirates", "EMIRATES"],
|
||||||
|
})
|
||||||
|
|
||||||
|
CONFIG = ResolverConfig(fuzzy_threshold=75)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_candidate(text: str, confidence: float = 0.9) -> TextCandidate:
|
||||||
|
dummy_frame = Frame(sequence=0, chunk_id=0, timestamp=1.0,
|
||||||
|
image=np.zeros((10, 10, 3), dtype=np.uint8))
|
||||||
|
dummy_box = BoundingBox(x=0, y=0, w=10, h=10, confidence=0.8, label="text")
|
||||||
|
return TextCandidate(frame=dummy_frame, bbox=dummy_box, text=text, ocr_confidence=confidence)
|
||||||
|
|
||||||
|
|
||||||
|
def test_exact_match():
|
||||||
|
assert _exact_match("Nike", DICTIONARY) == "Nike"
|
||||||
|
assert _exact_match("nike", DICTIONARY) == "Nike"
|
||||||
|
assert _exact_match("COCA-COLA", DICTIONARY) == "Coca-Cola"
|
||||||
|
assert _exact_match("fly emirates", DICTIONARY) == "Emirates"
|
||||||
|
assert _exact_match("unknown brand", DICTIONARY) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_fuzzy_match():
|
||||||
|
brand, score = _fuzzy_match("Nik3", DICTIONARY, threshold=75)
|
||||||
|
assert brand == "Nike"
|
||||||
|
assert score >= 75
|
||||||
|
|
||||||
|
brand, score = _fuzzy_match("adldas", DICTIONARY, threshold=75)
|
||||||
|
assert brand == "Adidas"
|
||||||
|
|
||||||
|
brand, score = _fuzzy_match("xyzxyzxyz", DICTIONARY, threshold=75)
|
||||||
|
assert brand is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_exact():
|
||||||
|
candidates = [_make_candidate("Nike"), _make_candidate("EMIRATES")]
|
||||||
|
matched, unresolved = resolve_brands(candidates, DICTIONARY, CONFIG)
|
||||||
|
assert len(matched) == 2
|
||||||
|
assert len(unresolved) == 0
|
||||||
|
assert matched[0].brand == "Nike"
|
||||||
|
assert matched[1].brand == "Emirates"
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_fuzzy():
|
||||||
|
candidates = [_make_candidate("coca coIa")] # OCR misread
|
||||||
|
matched, unresolved = resolve_brands(candidates, DICTIONARY, CONFIG)
|
||||||
|
assert len(matched) == 1
|
||||||
|
assert matched[0].brand == "Coca-Cola"
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_unresolved():
|
||||||
|
candidates = [_make_candidate("random garbage text")]
|
||||||
|
matched, unresolved = resolve_brands(candidates, DICTIONARY, CONFIG)
|
||||||
|
assert len(matched) == 0
|
||||||
|
assert len(unresolved) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_mixed():
|
||||||
|
candidates = [
|
||||||
|
_make_candidate("Nike"),
|
||||||
|
_make_candidate("unknown"),
|
||||||
|
_make_candidate("adldas"),
|
||||||
|
]
|
||||||
|
matched, unresolved = resolve_brands(candidates, DICTIONARY, CONFIG)
|
||||||
|
assert len(matched) == 2 # Nike exact + Adidas fuzzy
|
||||||
|
assert len(unresolved) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_events_emitted(monkeypatch):
|
||||||
|
events = []
|
||||||
|
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||||
|
lambda job_id, etype, data: events.append((etype, data)))
|
||||||
|
|
||||||
|
candidates = [_make_candidate("Nike")]
|
||||||
|
resolve_brands(candidates, DICTIONARY, CONFIG, job_id="test-job")
|
||||||
|
|
||||||
|
event_types = [e[0] for e in events]
|
||||||
|
assert "log" in event_types
|
||||||
|
assert "detection" in event_types
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
"""Tests for the LangGraph detection pipeline."""
|
"""Tests for the LangGraph detection pipeline."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from detect.graph import NODES, build_graph, get_pipeline
|
from detect.graph import NODES, build_graph, get_pipeline
|
||||||
@@ -9,6 +11,22 @@ from detect.state import DetectState
|
|||||||
VIDEO = "media/out/chunks/95043d50-4df6-4ac8-bbd5-2ba873117c6e/chunk_0000.mp4"
|
VIDEO = "media/out/chunks/95043d50-4df6-4ac8-bbd5-2ba873117c6e/chunk_0000.mp4"
|
||||||
|
|
||||||
|
|
||||||
|
def _has_inference() -> bool:
|
||||||
|
if os.environ.get("INFERENCE_URL"):
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
import ultralytics
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
requires_inference = pytest.mark.skipif(
|
||||||
|
not _has_inference(),
|
||||||
|
reason="Needs INFERENCE_URL or ultralytics installed",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_graph_compiles():
|
def test_graph_compiles():
|
||||||
pipeline = get_pipeline()
|
pipeline = get_pipeline()
|
||||||
assert pipeline is not None
|
assert pipeline is not None
|
||||||
@@ -20,6 +38,7 @@ def test_graph_has_all_nodes():
|
|||||||
assert node in graph.nodes
|
assert node in graph.nodes
|
||||||
|
|
||||||
|
|
||||||
|
@requires_inference
|
||||||
def test_graph_runs_end_to_end(monkeypatch):
|
def test_graph_runs_end_to_end(monkeypatch):
|
||||||
"""Run the full graph with mocked event emission."""
|
"""Run the full graph with mocked event emission."""
|
||||||
events = []
|
events = []
|
||||||
@@ -52,6 +71,7 @@ def test_graph_runs_end_to_end(monkeypatch):
|
|||||||
assert len(complete_events) == 1
|
assert len(complete_events) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@requires_inference
|
||||||
def test_graph_node_transitions(monkeypatch):
|
def test_graph_node_transitions(monkeypatch):
|
||||||
"""Verify each node emits running → done transitions."""
|
"""Verify each node emits running → done transitions."""
|
||||||
events = []
|
events = []
|
||||||
|
|||||||
141
tests/detect/test_ocr_stage.py
Normal file
141
tests/detect/test_ocr_stage.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
"""Tests for OCR stage."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from detect.models import BoundingBox, Frame
|
||||||
|
from detect.profiles.base import OCRConfig
|
||||||
|
from detect.stages.ocr_stage import _crop_region, _parse_ocr_raw, run_ocr
|
||||||
|
|
||||||
|
|
||||||
|
def _has_paddleocr() -> bool:
|
||||||
|
try:
|
||||||
|
import paddleocr
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _make_frame(seq: int = 0, w: int = 100, h: int = 80) -> Frame:
|
||||||
|
image = np.zeros((h, w, 3), dtype=np.uint8)
|
||||||
|
return Frame(sequence=seq, chunk_id=0, timestamp=float(seq), image=image)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_box(x=10, y=10, w=30, h=20) -> BoundingBox:
|
||||||
|
return BoundingBox(x=x, y=y, w=w, h=h, confidence=0.9, label="text")
|
||||||
|
|
||||||
|
|
||||||
|
# --- _crop_region ---
|
||||||
|
|
||||||
|
def test_crop_basic():
|
||||||
|
frame = _make_frame()
|
||||||
|
box = _make_box(x=10, y=20, w=30, h=15)
|
||||||
|
crop = _crop_region(frame, box)
|
||||||
|
assert crop.shape == (15, 30, 3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_crop_clamps_to_frame():
|
||||||
|
frame = _make_frame(w=50, h=40)
|
||||||
|
box = _make_box(x=30, y=25, w=100, h=100)
|
||||||
|
crop = _crop_region(frame, box)
|
||||||
|
assert crop.shape[0] == 15 # 40 - 25
|
||||||
|
assert crop.shape[1] == 20 # 50 - 30
|
||||||
|
|
||||||
|
|
||||||
|
def test_crop_negative_origin():
|
||||||
|
frame = _make_frame()
|
||||||
|
box = _make_box(x=-5, y=-5, w=20, h=20)
|
||||||
|
crop = _crop_region(frame, box)
|
||||||
|
assert crop.shape[0] == 15 # min(80, -5+20) - 0
|
||||||
|
assert crop.shape[1] == 15 # min(100, -5+20) - 0
|
||||||
|
|
||||||
|
|
||||||
|
# --- _parse_ocr_raw ---
|
||||||
|
|
||||||
|
def test_parse_nested_list_layout():
|
||||||
|
raw = [[
|
||||||
|
[[[0, 0], [10, 0], [10, 10], [0, 10]], ["hello", 0.95]],
|
||||||
|
[[[0, 0], [10, 0], [10, 10], [0, 10]], ["low", 0.2]],
|
||||||
|
]]
|
||||||
|
results = _parse_ocr_raw(raw, min_confidence=0.5)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["text"] == "hello"
|
||||||
|
assert results[0]["confidence"] == 0.95
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_dict_layout():
|
||||||
|
raw = [{"rec_texts": ["brand", "noise"], "rec_scores": [0.9, 0.3]}]
|
||||||
|
results = _parse_ocr_raw(raw, min_confidence=0.5)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["text"] == "brand"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_empty():
|
||||||
|
assert _parse_ocr_raw(None, 0.5) == []
|
||||||
|
assert _parse_ocr_raw([], 0.5) == []
|
||||||
|
assert _parse_ocr_raw([[]], 0.5) == []
|
||||||
|
|
||||||
|
|
||||||
|
# --- run_ocr (remote, mocked) ---
|
||||||
|
|
||||||
|
def test_run_ocr_remote(monkeypatch):
|
||||||
|
events = []
|
||||||
|
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||||
|
lambda job_id, etype, data: events.append((etype, data)))
|
||||||
|
|
||||||
|
class FakeResult:
|
||||||
|
def __init__(self, text, confidence):
|
||||||
|
self.text = text
|
||||||
|
self.confidence = confidence
|
||||||
|
|
||||||
|
class FakeClient:
|
||||||
|
def __init__(self, base_url):
|
||||||
|
pass
|
||||||
|
def ocr(self, image, languages):
|
||||||
|
return [FakeResult("NIKE", 0.92)]
|
||||||
|
|
||||||
|
monkeypatch.setattr("detect.stages.ocr_stage.InferenceClient", FakeClient,
|
||||||
|
raising=False)
|
||||||
|
# Patch the import path used in the function
|
||||||
|
import detect.stages.ocr_stage as mod
|
||||||
|
monkeypatch.setattr("detect.inference.InferenceClient", FakeClient)
|
||||||
|
|
||||||
|
frame = _make_frame()
|
||||||
|
box = _make_box()
|
||||||
|
config = OCRConfig(languages=["en"], min_confidence=0.5)
|
||||||
|
|
||||||
|
candidates = run_ocr(
|
||||||
|
frames=[frame],
|
||||||
|
boxes_by_frame={0: [box]},
|
||||||
|
config=config,
|
||||||
|
inference_url="http://fake:8000",
|
||||||
|
job_id="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(candidates) == 1
|
||||||
|
assert candidates[0].text == "NIKE"
|
||||||
|
assert candidates[0].ocr_confidence == 0.92
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not _has_paddleocr(),
|
||||||
|
reason="Needs paddleocr installed (GPU box)",
|
||||||
|
)
|
||||||
|
def test_run_ocr_skips_empty_crop(monkeypatch):
|
||||||
|
events = []
|
||||||
|
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||||
|
lambda job_id, etype, data: events.append((etype, data)))
|
||||||
|
|
||||||
|
frame = _make_frame(w=10, h=10)
|
||||||
|
box = _make_box(x=100, y=100, w=5, h=5) # outside frame → empty crop
|
||||||
|
config = OCRConfig(languages=["en"], min_confidence=0.5)
|
||||||
|
|
||||||
|
candidates = run_ocr(
|
||||||
|
frames=[frame],
|
||||||
|
boxes_by_frame={0: [box]},
|
||||||
|
config=config,
|
||||||
|
inference_url=None,
|
||||||
|
job_id="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(candidates) == 0
|
||||||
@@ -22,7 +22,7 @@ def test_soccer_frame_extraction_config():
|
|||||||
def test_soccer_detection_config():
|
def test_soccer_detection_config():
|
||||||
cfg = SoccerBroadcastProfile().detection_config()
|
cfg = SoccerBroadcastProfile().detection_config()
|
||||||
assert 0 < cfg.confidence_threshold < 1
|
assert 0 < cfg.confidence_threshold < 1
|
||||||
assert len(cfg.target_classes) > 0
|
assert isinstance(cfg.target_classes, list)
|
||||||
|
|
||||||
|
|
||||||
def test_soccer_brand_dictionary_non_empty():
|
def test_soccer_brand_dictionary_non_empty():
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import LogPanel from './panels/LogPanel.vue'
|
|||||||
import FunnelPanel from './panels/FunnelPanel.vue'
|
import FunnelPanel from './panels/FunnelPanel.vue'
|
||||||
import PipelineGraphPanel from './panels/PipelineGraphPanel.vue'
|
import PipelineGraphPanel from './panels/PipelineGraphPanel.vue'
|
||||||
import FramePanel from './panels/FramePanel.vue'
|
import FramePanel from './panels/FramePanel.vue'
|
||||||
|
import BrandTablePanel from './panels/BrandTablePanel.vue'
|
||||||
import type { StatsUpdate } from './types/sse-contract'
|
import type { StatsUpdate } from './types/sse-contract'
|
||||||
|
|
||||||
const jobId = ref(new URLSearchParams(window.location.search).get('job') || 'test-job')
|
const jobId = ref(new URLSearchParams(window.location.search).get('job') || 'test-job')
|
||||||
@@ -42,7 +43,7 @@ source.connect()
|
|||||||
<span class="job-id">job: {{ jobId }}</span>
|
<span class="job-id">job: {{ jobId }}</span>
|
||||||
</header>
|
</header>
|
||||||
|
|
||||||
<LayoutGrid :columns="3" :rows="2" gap="var(--space-2)">
|
<LayoutGrid :columns="3" :rows="3" gap="var(--space-2)">
|
||||||
<Panel title="Stats" :status="status">
|
<Panel title="Stats" :status="status">
|
||||||
<div class="stats" v-if="stats">
|
<div class="stats" v-if="stats">
|
||||||
<div class="stat" v-for="s in [
|
<div class="stat" v-for="s in [
|
||||||
@@ -66,6 +67,8 @@ source.connect()
|
|||||||
|
|
||||||
<PipelineGraphPanel :source="source" :status="status" />
|
<PipelineGraphPanel :source="source" :status="status" />
|
||||||
|
|
||||||
|
<BrandTablePanel :source="source" :status="status" />
|
||||||
|
|
||||||
<LogPanel :source="source" :status="status" />
|
<LogPanel :source="source" :status="status" />
|
||||||
</LayoutGrid>
|
</LayoutGrid>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
57
ui/detection-app/src/panels/BrandTablePanel.vue
Normal file
57
ui/detection-app/src/panels/BrandTablePanel.vue
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
<script setup lang="ts">
|
||||||
|
import { ref } from 'vue'
|
||||||
|
import { Panel } from 'mpr-ui-framework'
|
||||||
|
import TableRenderer from 'mpr-ui-framework/src/renderers/TableRenderer.vue'
|
||||||
|
import type { TableColumn } from 'mpr-ui-framework/src/renderers/TableRenderer.vue'
|
||||||
|
import type { DataSource } from 'mpr-ui-framework'
|
||||||
|
|
||||||
|
const props = defineProps<{
|
||||||
|
source: DataSource
|
||||||
|
status?: 'idle' | 'live' | 'processing' | 'error'
|
||||||
|
}>()
|
||||||
|
|
||||||
|
const columns: TableColumn[] = [
|
||||||
|
{ key: 'brand', label: 'Brand', width: '120px' },
|
||||||
|
{ key: 'confidence', label: 'Conf', width: '60px' },
|
||||||
|
{ key: 'source', label: 'Source', width: '80px' },
|
||||||
|
{ key: 'timestamp', label: 'Time', width: '60px' },
|
||||||
|
{ key: 'content_type', label: 'Type', width: '100px' },
|
||||||
|
{ key: 'frame_ref', label: 'Frame', width: '50px' },
|
||||||
|
]
|
||||||
|
|
||||||
|
const rows = ref<Record<string, unknown>[]>([])
|
||||||
|
const sortKey = ref('timestamp')
|
||||||
|
const sortDir = ref<'asc' | 'desc'>('desc')
|
||||||
|
|
||||||
|
props.source.on<Record<string, unknown>>('detection', (e) => {
|
||||||
|
rows.value.push({
|
||||||
|
brand: e.brand,
|
||||||
|
confidence: typeof e.confidence === 'number' ? (e.confidence as number).toFixed(2) : e.confidence,
|
||||||
|
source: e.source,
|
||||||
|
timestamp: typeof e.timestamp === 'number' ? (e.timestamp as number).toFixed(1) : e.timestamp,
|
||||||
|
content_type: e.content_type,
|
||||||
|
frame_ref: e.frame_ref,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
function onSort(key: string) {
|
||||||
|
if (sortKey.value === key) {
|
||||||
|
sortDir.value = sortDir.value === 'asc' ? 'desc' : 'asc'
|
||||||
|
} else {
|
||||||
|
sortKey.value = key
|
||||||
|
sortDir.value = 'desc'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<template>
|
||||||
|
<Panel title="Detections" :status="status">
|
||||||
|
<TableRenderer
|
||||||
|
:columns="columns"
|
||||||
|
:rows="rows"
|
||||||
|
:sort-key="sortKey"
|
||||||
|
:sort-dir="sortDir"
|
||||||
|
@sort="onSort"
|
||||||
|
/>
|
||||||
|
</Panel>
|
||||||
|
</template>
|
||||||
@@ -13,3 +13,4 @@ export { default as LogRenderer } from './renderers/LogRenderer.vue'
|
|||||||
export { default as TimeSeriesRenderer } from './renderers/TimeSeriesRenderer.vue'
|
export { default as TimeSeriesRenderer } from './renderers/TimeSeriesRenderer.vue'
|
||||||
export { default as GraphRenderer } from './renderers/GraphRenderer.vue'
|
export { default as GraphRenderer } from './renderers/GraphRenderer.vue'
|
||||||
export { default as FrameRenderer } from './renderers/FrameRenderer.vue'
|
export { default as FrameRenderer } from './renderers/FrameRenderer.vue'
|
||||||
|
export { default as TableRenderer } from './renderers/TableRenderer.vue'
|
||||||
|
|||||||
119
ui/framework/src/renderers/TableRenderer.vue
Normal file
119
ui/framework/src/renderers/TableRenderer.vue
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
<script setup lang="ts">
|
||||||
|
import { computed } from 'vue'
|
||||||
|
|
||||||
|
export interface TableColumn {
|
||||||
|
key: string
|
||||||
|
label: string
|
||||||
|
width?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
const props = defineProps<{
|
||||||
|
columns: TableColumn[]
|
||||||
|
rows: Record<string, unknown>[]
|
||||||
|
sortKey?: string
|
||||||
|
sortDir?: 'asc' | 'desc'
|
||||||
|
}>()
|
||||||
|
|
||||||
|
const emits = defineEmits<{
|
||||||
|
sort: [key: string]
|
||||||
|
}>()
|
||||||
|
|
||||||
|
const sorted = computed(() => {
|
||||||
|
if (!props.sortKey) return props.rows
|
||||||
|
const key = props.sortKey
|
||||||
|
const dir = props.sortDir === 'desc' ? -1 : 1
|
||||||
|
return [...props.rows].sort((a, b) => {
|
||||||
|
const av = a[key] as number | string
|
||||||
|
const bv = b[key] as number | string
|
||||||
|
if (av < bv) return -1 * dir
|
||||||
|
if (av > bv) return 1 * dir
|
||||||
|
return 0
|
||||||
|
})
|
||||||
|
})
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<template>
|
||||||
|
<div class="table-renderer">
|
||||||
|
<table>
|
||||||
|
<thead>
|
||||||
|
<tr>
|
||||||
|
<th
|
||||||
|
v-for="col in columns"
|
||||||
|
:key="col.key"
|
||||||
|
:style="{ width: col.width }"
|
||||||
|
@click="emits('sort', col.key)"
|
||||||
|
class="sortable"
|
||||||
|
>
|
||||||
|
{{ col.label }}
|
||||||
|
<span v-if="sortKey === col.key" class="sort-indicator">
|
||||||
|
{{ sortDir === 'desc' ? '▼' : '▲' }}
|
||||||
|
</span>
|
||||||
|
</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
<tr v-for="(row, i) in sorted" :key="i">
|
||||||
|
<td v-for="col in columns" :key="col.key">
|
||||||
|
{{ row[col.key] }}
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
<tr v-if="rows.length === 0">
|
||||||
|
<td :colspan="columns.length" class="empty">No detections yet</td>
|
||||||
|
</tr>
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<style scoped>
|
||||||
|
.table-renderer {
|
||||||
|
overflow: auto;
|
||||||
|
height: 100%;
|
||||||
|
font-family: var(--font-mono);
|
||||||
|
font-size: var(--font-size-sm);
|
||||||
|
}
|
||||||
|
|
||||||
|
table {
|
||||||
|
width: 100%;
|
||||||
|
border-collapse: collapse;
|
||||||
|
}
|
||||||
|
|
||||||
|
th {
|
||||||
|
position: sticky;
|
||||||
|
top: 0;
|
||||||
|
background: var(--surface-2);
|
||||||
|
color: var(--text-secondary);
|
||||||
|
font-weight: 600;
|
||||||
|
text-align: left;
|
||||||
|
padding: var(--space-2) var(--space-3);
|
||||||
|
border-bottom: var(--panel-border);
|
||||||
|
cursor: pointer;
|
||||||
|
user-select: none;
|
||||||
|
white-space: nowrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
th:hover {
|
||||||
|
color: var(--text-primary);
|
||||||
|
}
|
||||||
|
|
||||||
|
.sort-indicator {
|
||||||
|
font-size: 9px;
|
||||||
|
margin-left: 4px;
|
||||||
|
}
|
||||||
|
|
||||||
|
td {
|
||||||
|
padding: var(--space-1) var(--space-3);
|
||||||
|
border-bottom: 1px solid var(--surface-3);
|
||||||
|
white-space: nowrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
tr:hover td {
|
||||||
|
background: var(--surface-3);
|
||||||
|
}
|
||||||
|
|
||||||
|
.empty {
|
||||||
|
color: var(--text-dim);
|
||||||
|
text-align: center;
|
||||||
|
padding: var(--space-6);
|
||||||
|
}
|
||||||
|
</style>
|
||||||
Reference in New Issue
Block a user