This commit is contained in:
2026-03-30 07:22:14 -03:00
parent d0707333fd
commit 4220b0418e
182 changed files with 3668 additions and 5231 deletions

0
core/detect/__init__.py Normal file
View File

View File

@@ -0,0 +1,19 @@
"""
Checkpoint system — Timeline + Checkpoint tree.
detect/checkpoint/
frames.py — frame image S3 upload/download
storage.py — Timeline + Checkpoint (Postgres + MinIO)
replay.py — replay (TODO: migrate to new model)
runner_bridge.py — checkpoint hook for PipelineRunner
"""
from .storage import (
create_timeline,
get_timeline_frames,
get_timeline_frames_b64,
save_stage_output,
load_stage_output,
)
from .frames import save_frames, load_frames
from .runner_bridge import checkpoint_after_stage, reset_checkpoint_state, get_timeline_id

View File

@@ -0,0 +1,111 @@
"""Frame image storage — save/load to S3/MinIO as JPEGs."""
from __future__ import annotations
import logging
import os
import tempfile
import numpy as np
from PIL import Image
from core.detect.models import Frame
logger = logging.getLogger(__name__)
BUCKET = os.environ.get("S3_BUCKET", "mpr")
CHECKPOINT_PREFIX = "checkpoints"
def save_frames(job_id: str, frames: list[Frame]) -> dict[int, str]:
"""
Save frame images to S3 as JPEGs.
Returns manifest: {sequence: s3_key}
"""
from core.storage.s3 import upload_file
manifest = {}
for frame in frames:
key = f"{CHECKPOINT_PREFIX}/{job_id}/frames/{frame.sequence}.jpg"
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
img = Image.fromarray(frame.image)
img.save(tmp, format="JPEG", quality=85)
tmp_path = tmp.name
try:
upload_file(tmp_path, BUCKET, key)
finally:
os.unlink(tmp_path)
manifest[frame.sequence] = key
logger.info("Saved %d frames to s3://%s/%s/%s/frames/",
len(frames), BUCKET, CHECKPOINT_PREFIX, job_id)
return manifest
def load_frames(manifest: dict[int, str], frame_metadata: list[dict]) -> list[Frame]:
"""
Load frame images from S3 and reconstitute Frame objects.
frame_metadata: list of dicts with sequence, chunk_id, timestamp, perceptual_hash.
"""
from core.storage.s3 import download_to_temp
meta_map = {m["sequence"]: m for m in frame_metadata}
frames = []
for seq, key in manifest.items():
tmp_path = download_to_temp(BUCKET, key)
try:
img = Image.open(tmp_path).convert("RGB")
image_array = np.array(img)
finally:
os.unlink(tmp_path)
meta = meta_map.get(seq, {})
frame = Frame(
sequence=seq,
chunk_id=meta.get("chunk_id", 0),
timestamp=meta.get("timestamp", 0.0),
image=image_array,
perceptual_hash=meta.get("perceptual_hash", ""),
)
frames.append(frame)
frames.sort(key=lambda f: f.sequence)
return frames
def load_frames_b64(manifest: dict[int, str], frame_metadata: list[dict]) -> list[dict]:
"""
Load frame images from S3 as base64 JPEG — lightweight, no numpy.
Returns list of dicts: {seq, timestamp, jpeg_b64}
"""
import base64
from core.storage.s3 import download_to_temp
meta_map = {m["sequence"]: m for m in frame_metadata}
frames = []
for seq, key in manifest.items():
tmp_path = download_to_temp(BUCKET, key)
try:
with open(tmp_path, "rb") as f:
jpeg_bytes = f.read()
finally:
os.unlink(tmp_path)
meta = meta_map.get(seq, {})
frames.append({
"seq": seq,
"timestamp": meta.get("timestamp", 0.0),
"jpeg_b64": base64.b64encode(jpeg_bytes).decode(),
})
frames.sort(key=lambda f: f["seq"])
return frames

View File

@@ -0,0 +1,232 @@
"""
Pipeline replay — re-run from any stage with different config.
Loads a checkpoint, applies config overrides, builds a subgraph
starting from the target stage, and invokes it.
"""
from __future__ import annotations
import logging
import uuid
from core.detect import emit
# TODO: migrate to Timeline/Branch/Checkpoint model
# These old functions no longer exist — replay needs rework
def _not_migrated(*args, **kwargs):
raise NotImplementedError("Replay not yet migrated to Timeline/Branch/Checkpoint model")
load_checkpoint = _not_migrated
list_checkpoints = _not_migrated
from core.detect.graph import NODES, build_graph
logger = logging.getLogger(__name__)
# OverrideProfile removed — config overrides are now handled by dict merging
# in _load_profile() (nodes.py) and replay_single_stage (below).
def replay_from(
job_id: str,
start_stage: str,
config_overrides: dict | None = None,
checkpoint: bool = True,
) -> dict:
"""
Replay the pipeline from a specific stage.
Loads the checkpoint from the stage immediately before start_stage,
applies config overrides, and runs the subgraph from start_stage onward.
Returns the final state dict.
"""
if start_stage not in NODES:
raise ValueError(f"Unknown stage: {start_stage!r}. Options: {NODES}")
start_idx = NODES.index(start_stage)
# Load checkpoint from the stage before start_stage
if start_idx == 0:
raise ValueError("Cannot replay from the first stage — just run the full pipeline")
previous_stage = NODES[start_idx - 1]
available = list_checkpoints(job_id)
if previous_stage not in available:
raise ValueError(
f"No checkpoint for stage {previous_stage!r} (job {job_id}). "
f"Available: {available}"
)
logger.info("Replaying job %s from %s (loading checkpoint: %s)",
job_id, start_stage, previous_stage)
state = load_checkpoint(job_id, previous_stage)
# Apply config overrides
if config_overrides:
state["config_overrides"] = config_overrides
# Set run context for SSE events
run_id = str(uuid.uuid4())[:8]
emit.set_run_context(
run_id=run_id,
parent_job_id=job_id,
run_type="replay",
)
# Build subgraph starting from start_stage
graph = build_graph(checkpoint=checkpoint, start_from=start_stage)
pipeline = graph.compile()
try:
result = pipeline.invoke(state)
finally:
emit.clear_run_context()
return result
def replay_single_stage(
job_id: str,
stage: str,
frame_refs: list[int] | None = None,
config_overrides: dict | None = None,
debug: bool = False,
) -> dict:
"""
Replay a single stage on specific frames (or all frames from checkpoint).
Fast path for interactive parameter tuning — runs only the target stage
function, not the full pipeline tail. Returns the stage output directly.
When debug=True and stage is detect_edges, returns additional overlay
data (Canny edges, Hough lines) for visual feedback in the editor.
For detect_edges: returns {"edge_regions_by_frame": {seq: [box, ...]}}
With debug=True, also returns {"debug": {seq: {edge_overlay_b64, lines_overlay_b64, ...}}}
"""
if stage not in NODES:
raise ValueError(f"Unknown stage: {stage!r}. Options: {NODES}")
stage_idx = NODES.index(stage)
if stage_idx == 0:
raise ValueError("Cannot replay the first stage — just run the full pipeline")
previous_stage = NODES[stage_idx - 1]
available = list_checkpoints(job_id)
if previous_stage not in available:
raise ValueError(
f"No checkpoint for stage {previous_stage!r} (job {job_id}). "
f"Available: {available}"
)
logger.info("Single-stage replay: job %s, stage %s (loading checkpoint: %s, debug=%s)",
job_id, stage, previous_stage, debug)
state = load_checkpoint(job_id, previous_stage)
# Build profile with overrides
from core.detect.profile import get_profile, get_stage_config
profile = get_profile(state.get("profile_name", "soccer_broadcast"))
if config_overrides:
merged_configs = dict(profile.get("configs", {}))
for sname, soverrides in config_overrides.items():
if sname in merged_configs:
merged_configs[sname] = {**merged_configs[sname], **soverrides}
else:
merged_configs[sname] = soverrides
profile = {**profile, "configs": merged_configs}
# Run the stage function directly (not through the graph)
if stage == "detect_edges":
return _replay_detect_edges(state, profile, frame_refs, job_id, debug)
else:
raise ValueError(
f"Single-stage replay not yet implemented for {stage!r}. "
f"Use replay_from() for full pipeline replay."
)
def _replay_detect_edges(
state: dict,
profile,
frame_refs: list[int] | None,
job_id: str,
debug: bool,
) -> dict:
"""Run edge detection on checkpoint frames, optionally with debug overlays."""
import os
from core.detect.stages.edge_detector import detect_edge_regions
from core.detect.profile import get_stage_config
from core.detect.stages.models import RegionAnalysisConfig
config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
frames = state.get("filtered_frames", [])
if frame_refs:
ref_set = set(frame_refs)
frames = [f for f in frames if f.sequence in ref_set]
inference_url = os.environ.get("INFERENCE_URL")
# Normal run — always needed for the boxes
result = detect_edge_regions(
frames=frames,
config=config,
inference_url=inference_url,
job_id=job_id,
)
output = {"edge_regions_by_frame": result}
# Debug overlays — call debug endpoint (remote) or local debug function
if debug and frames:
debug_data = {}
if inference_url:
from core.detect.inference import InferenceClient
client = InferenceClient(base_url=inference_url, job_id=job_id)
for frame in frames:
dr = client.detect_edges_debug(
image=frame.image,
edge_canny_low=config.edge_canny_low,
edge_canny_high=config.edge_canny_high,
edge_hough_threshold=config.edge_hough_threshold,
edge_hough_min_length=config.edge_hough_min_length,
edge_hough_max_gap=config.edge_hough_max_gap,
edge_pair_max_distance=config.edge_pair_max_distance,
edge_pair_min_distance=config.edge_pair_min_distance,
)
debug_data[frame.sequence] = {
"edge_overlay_b64": dr.edge_overlay_b64,
"lines_overlay_b64": dr.lines_overlay_b64,
"horizontal_count": dr.horizontal_count,
"pair_count": dr.pair_count,
}
else:
# Local mode — import GPU module directly
from core.detect.stages.edge_detector import _load_cv_edges
edges_mod = _load_cv_edges()
for frame in frames:
dr = edges_mod.detect_edges_debug(
frame.image,
canny_low=config.edge_canny_low,
canny_high=config.edge_canny_high,
hough_threshold=config.edge_hough_threshold,
hough_min_length=config.edge_hough_min_length,
hough_max_gap=config.edge_hough_max_gap,
pair_max_distance=config.edge_pair_max_distance,
pair_min_distance=config.edge_pair_min_distance,
)
debug_data[frame.sequence] = {
"edge_overlay_b64": dr["edge_overlay_b64"],
"lines_overlay_b64": dr["lines_overlay_b64"],
"horizontal_count": dr["horizontal_count"],
"pair_count": dr["pair_count"],
}
output["debug"] = debug_data
return output

View File

@@ -0,0 +1,97 @@
"""
Runner bridge — checkpoint hook called by PipelineRunner after each stage.
Owns the per-job state (timeline, frame manifest, checkpoint chain) that
the runner shouldn't know about.
Timeline and Job are independent entities:
- One Timeline can serve multiple Jobs (re-run with different params)
- One Job operates on one Timeline (set after frame extraction)
- Checkpoints belong to Timeline, tagged with the Job that created them
"""
from __future__ import annotations
import logging
logger = logging.getLogger(__name__)
# Per-job state
_timeline_id: dict[str, str] = {}
_frames_manifest: dict[str, dict[int, str]] = {}
_latest_checkpoint: dict[str, str] = {}
def reset_checkpoint_state(job_id: str):
"""Clean up per-job checkpoint state. Called when pipeline finishes."""
_timeline_id.pop(job_id, None)
_frames_manifest.pop(job_id, None)
_latest_checkpoint.pop(job_id, None)
def checkpoint_after_stage(job_id: str, stage_name: str, state: dict, result: dict):
"""
Save a checkpoint after a stage completes.
Called by the runner. Handles:
- Timeline creation (once, on extract_frames)
- Frame upload (via create_timeline)
- Stage output serialization (via stage registry)
- Checkpoint chain (parent → child)
"""
if not job_id:
return
from .storage import create_timeline, save_stage_output
from core.detect.stages.base import _REGISTRY
merged = {**state, **result}
# On extract_frames: create Timeline + upload frames + root checkpoint
if stage_name == "extract_frames" and job_id not in _timeline_id:
frames = merged.get("frames", [])
video_path = merged.get("video_path", "")
profile_name = merged.get("profile_name", "")
tid, cid = create_timeline(
source_video=video_path,
profile_name=profile_name,
frames=frames,
)
_timeline_id[job_id] = tid
_latest_checkpoint[job_id] = cid
logger.info("Job %s → Timeline %s (root checkpoint %s)", job_id, tid, cid)
# Emit timeline_id via SSE so the UI can use it for checkpoint loads
from core.detect import emit
emit.log(job_id, "Checkpoint", "INFO", f"timeline_id={tid}")
return
# For subsequent stages: save checkpoint on the timeline
tid = _timeline_id.get(job_id)
if not tid:
logger.warning("No timeline for job %s, skipping checkpoint", job_id)
return
# Serialize stage output using the stage's serialize_fn if available
stage_cls = _REGISTRY.get(stage_name)
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
if serialize_fn:
output_json = serialize_fn(merged, job_id)
else:
output_json = {}
parent_id = _latest_checkpoint.get(job_id)
new_checkpoint_id = save_stage_output(
timeline_id=tid,
parent_checkpoint_id=parent_id,
stage_name=stage_name,
output_json=output_json,
job_id=job_id,
)
_latest_checkpoint[job_id] = new_checkpoint_id
def get_timeline_id(job_id: str) -> str | None:
"""Get the timeline_id for a running job. Used by the UI to load checkpoints."""
return _timeline_id.get(job_id)

View File

@@ -0,0 +1,108 @@
"""
State serialization — DetectState ↔ JSON-compatible dict.
Delegates to each stage's serialize_fn/deserialize_fn via the registry.
This file has no model-specific knowledge — stages own their data format.
The only things serialized here are the "envelope" fields (job_id, video_path, etc.)
that don't belong to any stage.
"""
from __future__ import annotations
from core.schema.serializers._common import serialize_dataclass
from core.schema.serializers.pipeline import (
deserialize_pipeline_stats,
deserialize_text_candidates,
)
# Envelope fields — not owned by any stage, always present
ENVELOPE_KEYS = ["job_id", "video_path", "profile_name", "config_overrides"]
def serialize_state(state: dict, frames_manifest: dict[int, str]) -> dict:
"""
Serialize DetectState to a JSON-compatible dict.
Calls each registered stage's serialize_fn for stage-owned data.
Envelope fields (job_id, etc.) are copied directly.
"""
from core.detect.stages.base import _REGISTRY
checkpoint = {}
# Envelope
for key in ENVELOPE_KEYS:
default = {} if key == "config_overrides" else ""
checkpoint[key] = state.get(key, default)
# Frames manifest (needed by frame-loading stages)
checkpoint["frames_manifest"] = {str(k): v for k, v in frames_manifest.items()}
# Stats (shared across stages, not owned by one)
stats = state.get("stats")
if stats is not None:
checkpoint["stats"] = serialize_dataclass(stats)
else:
checkpoint["stats"] = {}
# Per-stage data
for name, stage_def in _REGISTRY.items():
if stage_def.serialize_fn is None:
continue
job_id = state.get("job_id", "")
stage_data = stage_def.serialize_fn(state, job_id)
checkpoint[f"stage_{name}"] = stage_data
return checkpoint
def deserialize_state(checkpoint: dict, frames: list) -> dict:
"""
Reconstitute DetectState from a checkpoint dict + loaded frames.
Calls each stage's deserialize_fn to restore stage-owned data.
"""
from core.detect.stages.base import _REGISTRY
frame_map = {f.sequence: f for f in frames}
state = {}
# Envelope
for key in ENVELOPE_KEYS:
default = {} if key == "config_overrides" else ""
state[key] = checkpoint.get(key, default)
# Frames (always present, loaded externally)
state["frames"] = frames
# Stats
state["stats"] = deserialize_pipeline_stats(checkpoint.get("stats", {}))
# Per-stage data
for name, stage_def in _REGISTRY.items():
if stage_def.deserialize_fn is None:
continue
stage_key = f"stage_{name}"
if stage_key not in checkpoint:
continue
job_id = state.get("job_id", "")
stage_data = stage_def.deserialize_fn(checkpoint[stage_key], job_id)
for k, v in stage_data.items():
if k == "_filtered_sequences":
# Reconnect filtered frames from sequence list
seq_set = set(v)
state["filtered_frames"] = [f for f in frames if f.sequence in seq_set]
elif k.endswith("_raw"):
# Raw text candidates need frame reference reconnection
real_key = k.removeprefix("_").removesuffix("_raw")
state[real_key] = deserialize_text_candidates(v, frame_map)
else:
state[k] = v
return state

View File

@@ -0,0 +1,177 @@
"""
Checkpoint storage — Timeline + Checkpoint (tree of snapshots).
Timeline: frame sequence from source video (frames in MinIO)
Checkpoint: snapshot of pipeline state (stage outputs as JSONB in Postgres)
parent_id forms a tree — multiple children = different config tries
"""
from __future__ import annotations
import logging
from uuid import UUID
from .frames import save_frames, load_frames, CHECKPOINT_PREFIX
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Timeline
# ---------------------------------------------------------------------------
def create_timeline(
source_video: str,
profile_name: str,
frames: list,
fps: float = 2.0,
source_asset_id: UUID | None = None,
) -> tuple[str, str]:
"""
Create a timeline from frames. Uploads frame images to MinIO,
creates Timeline + root Checkpoint in Postgres.
Returns (timeline_id, checkpoint_id).
"""
from core.db.models import Timeline, Checkpoint
from core.db.connection import get_session
with get_session() as session:
timeline = Timeline(
source_video=source_video,
profile_name=profile_name,
source_asset_id=source_asset_id,
fps=fps,
)
session.add(timeline)
session.flush()
tid = str(timeline.id)
# Upload frames to MinIO
manifest = save_frames(tid, frames)
frames_meta = [
{
"sequence": f.sequence,
"chunk_id": getattr(f, "chunk_id", 0),
"timestamp": f.timestamp,
"perceptual_hash": getattr(f, "perceptual_hash", ""),
}
for f in frames
]
timeline.frames_prefix = f"{CHECKPOINT_PREFIX}/{tid}/frames/"
timeline.frames_manifest = {str(k): v for k, v in manifest.items()}
timeline.frames_meta = frames_meta
checkpoint = Checkpoint(
timeline_id=timeline.id,
parent_id=None,
stage_outputs={},
stats={"frames_extracted": len(frames)},
)
session.add(checkpoint)
session.commit()
session.refresh(checkpoint)
cid = str(checkpoint.id)
logger.info("Timeline created: %s (%d frames, root checkpoint %s)", tid, len(frames), cid)
return tid, cid
def get_timeline_frames(timeline_id: str) -> list:
"""Load frames from a timeline (from MinIO) as Frame objects."""
from core.db.models import Timeline
from core.db.connection import get_session
with get_session() as session:
timeline = session.get(Timeline, UUID(timeline_id))
if not timeline:
raise ValueError(f"Timeline not found: {timeline_id}")
raw_manifest = timeline.frames_manifest or {}
manifest = {int(k): v for k, v in raw_manifest.items()}
return load_frames(manifest, timeline.frames_meta or [])
def get_timeline_frames_b64(timeline_id: str) -> list[dict]:
"""Load frames as base64 JPEG (lightweight, no numpy)."""
from core.db.models import Timeline
from core.db.connection import get_session
from .frames import load_frames_b64
with get_session() as session:
timeline = session.get(Timeline, UUID(timeline_id))
if not timeline:
raise ValueError(f"Timeline not found: {timeline_id}")
raw_manifest = timeline.frames_manifest or {}
manifest = {int(k): v for k, v in raw_manifest.items()}
return load_frames_b64(manifest, timeline.frames_meta or [])
# ---------------------------------------------------------------------------
# Checkpoint
# ---------------------------------------------------------------------------
def save_stage_output(
timeline_id: str,
parent_checkpoint_id: str | None,
stage_name: str,
output_json: dict,
config_overrides: dict | None = None,
stats: dict | None = None,
is_scenario: bool = False,
scenario_label: str = "",
job_id: str | None = None,
) -> str:
"""
Save a stage's output as a new checkpoint (child of parent).
Carries forward stage outputs from parent + adds the new one.
Returns the new checkpoint ID.
"""
from core.db.models import Checkpoint
from core.db.connection import get_session
with get_session() as session:
parent_outputs = {}
parent_stats = {}
parent_config = {}
if parent_checkpoint_id:
parent = session.get(Checkpoint, UUID(parent_checkpoint_id))
if parent:
parent_outputs = dict(parent.stage_outputs or {})
parent_stats = dict(parent.stats or {})
parent_config = dict(parent.config_overrides or {})
checkpoint = Checkpoint(
timeline_id=UUID(timeline_id),
job_id=UUID(job_id) if job_id else None,
parent_id=UUID(parent_checkpoint_id) if parent_checkpoint_id else None,
stage_outputs={**parent_outputs, stage_name: output_json},
config_overrides={**parent_config, **(config_overrides or {})},
stats={**parent_stats, **(stats or {})},
is_scenario=is_scenario,
scenario_label=scenario_label,
)
session.add(checkpoint)
session.commit()
session.refresh(checkpoint)
cid = str(checkpoint.id)
logger.info("Checkpoint saved: %s (timeline %s, stage %s, parent %s)",
cid, timeline_id, stage_name, parent_checkpoint_id)
return cid
def load_stage_output(checkpoint_id: str, stage_name: str) -> dict | None:
"""Load a stage's output from a checkpoint."""
from core.db.models import Checkpoint
from core.db.connection import get_session
with get_session() as session:
checkpoint = session.get(Checkpoint, UUID(checkpoint_id))
if not checkpoint:
return None
return (checkpoint.stage_outputs or {}).get(stage_name)

159
core/detect/emit.py Normal file
View File

@@ -0,0 +1,159 @@
"""
Event emission helpers for detection pipeline stages.
Single place that knows how to build event payloads.
Stages call these instead of constructing dicts or dataclasses directly.
Run context (run_id, parent_job_id) is set once at pipeline start via
set_run_context() and automatically injected into all events.
Log level is set per-run with optional per-stage overrides.
DEBUG events are only pushed when the run (or stage) log level allows it.
"""
from __future__ import annotations
import dataclasses
from datetime import datetime, timezone
from core.detect.events import push_detect_event
from core.detect.models import PipelineStats
# Log level ordering for comparison
_LEVEL_ORDER = {"DEBUG": 0, "INFO": 1, "WARN": 2, "ERROR": 3}
# Module-level run context — set once per pipeline invocation
_run_context: dict = {}
_run_log_level: str = "INFO"
_stage_log_levels: dict[str, str] = {} # stage_name → level override
def set_run_context(
run_id: str = "",
parent_job_id: str = "",
run_type: str = "initial",
log_level: str = "INFO",
):
"""Set the run context for all subsequent events in this pipeline invocation."""
global _run_context, _run_log_level
_run_context = {
"run_id": run_id,
"parent_job_id": parent_job_id,
"run_type": run_type,
}
_run_log_level = log_level.upper()
_stage_log_levels.clear()
def set_stage_log_level(stage: str, level: str):
"""Override log level for a specific stage."""
_stage_log_levels[stage] = level.upper()
def clear_stage_log_level(stage: str):
"""Remove per-stage log level override."""
_stage_log_levels.pop(stage, None)
def clear_run_context():
global _run_context, _run_log_level
_run_context = {}
_run_log_level = "INFO"
_stage_log_levels.clear()
def _should_emit(level: str, stage: str) -> bool:
"""Check if this log level should be emitted given run/stage settings."""
effective = _stage_log_levels.get(stage, _run_log_level)
return _LEVEL_ORDER.get(level.upper(), 1) >= _LEVEL_ORDER.get(effective, 1)
def _inject_context(payload: dict) -> dict:
"""Add run context fields to an event payload."""
if _run_context:
payload.update(_run_context)
return payload
def log(job_id: str | None, stage: str, level: str, msg: str) -> None:
if not job_id:
return
if not _should_emit(level, stage):
return
payload = {
"level": level,
"stage": stage,
"msg": msg,
"ts": datetime.now(timezone.utc).isoformat(),
}
_inject_context(payload)
push_detect_event(job_id, "log", payload)
def stats(job_id: str | None, **kwargs) -> None:
if not job_id:
return
s = PipelineStats(**kwargs)
payload = dataclasses.asdict(s)
_inject_context(payload)
push_detect_event(job_id, "stats_update", payload)
def frame_update(
job_id: str | None,
frame_ref: int,
timestamp: float,
jpeg_b64: str,
boxes: list[dict],
) -> None:
if not job_id:
return
payload = {
"frame_ref": frame_ref,
"timestamp": timestamp,
"jpeg_b64": jpeg_b64,
"boxes": boxes,
}
_inject_context(payload)
push_detect_event(job_id, "frame_update", payload)
def graph_update(job_id: str | None, nodes: list[dict]) -> None:
if not job_id:
return
payload = {"nodes": nodes}
_inject_context(payload)
push_detect_event(job_id, "graph_update", payload)
def detection(
job_id: str | None,
brand: str,
confidence: float,
source: str,
timestamp: float,
duration: float = 0.0,
content_type: str = "",
frame_ref: int | None = None,
) -> None:
if not job_id:
return
payload = {
"brand": brand,
"confidence": confidence,
"source": source,
"timestamp": timestamp,
"duration": duration,
"content_type": content_type,
"frame_ref": frame_ref,
}
_inject_context(payload)
push_detect_event(job_id, "detection", payload)
def job_complete(job_id: str | None, report: dict) -> None:
if not job_id:
return
payload = {"job_id": job_id, "report": report}
_inject_context(payload)
push_detect_event(job_id, "job_complete", payload)

42
core/detect/events.py Normal file
View File

@@ -0,0 +1,42 @@
"""
Detection pipeline event helpers.
Non-generated runtime code for pushing SSE events.
The event payload types are in sse_contract.py (generated by modelgen).
"""
from pydantic import BaseModel
from core.events import push_event
DETECT_EVENTS_PREFIX = "detect_events"
# SSE event type names
EVENT_GRAPH_UPDATE = "graph_update"
EVENT_STATS_UPDATE = "stats_update"
EVENT_FRAME_UPDATE = "frame_update"
EVENT_DETECTION = "detection"
EVENT_LOG = "log"
EVENT_JOB_COMPLETE = "job_complete"
ALL_EVENT_TYPES = [
EVENT_GRAPH_UPDATE,
EVENT_STATS_UPDATE,
EVENT_FRAME_UPDATE,
EVENT_DETECTION,
EVENT_LOG,
EVENT_JOB_COMPLETE,
]
TERMINAL_EVENTS = [EVENT_JOB_COMPLETE]
def push_detect_event(job_id: str, event_type: str, data: BaseModel | dict) -> None:
"""Push a detection event to Redis. Accepts Pydantic models or plain dicts."""
payload = data.model_dump(mode="json") if isinstance(data, BaseModel) else data
push_event(
job_id=job_id,
event_type=event_type,
data=payload,
prefix=DETECT_EVENTS_PREFIX,
)

View File

@@ -0,0 +1,45 @@
"""
Detection pipeline graph.
detect/graph/
nodes.py — node functions (one per stage)
events.py — graph_update SSE emission
runner.py — PipelineRunner (config-driven, checkpoint, cancel, pause)
"""
from .nodes import NODES, NODE_FUNCTIONS
from .runner import (
PipelineCancelled,
PipelineRunner,
build_graph,
clear_cancel_check,
clear_pause,
get_pipeline,
init_pause,
is_paused,
pause_pipeline,
resume_pipeline,
set_cancel_check,
set_pause_after_stage,
step_pipeline,
)
from .events import _node_states
__all__ = [
"NODES",
"NODE_FUNCTIONS",
"PipelineCancelled",
"PipelineRunner",
"build_graph",
"get_pipeline",
"set_cancel_check",
"clear_cancel_check",
"init_pause",
"clear_pause",
"pause_pipeline",
"resume_pipeline",
"step_pipeline",
"set_pause_after_stage",
"is_paused",
"_node_states",
]

View File

@@ -0,0 +1,27 @@
"""
Graph event emission — node state tracking + SSE graph_update events.
"""
from __future__ import annotations
from core.detect import emit
from core.detect.state import DetectState
# Track node states across pipeline runs
_node_states: dict[str, dict[str, str]] = {}
def emit_transition(state: DetectState, node: str, status: str, node_list: list[str]):
"""Update node status and emit graph_update SSE event."""
job_id = state.get("job_id")
if not job_id:
return
if job_id not in _node_states:
_node_states[job_id] = {n: "pending" for n in node_list}
_node_states[job_id][node] = status
nodes = [{"id": n, "status": _node_states[job_id][n]} for n in node_list]
emit.graph_update(job_id, nodes)

366
core/detect/graph/nodes.py Normal file
View File

@@ -0,0 +1,366 @@
"""
Pipeline node functions — one per stage.
Each node: reads state, gets config from profile dict, runs stage logic,
emits transitions, returns output dict.
"""
from __future__ import annotations
import os
from core.detect import emit
from core.detect.models import CropContext, PipelineStats
from core.detect.profile import get_profile, get_stage_config, build_vlm_prompt, aggregate_detections
from core.detect.stages.models import (
DetectionConfig,
FieldSegmentationConfig,
FrameExtractionConfig,
OCRConfig,
RegionAnalysisConfig,
ResolverConfig,
SceneFilterConfig,
)
from core.detect.state import DetectState
from core.detect.stages.frame_extractor import extract_frames
from core.detect.stages.scene_filter import scene_filter
from core.detect.stages.field_segmentation import run_field_segmentation
from core.detect.stages.edge_detector import detect_edge_regions
from core.detect.stages.yolo_detector import detect_objects
from core.detect.stages.preprocess import preprocess_regions
from core.detect.stages.ocr_stage import run_ocr
from core.detect.stages.brand_resolver import resolve_brands
from core.detect.stages.vlm_local import escalate_vlm
from core.detect.stages.vlm_cloud import escalate_cloud
from core.detect.stages.aggregator import compile_report
from core.detect.tracing import trace_node, flush as flush_traces
from .events import emit_transition
INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode
NODES = [
"extract_frames",
"filter_scenes",
"field_segmentation",
"detect_edges",
"detect_objects",
"preprocess",
"run_ocr",
"match_brands",
"escalate_vlm",
"escalate_cloud",
"compile_report",
]
def _load_profile(state: DetectState) -> dict:
"""Load profile dict, apply config overrides if present."""
name = state.get("profile_name", "soccer_broadcast")
profile = get_profile(name)
overrides = state.get("config_overrides")
if overrides:
# Merge overrides into a copy of the profile configs
merged_configs = dict(profile.get("configs", {}))
for stage_name, stage_overrides in overrides.items():
if stage_name in merged_configs:
merged_configs[stage_name] = {**merged_configs[stage_name], **stage_overrides}
else:
merged_configs[stage_name] = stage_overrides
profile = {**profile, "configs": merged_configs}
return profile
def _emit(state, node, status):
emit_transition(state, node, status, NODES)
# --- Node functions ---
def node_extract_frames(state: DetectState) -> dict:
job_id = state.get("job_id", "")
if job_id and not emit._run_context:
emit.set_run_context(run_id=job_id, parent_job_id=job_id, run_type="initial")
source_asset_id = state.get("source_asset_id")
if source_asset_id and not state.get("session_brands"):
from core.detect.stages.brand_resolver import build_session_dict
session_brands = build_session_dict(source_asset_id)
state["session_brands"] = session_brands
_emit(state, "extract_frames", "running")
with trace_node(state, "extract_frames") as span:
profile = _load_profile(state)
config = FrameExtractionConfig(**get_stage_config(profile, "extract_frames"))
frames = extract_frames(state["video_path"], config, job_id=job_id)
span.set_output({"frames_extracted": len(frames)})
_emit(state, "extract_frames", "done")
return {"frames": frames, "stats": PipelineStats(frames_extracted=len(frames))}
def node_filter_scenes(state: DetectState) -> dict:
_emit(state, "filter_scenes", "running")
with trace_node(state, "filter_scenes") as span:
profile = _load_profile(state)
config = SceneFilterConfig(**get_stage_config(profile, "filter_scenes"))
frames = state.get("frames", [])
kept = scene_filter(frames, config, job_id=state.get("job_id"))
span.set_output({"frames_in": len(frames), "frames_kept": len(kept)})
stats = state.get("stats", PipelineStats())
stats.frames_after_scene_filter = len(kept)
_emit(state, "filter_scenes", "done")
return {"filtered_frames": kept, "stats": stats}
def node_field_segmentation(state: DetectState) -> dict:
_emit(state, "field_segmentation", "running")
with trace_node(state, "field_segmentation") as span:
profile = _load_profile(state)
config = FieldSegmentationConfig(**get_stage_config(profile, "field_segmentation"))
frames = state.get("filtered_frames", [])
job_id = state.get("job_id")
result = run_field_segmentation(frames, config, inference_url=INFERENCE_URL, job_id=job_id)
span.set_output({
"frames": len(frames),
"avg_coverage": sum(result["field_coverage"].values()) / max(len(result["field_coverage"]), 1),
})
_emit(state, "field_segmentation", "done")
return {
"field_masks": result["field_masks"],
"field_boundaries": result["field_boundaries"],
"field_coverage": result["field_coverage"],
}
def node_detect_edges(state: DetectState) -> dict:
_emit(state, "detect_edges", "running")
with trace_node(state, "detect_edges") as span:
profile = _load_profile(state)
config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
frames = state.get("filtered_frames", [])
field_masks = state.get("field_masks", {})
job_id = state.get("job_id")
regions = detect_edge_regions(
frames, config, inference_url=INFERENCE_URL, job_id=job_id,
field_masks=field_masks,
)
total = sum(len(r) for r in regions.values())
span.set_output({"frames": len(frames), "edge_regions": total})
stats = state.get("stats", PipelineStats())
stats.cv_regions_detected = total
_emit(state, "detect_edges", "done")
return {"edge_regions_by_frame": regions, "stats": stats}
def node_detect_objects(state: DetectState) -> dict:
_emit(state, "detect_objects", "running")
with trace_node(state, "detect_objects") as span:
profile = _load_profile(state)
config = DetectionConfig(**get_stage_config(profile, "detect_objects"))
frames = state.get("filtered_frames", [])
job_id = state.get("job_id")
all_boxes = detect_objects(frames, config, inference_url=INFERENCE_URL, job_id=job_id)
total_regions = sum(len(boxes) for boxes in all_boxes.values())
span.set_output({"frames": len(frames), "regions_detected": total_regions})
stats = state.get("stats", PipelineStats())
stats.regions_detected = total_regions
_emit(state, "detect_objects", "done")
return {"boxes_by_frame": all_boxes, "stats": stats}
def node_preprocess(state: DetectState) -> dict:
_emit(state, "preprocess", "running")
with trace_node(state, "preprocess") as span:
profile = _load_profile(state)
prep_config = get_stage_config(profile, "preprocess")
frames = state.get("filtered_frames", [])
boxes = state.get("boxes_by_frame", {})
job_id = state.get("job_id")
do_contrast = prep_config.get("contrast", True)
do_deskew = prep_config.get("deskew", False)
do_binarize = prep_config.get("binarize", False)
preprocessed = preprocess_regions(
frames, boxes,
do_contrast=do_contrast,
do_deskew=do_deskew,
do_binarize=do_binarize,
inference_url=INFERENCE_URL,
job_id=job_id,
)
span.set_output({"regions_preprocessed": len(preprocessed)})
_emit(state, "preprocess", "done")
return {"preprocessed_crops": preprocessed}
def node_run_ocr(state: DetectState) -> dict:
_emit(state, "run_ocr", "running")
with trace_node(state, "run_ocr") as span:
profile = _load_profile(state)
config = OCRConfig(**get_stage_config(profile, "run_ocr"))
frames = state.get("filtered_frames", [])
boxes = state.get("boxes_by_frame", {})
job_id = state.get("job_id")
candidates = run_ocr(frames, boxes, config, inference_url=INFERENCE_URL, job_id=job_id)
span.set_output({"regions_in": sum(len(b) for b in boxes.values()), "text_candidates": len(candidates)})
stats = state.get("stats", PipelineStats())
stats.regions_resolved_by_ocr = len(candidates)
_emit(state, "run_ocr", "done")
return {"text_candidates": candidates, "stats": stats}
def node_match_brands(state: DetectState) -> dict:
_emit(state, "match_brands", "running")
with trace_node(state, "match_brands") as span:
profile = _load_profile(state)
config = ResolverConfig(**get_stage_config(profile, "match_brands"))
candidates = state.get("text_candidates", [])
session_brands = state.get("session_brands", {})
job_id = state.get("job_id")
source_asset_id = state.get("source_asset_id")
matched, unresolved = resolve_brands(
candidates, config,
session_brands=session_brands,
source_asset_id=source_asset_id,
content_type=profile["name"], job_id=job_id,
)
span.set_output({"matched": len(matched), "unresolved": len(unresolved)})
_emit(state, "match_brands", "done")
return {"detections": matched, "unresolved_candidates": unresolved}
def node_escalate_vlm(state: DetectState) -> dict:
_emit(state, "escalate_vlm", "running")
with trace_node(state, "escalate_vlm") as span:
profile = _load_profile(state)
vlm_config = get_stage_config(profile, "escalate_vlm")
vlm_template = vlm_config.get("vlm_prompt_template", "Identify the brand in this image.")
candidates = state.get("unresolved_candidates", [])
job_id = state.get("job_id")
vlm_prompt_fn = lambda ctx: build_vlm_prompt(ctx, vlm_template)
vlm_matched, still_unresolved = escalate_vlm(
candidates,
vlm_prompt_fn=vlm_prompt_fn,
inference_url=INFERENCE_URL,
content_type=profile["name"],
source_asset_id=state.get("source_asset_id"),
job_id=job_id,
)
stats = state.get("stats", PipelineStats())
stats.regions_escalated_to_local_vlm = len(candidates)
span.set_output({"candidates": len(candidates), "matched": len(vlm_matched),
"still_unresolved": len(still_unresolved)})
existing = state.get("detections", [])
vlm_skipped = os.environ.get("SKIP_VLM", "").strip() == "1"
_emit(state, "escalate_vlm", "skipped" if vlm_skipped else "done")
return {
"detections": existing + vlm_matched,
"unresolved_candidates": still_unresolved,
"stats": stats,
}
def node_escalate_cloud(state: DetectState) -> dict:
_emit(state, "escalate_cloud", "running")
with trace_node(state, "escalate_cloud") as span:
profile = _load_profile(state)
vlm_config = get_stage_config(profile, "escalate_vlm")
vlm_template = vlm_config.get("vlm_prompt_template", "Identify the brand in this image.")
candidates = state.get("unresolved_candidates", [])
job_id = state.get("job_id")
stats = state.get("stats", PipelineStats())
vlm_prompt_fn = lambda ctx: build_vlm_prompt(ctx, vlm_template)
cloud_matched = escalate_cloud(
candidates,
vlm_prompt_fn=vlm_prompt_fn,
stats=stats,
content_type=profile["name"],
source_asset_id=state.get("source_asset_id"),
job_id=job_id,
)
span.set_output({"candidates": len(candidates), "matched": len(cloud_matched),
"cloud_calls": stats.cloud_llm_calls,
"cost_usd": stats.estimated_cloud_cost_usd})
existing = state.get("detections", [])
cloud_skipped = os.environ.get("SKIP_CLOUD", "").strip() == "1"
_emit(state, "escalate_cloud", "skipped" if cloud_skipped else "done")
return {"detections": existing + cloud_matched, "stats": stats}
def node_compile_report(state: DetectState) -> dict:
_emit(state, "compile_report", "running")
with trace_node(state, "compile_report") as span:
profile = _load_profile(state)
detections = state.get("detections", [])
stats = state.get("stats", PipelineStats())
job_id = state.get("job_id")
report = compile_report(
detections=detections,
stats=stats,
video_source=state.get("video_path", ""),
content_type=profile["name"],
job_id=job_id,
)
span.set_output({"brands": len(report.brands), "detections": len(report.timeline)})
flush_traces()
_emit(state, "compile_report", "done")
return {"report": report}
NODE_FUNCTIONS = [
("extract_frames", node_extract_frames),
("filter_scenes", node_filter_scenes),
("field_segmentation", node_field_segmentation),
("detect_edges", node_detect_edges),
("detect_objects", node_detect_objects),
("preprocess", node_preprocess),
("run_ocr", node_run_ocr),
("match_brands", node_match_brands),
("escalate_vlm", node_escalate_vlm),
("escalate_cloud", node_escalate_cloud),
("compile_report", node_compile_report),
]

274
core/detect/graph/runner.py Normal file
View File

@@ -0,0 +1,274 @@
"""
Pipeline runner — executes stages sequentially with checkpointing,
cancellation, and pause/resume.
Reads PipelineConfig from the profile to determine what stages to run.
Flattens the graph into a linear sequence for now (serial execution).
Executor socket: all stages run via LocalExecutor (call function directly).
"""
from __future__ import annotations
import logging
import os
import threading
from core.detect.stages.models import PipelineConfig
from core.detect.state import DetectState
from .nodes import NODES, NODE_FUNCTIONS
logger = logging.getLogger(__name__)
_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1"
class PipelineCancelled(Exception):
"""Raised when a pipeline run is cancelled."""
pass
class PipelinePaused(Exception):
"""Raised when a pipeline is paused (internally, for flow control)."""
pass
# ---------------------------------------------------------------------------
# Cancellation — checked before each node
# ---------------------------------------------------------------------------
_cancel_check: dict[str, callable] = {}
def set_cancel_check(job_id: str, fn):
_cancel_check[job_id] = fn
def clear_cancel_check(job_id: str):
_cancel_check.pop(job_id, None)
# ---------------------------------------------------------------------------
# Pause / Resume / Step — checked after each node completes
# ---------------------------------------------------------------------------
_pause_gate: dict[str, threading.Event] = {}
_pause_after_stage: dict[str, bool] = {}
def init_pause(job_id: str, pause_after_stage: bool = False):
"""Initialize pause state for a job. Called when pipeline starts."""
gate = threading.Event()
gate.set() # start unpaused
_pause_gate[job_id] = gate
_pause_after_stage[job_id] = pause_after_stage
def clear_pause(job_id: str):
"""Clean up pause state. Called when pipeline finishes."""
_pause_gate.pop(job_id, None)
_pause_after_stage.pop(job_id, None)
def pause_pipeline(job_id: str):
"""Pause a running pipeline. It will block after the current stage completes."""
gate = _pause_gate.get(job_id)
if gate:
gate.clear()
logger.info("Pipeline %s paused", job_id)
def resume_pipeline(job_id: str):
"""Resume a paused pipeline."""
gate = _pause_gate.get(job_id)
if gate:
gate.set()
logger.info("Pipeline %s resumed", job_id)
def step_pipeline(job_id: str):
"""Run one stage then pause again."""
_pause_after_stage[job_id] = True
gate = _pause_gate.get(job_id)
if gate:
gate.set()
logger.info("Pipeline %s stepping", job_id)
def set_pause_after_stage(job_id: str, enabled: bool):
"""Toggle pause-after-each-stage mode."""
_pause_after_stage[job_id] = enabled
if not enabled:
gate = _pause_gate.get(job_id)
if gate:
gate.set()
def is_paused(job_id: str) -> bool:
"""Check if a pipeline is currently paused."""
gate = _pause_gate.get(job_id)
return gate is not None and not gate.is_set()
def _wait_if_paused(job_id: str, node_name: str):
"""Block until resumed. Called after each node completes."""
gate = _pause_gate.get(job_id)
if gate is None:
return
if _pause_after_stage.get(job_id, False):
gate.clear()
from core.detect import emit
emit.log(job_id, "Pipeline", "INFO", f"Paused after {node_name}")
while not gate.wait(timeout=0.5):
check = _cancel_check.get(job_id)
if check and check():
raise PipelineCancelled(f"Cancelled while paused before next stage")
# ---------------------------------------------------------------------------
# Pipeline Runner
# ---------------------------------------------------------------------------
# Node function lookup — maps stage name to callable
_NODE_FN_MAP: dict[str, callable] = {name: fn for name, fn in NODE_FUNCTIONS}
def _flatten_config(config: PipelineConfig, start_from: str | None = None) -> list[str]:
"""
Flatten a PipelineConfig into a linear stage sequence.
For now: topological sort via edges. Falls back to stage order if no edges.
Respects start_from for replay (skip stages before it).
"""
if not config.edges:
# No edges defined — use stage order as-is
names = [s.name for s in config.stages]
else:
# Topological sort from edges
graph: dict[str, list[str]] = {}
in_degree: dict[str, int] = {}
stage_names = {s.name for s in config.stages}
for name in stage_names:
graph[name] = []
in_degree[name] = 0
for edge in config.edges:
if edge.source in stage_names and edge.target in stage_names:
graph[edge.source].append(edge.target)
in_degree[edge.target] = in_degree.get(edge.target, 0) + 1
# Kahn's algorithm
queue = [n for n in stage_names if in_degree.get(n, 0) == 0]
# Stable sort: prefer order from config.stages
stage_order = {s.name: i for i, s in enumerate(config.stages)}
queue.sort(key=lambda n: stage_order.get(n, 999))
names = []
while queue:
node = queue.pop(0)
names.append(node)
for neighbor in graph.get(node, []):
in_degree[neighbor] -= 1
if in_degree[neighbor] == 0:
queue.append(neighbor)
queue.sort(key=lambda n: stage_order.get(n, 999))
if start_from:
try:
idx = names.index(start_from)
names = names[idx:]
except ValueError:
raise ValueError(f"Stage {start_from!r} not in pipeline config")
return names
class PipelineRunner:
"""
Executes a pipeline defined by PipelineConfig.
Runs stages sequentially (flattened). Each stage:
1. Check cancel
2. Run node function (via executor — local for now)
3. Merge result into state
4. Checkpoint (if enabled)
5. Check pause
Executor socket: currently calls node functions directly.
Future: dispatch to LocalExecutor / GrpcExecutor / LambdaExecutor
based on StageRef.execution_target.
"""
def __init__(
self,
config: PipelineConfig,
checkpoint: bool = False,
start_from: str | None = None,
):
self.config = config
self.do_checkpoint = checkpoint
self.stage_sequence = _flatten_config(config, start_from)
def invoke(self, state: DetectState) -> DetectState:
"""Run the pipeline on the given state. Returns final state."""
for stage_name in self.stage_sequence:
job_id = state.get("job_id", "")
# 1. Cancel check
check = _cancel_check.get(job_id)
if check and check():
raise PipelineCancelled(f"Cancelled before {stage_name}")
# 2. Run node function
node_fn = _NODE_FN_MAP.get(stage_name)
if node_fn is None:
logger.warning("No node function for stage %s, skipping", stage_name)
continue
result = node_fn(state)
# 3. Merge result into state
state.update(result)
# 4. Checkpoint
if self.do_checkpoint:
from core.detect.checkpoint import checkpoint_after_stage
checkpoint_after_stage(job_id, stage_name, state, result)
# 5. Pause check
_wait_if_paused(job_id, stage_name)
return state
# ---------------------------------------------------------------------------
# Public API — backwards compatible with old get_pipeline/build_graph
# ---------------------------------------------------------------------------
def get_pipeline(
checkpoint: bool | None = None,
profile_name: str = "soccer_broadcast",
start_from: str | None = None,
) -> PipelineRunner:
"""Return a PipelineRunner for the given profile."""
from core.detect.profile import get_profile, pipeline_config_from_dict
do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED
profile = get_profile(profile_name)
config = pipeline_config_from_dict(profile["pipeline"])
return PipelineRunner(
config=config,
checkpoint=do_checkpoint,
start_from=start_from,
)
def build_graph(checkpoint: bool | None = None, start_from: str | None = None):
"""Backwards-compatible wrapper. Returns a PipelineRunner."""
return get_pipeline(checkpoint=checkpoint, start_from=start_from)

View File

@@ -0,0 +1,4 @@
from .client import InferenceClient
from .types import DetectResult, OCRResult, VLMResult
__all__ = ["InferenceClient", "DetectResult", "OCRResult", "VLMResult"]

View File

@@ -0,0 +1,262 @@
"""
HTTP client for the inference server.
The pipeline stages call this instead of importing ML libraries directly.
The inference server runs on the GPU machine (or spot instance).
"""
from __future__ import annotations
import base64
import io
import logging
import os
import numpy as np
import requests
from PIL import Image
from .types import DetectResult, OCRResult, RegionDebugResult, RegionResult, ServerStatus, VLMResult
logger = logging.getLogger(__name__)
DEFAULT_URL = os.environ.get("INFERENCE_URL", "http://localhost:8000")
def _encode_image(image: np.ndarray) -> str:
"""Encode numpy array as base64 JPEG."""
img = Image.fromarray(image)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=85)
return base64.b64encode(buf.getvalue()).decode()
class InferenceClient:
"""HTTP client for the GPU inference server."""
def __init__(self, base_url: str | None = None, timeout: float = 60.0,
job_id: str = "", log_level: str = "INFO"):
self.base_url = (base_url or DEFAULT_URL).rstrip("/")
self.timeout = timeout
self.job_id = job_id
self.log_level = log_level
self.session = requests.Session()
if job_id:
self.session.headers["X-Job-Id"] = job_id
self.session.headers["X-Log-Level"] = log_level
def health(self) -> ServerStatus:
"""Check server health and loaded models."""
resp = self.session.get(f"{self.base_url}/health", timeout=self.timeout)
resp.raise_for_status()
data = resp.json()
return ServerStatus(
loaded_models=data.get("loaded_models", []),
vram_used_mb=data.get("vram_used_mb", 0),
vram_budget_mb=data.get("vram_budget_mb", 0),
strategy=data.get("strategy", "sequential"),
)
def detect(
self,
image: np.ndarray,
model: str = "yolov8n",
confidence: float = 0.3,
target_classes: list[str] | None = None,
) -> list[DetectResult]:
"""Run object detection on an image."""
payload = {
"image": _encode_image(image),
"model": model,
"confidence": confidence,
}
if target_classes:
payload["target_classes"] = target_classes
resp = self.session.post(
f"{self.base_url}/detect",
json=payload,
timeout=self.timeout,
)
resp.raise_for_status()
results = []
for d in resp.json().get("detections", []):
result = DetectResult(
x=d["x"], y=d["y"], w=d["w"], h=d["h"],
confidence=d["confidence"], label=d["label"],
)
results.append(result)
return results
def ocr(
self,
image: np.ndarray,
languages: list[str] | None = None,
) -> list[OCRResult]:
"""Run OCR on an image region."""
payload = {
"image": _encode_image(image),
}
if languages:
payload["languages"] = languages
resp = self.session.post(
f"{self.base_url}/ocr",
json=payload,
timeout=self.timeout,
)
resp.raise_for_status()
results = []
for d in resp.json().get("results", []):
result = OCRResult(
text=d["text"],
confidence=d["confidence"],
bbox=tuple(d["bbox"]),
)
results.append(result)
return results
def vlm(
self,
image: np.ndarray,
prompt: str,
model: str = "moondream2",
) -> VLMResult:
"""Query a visual language model with an image crop + prompt."""
payload = {
"image": _encode_image(image),
"prompt": prompt,
"model": model,
}
resp = self.session.post(
f"{self.base_url}/vlm",
json=payload,
timeout=self.timeout,
)
resp.raise_for_status()
data = resp.json()
return VLMResult(
brand=data.get("brand", ""),
confidence=data.get("confidence", 0.0),
reasoning=data.get("reasoning", ""),
)
def detect_edges(
self,
image: np.ndarray,
edge_canny_low: int = 50,
edge_canny_high: int = 150,
edge_hough_threshold: int = 80,
edge_hough_min_length: int = 100,
edge_hough_max_gap: int = 10,
edge_pair_max_distance: int = 200,
edge_pair_min_distance: int = 15,
) -> list[RegionResult]:
"""Run edge detection on an image."""
payload = {
"image": _encode_image(image),
"edge_canny_low": edge_canny_low,
"edge_canny_high": edge_canny_high,
"edge_hough_threshold": edge_hough_threshold,
"edge_hough_min_length": edge_hough_min_length,
"edge_hough_max_gap": edge_hough_max_gap,
"edge_pair_max_distance": edge_pair_max_distance,
"edge_pair_min_distance": edge_pair_min_distance,
}
resp = self.session.post(
f"{self.base_url}/detect_edges",
json=payload,
timeout=self.timeout,
)
resp.raise_for_status()
results = []
for r in resp.json().get("regions", []):
result = RegionResult(
x=r["x"], y=r["y"], w=r["w"], h=r["h"],
confidence=r["confidence"], label=r["label"],
)
results.append(result)
return results
def detect_edges_debug(
self,
image: np.ndarray,
edge_canny_low: int = 50,
edge_canny_high: int = 150,
edge_hough_threshold: int = 80,
edge_hough_min_length: int = 100,
edge_hough_max_gap: int = 10,
edge_pair_max_distance: int = 200,
edge_pair_min_distance: int = 15,
) -> RegionDebugResult:
"""Run edge detection with debug overlays."""
payload = {
"image": _encode_image(image),
"edge_canny_low": edge_canny_low,
"edge_canny_high": edge_canny_high,
"edge_hough_threshold": edge_hough_threshold,
"edge_hough_min_length": edge_hough_min_length,
"edge_hough_max_gap": edge_hough_max_gap,
"edge_pair_max_distance": edge_pair_max_distance,
"edge_pair_min_distance": edge_pair_min_distance,
}
resp = self.session.post(
f"{self.base_url}/detect_edges/debug",
json=payload,
timeout=self.timeout,
)
resp.raise_for_status()
data = resp.json()
regions = []
for r in data.get("regions", []):
region = RegionResult(
x=r["x"], y=r["y"], w=r["w"], h=r["h"],
confidence=r["confidence"], label=r["label"],
)
regions.append(region)
return RegionDebugResult(
regions=regions,
edge_overlay_b64=data.get("edge_overlay_b64", ""),
lines_overlay_b64=data.get("lines_overlay_b64", ""),
horizontal_count=data.get("horizontal_count", 0),
pair_count=data.get("pair_count", 0),
)
def post(self, path: str, payload: dict) -> dict | None:
"""Generic POST to the inference server. Returns JSON response or None on error."""
try:
resp = self.session.post(
f"{self.base_url}{path}",
json=payload,
timeout=self.timeout,
)
resp.raise_for_status()
return resp.json()
except Exception as e:
logger.warning("Inference POST %s failed: %s", path, e)
return None
def load_model(self, model: str, quantization: str = "fp16") -> None:
"""Request the server to load a model into VRAM."""
self.session.post(
f"{self.base_url}/models/load",
json={"model": model, "quantization": quantization},
timeout=self.timeout,
).raise_for_status()
def unload_model(self, model: str) -> None:
"""Request the server to unload a model from VRAM."""
self.session.post(
f"{self.base_url}/models/unload",
json={"model": model},
timeout=self.timeout,
).raise_for_status()

View File

@@ -0,0 +1,76 @@
"""
Inference response types.
These are the shapes returned by the inference server.
Kept separate from core.detect.models to avoid coupling the
inference protocol to pipeline internals.
"""
from __future__ import annotations
from dataclasses import dataclass, field
@dataclass
class DetectResult:
"""Single object detection from YOLO or similar."""
x: int
y: int
w: int
h: int
confidence: float
label: str
@dataclass
class OCRResult:
"""Text extracted from a region."""
text: str
confidence: float
bbox: tuple[int, int, int, int] # x, y, w, h
@dataclass
class VLMResult:
"""Visual language model response for a crop."""
brand: str
confidence: float
reasoning: str
@dataclass
class RegionResult:
"""A candidate region from CV analysis."""
x: int
y: int
w: int
h: int
confidence: float
label: str
@dataclass
class RegionDebugResult:
"""CV region analysis with debug overlays."""
regions: list[RegionResult] = field(default_factory=list)
edge_overlay_b64: str = ""
lines_overlay_b64: str = ""
horizontal_count: int = 0
pair_count: int = 0
@dataclass
class ModelInfo:
"""Info about a loaded model."""
name: str
vram_mb: float
quantization: str # fp32, fp16, int8, int4
@dataclass
class ServerStatus:
"""Inference server health response."""
loaded_models: list[ModelInfo] = field(default_factory=list)
vram_used_mb: float = 0.0
vram_budget_mb: float = 0.0
strategy: str = "sequential" # sequential, concurrent, auto

95
core/detect/models.py Normal file
View File

@@ -0,0 +1,95 @@
"""
Detection pipeline runtime models.
These are the data structures that flow between pipeline stages.
They contain runtime types (np.ndarray) so they live here, not in
core/schema/models/ (which is for modelgen source of truth).
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Literal
import numpy as np
@dataclass
class Frame:
sequence: int
chunk_id: int
timestamp: float # position in video (seconds)
image: np.ndarray
perceptual_hash: str = ""
@dataclass
class BoundingBox:
x: int
y: int
w: int
h: int
confidence: float
label: str
@dataclass
class TextCandidate:
frame: Frame
bbox: BoundingBox
text: str
ocr_confidence: float
@dataclass
class BrandDetection:
brand: str
timestamp: float
duration: float
confidence: float
source: Literal["ocr", "local_vlm", "cloud_llm", "logo_match", "auxiliary"]
bbox: BoundingBox | None = None
frame_ref: int | None = None
content_type: str = ""
@dataclass
class BrandStats:
total_appearances: int = 0
total_screen_time: float = 0.0
avg_confidence: float = 0.0
first_seen: float = 0.0
last_seen: float = 0.0
@dataclass
class PipelineStats:
frames_extracted: int = 0
frames_after_scene_filter: int = 0
cv_regions_detected: int = 0
regions_detected: int = 0
regions_resolved_by_ocr: int = 0
regions_escalated_to_local_vlm: int = 0
regions_escalated_to_cloud_llm: int = 0
auxiliary_detections: int = 0
cloud_llm_calls: int = 0
processing_time_seconds: float = 0.0
estimated_cloud_cost_usd: float = 0.0
@dataclass
class DetectionReport:
video_source: str
content_type: str
duration_seconds: float
brands: dict[str, BrandStats] = field(default_factory=dict)
timeline: list[BrandDetection] = field(default_factory=list)
pipeline_stats: PipelineStats = field(default_factory=PipelineStats)
@dataclass
class CropContext:
"""Runtime type — holds image bytes for VLM prompts."""
image: bytes
surrounding_text: str = ""
position_hint: str = ""

107
core/detect/profile.py Normal file
View File

@@ -0,0 +1,107 @@
"""
Profile registry and helpers.
Loads profile data from Postgres.
A profile is a dict with keys: name, pipeline, configs.
"""
from __future__ import annotations
import logging
from typing import Any, Dict
from core.detect.stages.models import PipelineConfig, StageRef, Edge
from core.detect.models import (
BrandDetection,
BrandStats,
CropContext,
DetectionReport,
PipelineStats,
)
logger = logging.getLogger(__name__)
def get_profile(name: str) -> Dict[str, Any]:
"""Get a profile dict by name from the database."""
from core.db.connection import get_session
from core.db.models import Profile
with get_session() as session:
row = session.query(Profile).filter(Profile.name == name).first()
if row is None:
raise ValueError(f"Unknown profile: {name!r}")
return {
"name": row.name,
"pipeline": row.pipeline or {},
"configs": row.configs or {},
}
def list_profiles() -> list[str]:
"""List available profile names from the database."""
from core.db.connection import get_session
from core.db.models import Profile
with get_session() as session:
rows = session.query(Profile.name).all()
return [r[0] for r in rows]
def get_stage_config(profile: Dict[str, Any], stage_name: str) -> dict:
"""Get config values for a stage from a profile."""
return profile.get("configs", {}).get(stage_name, {})
def pipeline_config_from_dict(data: Dict[str, Any]) -> PipelineConfig:
"""Deserialize a PipelineConfig from a JSONB dict."""
stages = [StageRef(**s) for s in data.get("stages", [])]
edges = [Edge(**e) for e in data.get("edges", [])]
return PipelineConfig(
name=data.get("name", ""),
profile_name=data.get("profile_name", ""),
stages=stages,
edges=edges,
routing_rules=data.get("routing_rules", {}),
)
def build_vlm_prompt(crop_context: CropContext, template: str) -> str:
"""Build a VLM prompt from a template and crop context."""
hint = f" Position: {crop_context.position_hint}." if crop_context.position_hint else ""
text = f" Nearby text: '{crop_context.surrounding_text}'." if crop_context.surrounding_text else ""
return template.format(hint=hint, text=text)
def aggregate_detections(
detections: list[BrandDetection],
content_type: str,
) -> DetectionReport:
"""Group detections by brand into a report."""
brands: dict[str, BrandStats] = {}
for d in detections:
if d.brand not in brands:
brands[d.brand] = BrandStats()
s = brands[d.brand]
s.total_appearances += 1
s.total_screen_time += d.duration
s.avg_confidence = (
(s.avg_confidence * (s.total_appearances - 1) + d.confidence)
/ s.total_appearances
)
if s.first_seen == 0.0 or d.timestamp < s.first_seen:
s.first_seen = d.timestamp
if d.timestamp > s.last_seen:
s.last_seen = d.timestamp
return DetectionReport(
video_source="",
content_type=content_type,
duration_seconds=0.0,
brands=brands,
timeline=sorted(detections, key=lambda d: d.timestamp),
pipeline_stats=PipelineStats(),
)

View File

@@ -0,0 +1,58 @@
"""
Cloud LLM provider registry.
Select provider via CLOUD_LLM_PROVIDER env var.
Each provider reads its own env vars for auth/config.
CLOUD_LLM_PROVIDER=groq → GROQ_API_KEY, GROQ_MODEL, GROQ_BASE_URL
CLOUD_LLM_PROVIDER=gemini → GEMINI_API_KEY, GEMINI_MODEL
CLOUD_LLM_PROVIDER=openai → OPENAI_API_KEY, OPENAI_MODEL, OPENAI_BASE_URL
CLOUD_LLM_PROVIDER=claude → ANTHROPIC_API_KEY, CLAUDE_MODEL
"""
from __future__ import annotations
import os
from .base import CloudProvider, ProviderResponse
from .groq import GroqProvider
from .gemini import GeminiProvider
from .openai_compat import OpenAICompatProvider
from .claude import ClaudeProvider
PROVIDERS: dict[str, type] = {
"groq": GroqProvider,
"gemini": GeminiProvider,
"openai": OpenAICompatProvider,
"claude": ClaudeProvider,
}
_cached: CloudProvider | None = None
def get_provider() -> CloudProvider:
"""Get the configured cloud provider (cached after first call)."""
global _cached
if _cached is not None:
return _cached
name = os.environ.get("CLOUD_LLM_PROVIDER", "groq")
cls = PROVIDERS.get(name)
if cls is None:
raise ValueError(f"Unknown provider: {name!r}. Options: {list(PROVIDERS)}")
_cached = cls()
return _cached
def has_api_key() -> bool:
"""Check if the configured provider has an API key set."""
name = os.environ.get("CLOUD_LLM_PROVIDER", "groq")
key_map = {
"groq": "GROQ_API_KEY",
"gemini": "GEMINI_API_KEY",
"openai": "OPENAI_API_KEY",
"claude": "ANTHROPIC_API_KEY",
}
env_var = key_map.get(name, "")
return bool(os.environ.get(env_var, ""))

View File

@@ -0,0 +1,36 @@
"""Cloud LLM provider protocol and model metadata."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Protocol
@dataclass
class ModelInfo:
"""Metadata for a cloud LLM model."""
id: str
vision: bool = True
cost_per_input_token: float = 0.0
cost_per_output_token: float = 0.0
max_output_tokens: int = 4096
notes: str = ""
@dataclass
class ProviderResponse:
answer: str
total_tokens: int = 0
class CloudProvider(Protocol):
"""
Interface for cloud LLM providers.
Each provider handles its own auth, payload format, and response parsing.
The pipeline only calls call() and reads the response.
"""
name: str
models: dict[str, ModelInfo]
def call(self, image_b64: str, prompt: str) -> ProviderResponse: ...

View File

@@ -0,0 +1,73 @@
"""Anthropic Claude provider — uses the official SDK."""
from __future__ import annotations
import logging
import os
from .base import ModelInfo, ProviderResponse
logger = logging.getLogger(__name__)
# Claude-specific env vars
# ANTHROPIC_API_KEY is read by the SDK automatically
CLAUDE_MODEL = os.environ.get("CLAUDE_MODEL", "claude-sonnet-4-20250514")
MODELS = {
"claude-sonnet-4-20250514": ModelInfo(
id="claude-sonnet-4-20250514",
vision=True,
cost_per_input_token=0.000003,
cost_per_output_token=0.000015,
notes="Best balance of quality/cost with vision",
),
"claude-haiku-4-5-20251001": ModelInfo(
id="claude-haiku-4-5-20251001",
vision=True,
cost_per_input_token=0.0000008,
cost_per_output_token=0.000004,
notes="Fastest, cheapest, good for simple brand ID",
),
"claude-opus-4-6": ModelInfo(
id="claude-opus-4-6",
vision=True,
cost_per_input_token=0.000015,
cost_per_output_token=0.000075,
notes="Highest quality, use for ambiguous cases",
),
}
class ClaudeProvider:
name = "claude"
models = MODELS
def __init__(self):
from anthropic import Anthropic
self.client = Anthropic()
self.model = CLAUDE_MODEL
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
message = self.client.messages.create(
model=self.model,
max_tokens=150,
messages=[{
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": image_b64,
},
},
{"type": "text", "text": prompt},
],
}],
)
answer = message.content[0].text.strip()
total_tokens = message.usage.input_tokens + message.usage.output_tokens
return ProviderResponse(answer=answer, total_tokens=total_tokens)

View File

@@ -0,0 +1,75 @@
"""Google Gemini provider — native REST API, not OpenAI-compatible."""
from __future__ import annotations
import logging
import os
import requests
from .base import ModelInfo, ProviderResponse
logger = logging.getLogger(__name__)
# Gemini-specific env vars
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "")
GEMINI_MODEL = os.environ.get("GEMINI_MODEL", "gemini-2.0-flash")
MODELS = {
"gemini-2.0-flash": ModelInfo(
id="gemini-2.0-flash",
vision=True,
cost_per_input_token=0.0000001,
cost_per_output_token=0.0000004,
notes="Fast, cheap, good vision",
),
"gemini-2.0-pro": ModelInfo(
id="gemini-2.0-pro",
vision=True,
cost_per_input_token=0.00000125,
cost_per_output_token=0.000005,
notes="Higher quality, slower",
),
"gemini-1.5-flash": ModelInfo(
id="gemini-1.5-flash",
vision=True,
cost_per_input_token=0.000000075,
cost_per_output_token=0.0000003,
notes="Cheapest option",
),
}
class GeminiProvider:
name = "gemini"
models = MODELS
def __init__(self):
self.api_key = GEMINI_API_KEY
self.model = GEMINI_MODEL
self.endpoint = (
f"https://generativelanguage.googleapis.com/v1beta/models/"
f"{self.model}:generateContent"
)
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
payload = {
"contents": [{
"parts": [
{"text": prompt},
{"inline_data": {"mime_type": "image/jpeg", "data": image_b64}},
],
}],
"generationConfig": {"maxOutputTokens": 150},
}
url = f"{self.endpoint}?key={self.api_key}"
resp = requests.post(url, json=payload, timeout=30)
resp.raise_for_status()
data = resp.json()
answer = data["candidates"][0]["content"]["parts"][0]["text"].strip()
usage = data.get("usageMetadata", {})
total_tokens = usage.get("totalTokenCount", 0)
return ProviderResponse(answer=answer, total_tokens=total_tokens)

View File

@@ -0,0 +1,66 @@
"""Groq cloud provider — OpenAI-compatible API with vision."""
from __future__ import annotations
import logging
import os
import requests
from .base import ModelInfo, ProviderResponse
logger = logging.getLogger(__name__)
# Groq-specific env vars
GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "")
GROQ_BASE_URL = os.environ.get("GROQ_BASE_URL", "https://api.groq.com/openai/v1")
GROQ_MODEL = os.environ.get("GROQ_MODEL", "meta-llama/llama-4-scout-17b-16e-instruct")
MODELS = {
"meta-llama/llama-4-scout-17b-16e-instruct": ModelInfo(
id="meta-llama/llama-4-scout-17b-16e-instruct",
vision=True,
cost_per_input_token=0.0,
cost_per_output_token=0.0,
notes="Llama 4 Scout, only vision model on Groq free tier",
),
}
class GroqProvider:
name = "groq"
models = MODELS
def __init__(self):
self.api_key = GROQ_API_KEY
self.base_url = GROQ_BASE_URL
self.model = GROQ_MODEL
self.endpoint = f"{self.base_url.rstrip('/')}/chat/completions"
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
payload = {
"model": self.model,
"messages": [{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {
"url": f"data:image/jpeg;base64,{image_b64}",
}},
],
}],
"max_tokens": 150,
}
resp = requests.post(self.endpoint, headers=self.headers, json=payload, timeout=30)
resp.raise_for_status()
data = resp.json()
answer = data["choices"][0]["message"]["content"].strip()
total_tokens = data.get("usage", {}).get("total_tokens", 0)
return ProviderResponse(answer=answer, total_tokens=total_tokens)

View File

@@ -0,0 +1,73 @@
"""Generic OpenAI-compatible provider (OpenAI, Together, etc.)."""
from __future__ import annotations
import logging
import os
import requests
from .base import ModelInfo, ProviderResponse
logger = logging.getLogger(__name__)
# OpenAI-compat specific env vars
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")
OPENAI_MODEL = os.environ.get("OPENAI_MODEL", "gpt-4o-mini")
MODELS = {
"gpt-4o-mini": ModelInfo(
id="gpt-4o-mini",
vision=True,
cost_per_input_token=0.00000015,
cost_per_output_token=0.0000006,
notes="Cheap, fast, decent vision",
),
"gpt-4o": ModelInfo(
id="gpt-4o",
vision=True,
cost_per_input_token=0.0000025,
cost_per_output_token=0.00001,
notes="Best OpenAI vision model",
),
}
class OpenAICompatProvider:
name = "openai"
models = MODELS
def __init__(self):
self.api_key = OPENAI_API_KEY
self.base_url = OPENAI_BASE_URL
self.model = OPENAI_MODEL
self.endpoint = f"{self.base_url.rstrip('/')}/chat/completions"
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
payload = {
"model": self.model,
"messages": [{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {
"url": f"data:image/jpeg;base64,{image_b64}",
}},
],
}],
"max_tokens": 150,
}
resp = requests.post(self.endpoint, headers=self.headers, json=payload, timeout=30)
resp.raise_for_status()
data = resp.json()
answer = data["choices"][0]["message"]["content"].strip()
total_tokens = data.get("usage", {}).get("total_tokens", 0)
return ProviderResponse(answer=answer, total_tokens=total_tokens)

163
core/detect/sse.py Normal file
View File

@@ -0,0 +1,163 @@
"""
Pydantic Models - GENERATED FILE
Do not edit directly. Regenerate using modelgen.
"""
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from uuid import UUID
from pydantic import BaseModel, Field
class GraphNode(BaseModel):
"""A pipeline stage node."""
id: str
status: str = "idle"
items_in: int = 0
items_out: int = 0
class GraphEdge(BaseModel):
"""An edge between pipeline stages."""
source: str
target: str
throughput: int = 0
class BoundingBoxEvent(BaseModel):
"""Bounding box in SSE event payloads."""
x: int
y: int
w: int
h: int
confidence: float
label: str
resolved_brand: Optional[str] = None
source: Optional[str] = None
stage: Optional[str] = None
class BrandSummary(BaseModel):
"""Per-brand stats in the final report."""
brand: str
total_appearances: int = 0
total_screen_time: float = 0.0
avg_confidence: float = 0.0
first_seen: float = 0.0
last_seen: float = 0.0
class GraphUpdate(BaseModel):
"""Pipeline node state transition. SSE event: graph_update"""
nodes: List[GraphNode] = Field(default_factory=list)
edges: List[GraphEdge] = Field(default_factory=list)
active_path: List[str] = Field(default_factory=list)
class StatsUpdate(BaseModel):
"""Funnel statistics snapshot. SSE event: stats_update"""
frames_extracted: int = 0
frames_after_scene_filter: int = 0
cv_regions_detected: int = 0
regions_detected: int = 0
regions_resolved_by_ocr: int = 0
regions_escalated_to_local_vlm: int = 0
regions_escalated_to_cloud_llm: int = 0
cloud_llm_calls: int = 0
processing_time_seconds: float = 0.0
estimated_cloud_cost_usd: float = 0.0
run_id: Optional[str] = None
parent_job_id: Optional[str] = None
run_type: str = "initial"
class FrameUpdate(BaseModel):
"""Current frame being processed. SSE event: frame_update"""
frame_ref: int
timestamp: float
jpeg_b64: str
boxes: List[BoundingBoxEvent] = Field(default_factory=list)
class Detection(BaseModel):
"""A confirmed brand detection. SSE event: detection"""
brand: str
timestamp: float
duration: float
confidence: float
source: str
content_type: str
bbox: Optional[BoundingBoxEvent] = None
frame_ref: Optional[int] = None
class LogEvent(BaseModel):
"""Pipeline log line. SSE event: log"""
level: str
stage: str
msg: str
ts: str
trace_id: Optional[str] = None
class DetectionReportSummary(BaseModel):
"""Final detection report summary."""
video_source: str
content_type: str
duration_seconds: float
total_detections: int = 0
brands: List[BrandSummary] = Field(default_factory=list)
stats: Optional[StatsUpdate] = None
class JobComplete(BaseModel):
"""Final report when pipeline finishes. SSE event: job_complete"""
job_id: str
report: Optional[DetectionReportSummary] = None
class RunContext(BaseModel):
"""Run context injected into all SSE events for grouping."""
run_id: str
parent_job_id: str
run_type: str = "initial"
class CheckpointInfo(BaseModel):
"""Available checkpoint for a stage."""
stage: str
is_scenario: bool = False
scenario_label: str = ""
class ReplayRequest(BaseModel):
"""Request to replay pipeline from a specific stage."""
job_id: str
start_stage: str
config_overrides: Optional[Dict[str, Any]] = None
class ReplayResponse(BaseModel):
"""Result of a replay invocation."""
status: str
job_id: str
start_stage: str
detections: int = 0
brands_found: int = 0
class RetryRequest(BaseModel):
"""Request to queue async retry with different config."""
job_id: str
config_overrides: Optional[Dict[str, Any]] = None
start_stage: str = "escalate_vlm"
schedule_seconds: Optional[float] = None
class RetryResponse(BaseModel):
"""Result of queueing a retry task."""
status: str
task_id: str
job_id: str
class RunRequest(BaseModel):
"""Request body for launching a detection pipeline run."""
video_path: str
profile_name: str = "soccer_broadcast"
source_asset_id: str = ""
checkpoint: bool = True
skip_vlm: bool = False
skip_cloud: bool = False
log_level: str = "INFO"
class RunResponse(BaseModel):
"""Response after starting a pipeline run."""
status: str
job_id: str
video_path: str

View File

@@ -0,0 +1,22 @@
"""
Pipeline stages.
Each stage is a file with a Stage subclass. Auto-discovered via
__init_subclass__ — importing the file registers the stage.
"""
from .base import (
Stage,
get_stage,
get_stage_instance,
list_stages,
list_stage_classes,
get_palette,
)
# Import all stage files to trigger auto-registration
from . import edge_detector # noqa: F401
from . import field_segmentation # noqa: F401
# Import registry for backward compat (other stages still use old pattern)
from . import registry # noqa: F401

View File

@@ -0,0 +1,116 @@
"""
Stage 8 — Report compilation
Groups all detections by brand, merges contiguous appearances,
and builds the final DetectionReport.
"""
from __future__ import annotations
import logging
from core.detect import emit
from core.detect.models import BrandDetection, BrandStats, DetectionReport, PipelineStats
logger = logging.getLogger(__name__)
def _merge_contiguous(detections: list[BrandDetection], gap_threshold: float = 2.0) -> list[BrandDetection]:
"""
Merge detections of the same brand that are close in time.
If two detections of the same brand are within gap_threshold seconds,
they're merged into one detection spanning the full range.
"""
if not detections:
return []
sorted_dets = sorted(detections, key=lambda d: (d.brand, d.timestamp))
merged: list[BrandDetection] = []
current = sorted_dets[0]
for det in sorted_dets[1:]:
if (det.brand == current.brand
and det.timestamp <= current.timestamp + current.duration + gap_threshold):
end = max(current.timestamp + current.duration,
det.timestamp + det.duration)
current = BrandDetection(
brand=current.brand,
timestamp=current.timestamp,
duration=end - current.timestamp,
confidence=max(current.confidence, det.confidence),
source=current.source,
bbox=current.bbox,
frame_ref=current.frame_ref,
content_type=current.content_type,
)
else:
merged.append(current)
current = det
merged.append(current)
return merged
def compile_report(
detections: list[BrandDetection],
stats: PipelineStats,
video_source: str = "",
content_type: str = "",
duration_seconds: float = 0.0,
job_id: str | None = None,
) -> DetectionReport:
"""
Build the final detection report from all accumulated detections.
Merges contiguous detections, computes per-brand stats,
and emits the job_complete event.
"""
merged = _merge_contiguous(detections)
brands: dict[str, BrandStats] = {}
for d in merged:
if d.brand not in brands:
brands[d.brand] = BrandStats()
s = brands[d.brand]
s.total_appearances += 1
s.total_screen_time += d.duration
s.avg_confidence = (
(s.avg_confidence * (s.total_appearances - 1) + d.confidence)
/ s.total_appearances
)
if s.first_seen == 0.0 or d.timestamp < s.first_seen:
s.first_seen = d.timestamp
if d.timestamp > s.last_seen:
s.last_seen = d.timestamp
report = DetectionReport(
video_source=video_source,
content_type=content_type,
duration_seconds=duration_seconds,
brands=brands,
timeline=sorted(merged, key=lambda d: d.timestamp),
pipeline_stats=stats,
)
emit.log(job_id, "Aggregator", "INFO",
f"Report: {len(brands)} brands, {len(merged)} detections "
f"(merged from {len(detections)} raw)")
emit.job_complete(job_id, {
"video_source": report.video_source,
"content_type": report.content_type,
"duration_seconds": report.duration_seconds,
"brands": {
k: {
"total_appearances": v.total_appearances,
"total_screen_time": v.total_screen_time,
"avg_confidence": round(v.avg_confidence, 3),
"first_seen": v.first_seen,
"last_seen": v.last_seen,
}
for k, v in brands.items()
},
})
return report

151
core/detect/stages/base.py Normal file
View File

@@ -0,0 +1,151 @@
"""
Stage base class — common interface for all pipeline stages.
Each stage is a file that subclasses Stage. Auto-discovered via
__init_subclass__. No manual registration needed.
A stage:
- Has a StageDefinition (generated from schema) with name, config, IO
- Implements run(frames, config) → output
- Owns its output serialization (opaque blob)
- Optionally has a TypeScript port for browser-side execution
The checkpoint layer stores stage output as blobs without knowing
the format. The stage that wrote it is the only one that can read it.
"""
from __future__ import annotations
from typing import Any
import numpy as np
from core.detect.stages.models import (
StageConfigField,
StageIO,
StageDefinition,
)
# Legacy runtime extension — adds callable fields for old-style stages.
# New stages use Stage subclass with serialize()/deserialize() methods instead.
class LegacyStageDefinition:
"""Wraps a StageDefinition with callable serialize/deserialize functions."""
def __init__(self, definition: StageDefinition, fn=None, serialize_fn=None, deserialize_fn=None):
self._definition = definition
self.fn = fn
self.serialize_fn = serialize_fn
self.deserialize_fn = deserialize_fn
def __getattr__(self, name):
return getattr(self._definition, name)
# ---------------------------------------------------------------------------
# Registry — auto-populated by __init_subclass__ (new stages)
# + register_stage() (legacy stages during migration)
# ---------------------------------------------------------------------------
_REGISTRY: dict[str, type['Stage']] = {}
_LEGACY_REGISTRY: dict[str, LegacyStageDefinition] = {}
def register_stage(
definition: StageDefinition,
fn=None,
serialize_fn=None,
deserialize_fn=None,
):
"""Legacy registration for stages not yet converted to Stage subclass."""
legacy = LegacyStageDefinition(definition, fn=fn, serialize_fn=serialize_fn, deserialize_fn=deserialize_fn)
_LEGACY_REGISTRY[definition.name] = legacy
class Stage:
"""
Base class for all pipeline stages.
Subclass this in detect/stages/<name>.py. Define `definition` as a
class attribute. Implement `run()`. Optionally override `serialize()`
and `deserialize()` for custom blob formats (default is JSON).
"""
definition: StageDefinition # set by each subclass
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if hasattr(cls, 'definition') and cls.definition is not None:
_REGISTRY[cls.definition.name] = cls
def run(self, frames: list, config: dict) -> Any:
raise NotImplementedError
def serialize(self, output: Any) -> bytes:
"""Serialize stage output to bytes for checkpoint storage."""
import json
return json.dumps(output, default=str).encode()
def deserialize(self, data: bytes) -> Any:
"""Deserialize stage output from checkpoint blob."""
import json
return json.loads(data)
# ---------------------------------------------------------------------------
# Discovery API
# ---------------------------------------------------------------------------
def _all_definitions():
"""Merge new Stage subclass registry + legacy registry.
Returns StageDefinition for new-style stages,
LegacyStageDefinition for legacy stages (has serialize_fn etc).
"""
merged = {}
for name, legacy in _LEGACY_REGISTRY.items():
merged[name] = legacy
for name, cls in _REGISTRY.items():
merged[name] = cls.definition
return merged
def get_stage(name: str) -> StageDefinition:
"""Get a stage definition by name (works for both new and legacy)."""
all_defs = _all_definitions()
if name not in all_defs:
raise KeyError(f"Unknown stage: {name!r}. Registered: {list(all_defs)}")
return all_defs[name]
def get_stage_class(name: str) -> type[Stage] | None:
"""Get a Stage subclass by name. Returns None for legacy stages."""
return _REGISTRY.get(name)
def get_stage_instance(name: str) -> Stage:
"""Get an instantiated Stage by name. Only works for new-style stages."""
cls = _REGISTRY.get(name)
if cls is None:
raise KeyError(f"No Stage subclass for {name!r}. Legacy stages don't have instances.")
return cls()
def list_stages() -> list[StageDefinition]:
"""List all registered stage definitions (new + legacy)."""
return list(_all_definitions().values())
def list_stage_classes() -> list[type[Stage]]:
"""List all registered Stage subclasses (new-style only)."""
return list(_REGISTRY.values())
def get_palette() -> dict[str, list[StageDefinition]]:
"""Group stages by category for the editor palette."""
palette: dict[str, list[StageDefinition]] = {}
for defn in _all_definitions().values():
if defn.category not in palette:
palette[defn.category] = []
palette[defn.category].append(defn)
return palette

View File

@@ -0,0 +1,216 @@
"""
Stage 5 — Brand Resolver (discovery mode)
Discovery-first brand matching. No static dictionary — all brands live in the DB.
Flow:
1. Check session brands first (brands already seen in this run, in-memory)
2. Check global known brands (accumulated across all runs)
3. Unresolved candidates → escalate to VLM/cloud
4. Confirmed brands get added to DB for future runs
"""
from __future__ import annotations
import logging
from rapidfuzz import fuzz
from core.detect import emit
from core.detect.models import BrandDetection, TextCandidate
from core.detect.stages.models import ResolverConfig
logger = logging.getLogger(__name__)
def _normalize(text: str) -> str:
return text.strip().lower()
def _has_db() -> bool:
try:
from core.db import find_brand_by_text as _
return True
except (ImportError, Exception):
return False
def _match_session(text: str, session_brands: dict[str, str]) -> str | None:
return session_brands.get(_normalize(text))
def _match_known(text: str, threshold: int) -> tuple[str | None, str | None]:
"""Check against global known brands in DB. Returns (canonical_name, brand_id) or (None, None)."""
if not _has_db():
return None, None
from core.db import find_brand_by_text, list_brands
from core.db.connection import get_session
with get_session() as session:
brand = find_brand_by_text(session, text)
if brand:
return brand.canonical_name, str(brand.id)
all_brands = list_brands(session)
normalized = _normalize(text)
best_brand = None
best_score = 0
for known in all_brands:
names = [known.canonical_name] + (known.aliases or [])
for name in names:
score = fuzz.ratio(normalized, _normalize(name))
if score > best_score and score >= threshold:
best_score = score
best_brand = known
if best_brand:
return best_brand.canonical_name, str(best_brand.id)
return None, None
def _register_brand(canonical_name: str, source: str) -> str | None:
"""Register a newly discovered brand in the DB. Returns brand_id."""
if not _has_db():
return None
from core.db import get_or_create_brand
from core.db.connection import get_session
with get_session() as session:
brand, created = get_or_create_brand(session, canonical_name, source=source)
session.commit()
if created:
logger.info("New brand discovered: %s (source=%s)", canonical_name, source)
return str(brand.id)
def _record_airing(timeline_id: str | None, brand_id: str,
frame_seq: int, confidence: float, source: str):
"""Record a brand airing on a timeline."""
if not _has_db() or not timeline_id:
return
from core.db import record_airing
from core.db.connection import get_session
from uuid import UUID
with get_session() as session:
record_airing(
session,
brand_id=UUID(brand_id),
timeline_id=UUID(timeline_id),
frame_start=frame_seq,
frame_end=frame_seq,
confidence=confidence,
source=source,
)
session.commit()
def build_session_dict(source_asset_id: str | None = None) -> dict[str, str]:
"""
Load known brands from DB as a session lookup dict.
Returns {normalized_name: canonical_name, ...} including aliases.
"""
if not _has_db():
return {}
from core.db import list_brands
from core.db.connection import get_session
with get_session() as session:
all_brands = list_brands(session)
session_dict = {}
for brand in all_brands:
session_dict[_normalize(brand.canonical_name)] = brand.canonical_name
for alias in (brand.aliases or []):
session_dict[_normalize(alias)] = brand.canonical_name
return session_dict
def resolve_brands(
candidates: list[TextCandidate],
config: ResolverConfig,
session_brands: dict[str, str] | None = None,
source_asset_id: str | None = None,
content_type: str = "",
job_id: str | None = None,
) -> tuple[list[BrandDetection], list[TextCandidate]]:
"""
Match text candidates against known brands (session → global → unresolved).
session_brands: pre-loaded session dict (from build_session_dict)
job_id: timeline_id — used to record airings
"""
if session_brands is None:
session_brands = {}
emit.log(job_id, "BrandResolver", "INFO",
f"Resolving {len(candidates)} candidates "
f"(session={len(session_brands)} brands, fuzzy={config.fuzzy_threshold})")
matched: list[BrandDetection] = []
unresolved: list[TextCandidate] = []
session_hits = 0
known_hits = 0
for candidate in candidates:
text = candidate.text
brand_name = None
brand_id = None
match_source = "ocr"
# 1. Check session (cheapest — in-memory dict)
brand_name = _match_session(text, session_brands)
if brand_name:
session_hits += 1
else:
# 2. Check global known brands (DB query + fuzzy)
brand_name, brand_id = _match_known(text, config.fuzzy_threshold)
if brand_name:
known_hits += 1
session_brands[_normalize(brand_name)] = brand_name
if brand_name:
detection = BrandDetection(
brand=brand_name,
timestamp=candidate.frame.timestamp,
duration=0.5,
confidence=candidate.ocr_confidence,
source=match_source,
bbox=candidate.bbox,
frame_ref=candidate.frame.sequence,
content_type=content_type,
)
matched.append(detection)
if brand_id:
_record_airing(
job_id, brand_id,
candidate.frame.sequence, candidate.ocr_confidence, match_source,
)
emit.detection(
job_id,
brand=brand_name,
confidence=candidate.ocr_confidence,
source=match_source,
timestamp=candidate.frame.timestamp,
content_type=content_type,
frame_ref=candidate.frame.sequence,
)
else:
unresolved.append(candidate)
emit.log(job_id, "BrandResolver", "INFO",
f"Session: {session_hits}, Known: {known_hits}, "
f"Unresolved: {len(unresolved)} → escalating")
return matched, unresolved

View File

@@ -0,0 +1,288 @@
"""
Stage — Edge Detection
Canny + HoughLinesP to find horizontal line pairs that bound
advertising hoardings. Pure OpenCV, no ML models.
Two modes:
- Remote: calls GPU inference server over HTTP
- Local: imports cv2 directly (OpenCV on same machine)
"""
from __future__ import annotations
import base64
import io
import json
import logging
import os
import time
from typing import Any
from PIL import Image
from core.detect import emit
from core.detect.models import BoundingBox, Frame
from core.detect.stages.base import Stage
from core.detect.stages.models import StageDefinition, StageConfigField, StageIO
logger = logging.getLogger(__name__)
class EdgeDetectionStage(Stage):
definition = StageDefinition(
name="detect_edges",
label="Edge Detection",
description="Canny + HoughLinesP — find horizontal line pairs (hoarding boundaries)",
category="cv_analysis",
io=StageIO(
reads=["filtered_frames"],
writes=["edge_regions_by_frame"],
),
config_fields=[
StageConfigField(name="enabled", type="bool", default=True, description="Enable edge detection"),
StageConfigField(name="edge_canny_low", type="int", default=50, description="Canny low threshold", min=0, max=255),
StageConfigField(name="edge_canny_high", type="int", default=150, description="Canny high threshold", min=0, max=255),
StageConfigField(name="edge_hough_threshold", type="int", default=80, description="Hough accumulator threshold", min=1, max=500),
StageConfigField(name="edge_hough_min_length", type="int", default=100, description="Min line length (px)", min=10, max=2000),
StageConfigField(name="edge_hough_max_gap", type="int", default=10, description="Max line gap (px)", min=1, max=100),
StageConfigField(name="edge_pair_max_distance", type="int", default=200, description="Max distance between line pair (px)", min=10, max=500),
StageConfigField(name="edge_pair_min_distance", type="int", default=15, description="Min distance between line pair (px)", min=5, max=200),
],
tracks_element="edge_region",
)
def run(self, frames: list[Frame], config: dict) -> dict[int, list[BoundingBox]]:
"""
Run edge detection on all frames.
Config keys: enabled, edge_canny_low, edge_canny_high, edge_hough_threshold,
edge_hough_min_length, edge_hough_max_gap, edge_pair_max_distance, edge_pair_min_distance,
debug (bool), inference_url (str|None), job_id (str|None).
Returns dict mapping frame sequence → list of BoundingBox.
"""
enabled = config.get("enabled", True)
job_id = config.get("job_id")
inference_url = config.get("inference_url") or os.environ.get("INFERENCE_URL")
if not enabled:
emit.log(job_id, "EdgeDetection", "INFO", "Edge detection disabled, skipping")
return {}
mode = "remote" if inference_url else "local"
emit.log(job_id, "EdgeDetection", "INFO",
f"Detecting edges in {len(frames)} frames (mode={mode})")
all_boxes: dict[int, list[BoundingBox]] = {}
total_regions = 0
for frame in frames:
t0 = time.monotonic()
if inference_url:
boxes = self._run_remote(frame, config, inference_url, job_id or "")
else:
boxes = self._run_local(frame, config)
ms = (time.monotonic() - t0) * 1000
all_boxes[frame.sequence] = boxes
total_regions += len(boxes)
emit.log(job_id, "EdgeDetection", "DEBUG",
f"Frame {frame.sequence}: {len(boxes)} regions in {ms:.0f}ms"
+ (f" [{', '.join(b.label for b in boxes)}]" if boxes else ""))
if boxes and job_id:
box_dicts = [
{"x": b.x, "y": b.y, "w": b.w, "h": b.h,
"confidence": b.confidence, "label": b.label,
"stage": "detect_edges"}
for b in boxes
]
emit.frame_update(
job_id,
frame_ref=frame.sequence,
timestamp=frame.timestamp,
jpeg_b64=_frame_to_b64(frame),
boxes=box_dicts,
)
emit.log(job_id, "EdgeDetection", "INFO",
f"Found {total_regions} edge regions across {len(frames)} frames")
emit.stats(job_id, cv_regions_detected=total_regions)
return all_boxes
def serialize(self, output: Any) -> bytes:
"""Serialize edge regions to JSON blob."""
serialized = {}
for seq, boxes in output.items():
serialized[str(seq)] = [
{"x": b.x, "y": b.y, "w": b.w, "h": b.h,
"confidence": b.confidence, "label": b.label}
for b in boxes
]
return json.dumps(serialized).encode()
def deserialize(self, data: bytes) -> dict[int, list[BoundingBox]]:
"""Deserialize edge regions from JSON blob."""
raw = json.loads(data)
result = {}
for seq_str, box_dicts in raw.items():
boxes = [
BoundingBox(x=b["x"], y=b["y"], w=b["w"], h=b["h"],
confidence=b["confidence"], label=b["label"])
for b in box_dicts
]
result[int(seq_str)] = boxes
return result
# --- Private helpers ---
def _run_remote(self, frame: Frame, config: dict,
inference_url: str, job_id: str) -> list[BoundingBox]:
from core.detect.inference import InferenceClient
from core.detect.emit import _run_log_level
client = InferenceClient(
base_url=inference_url, job_id=job_id, log_level=_run_log_level,
)
results = client.detect_edges(
image=frame.image,
edge_canny_low=config.get("edge_canny_low", 50),
edge_canny_high=config.get("edge_canny_high", 150),
edge_hough_threshold=config.get("edge_hough_threshold", 80),
edge_hough_min_length=config.get("edge_hough_min_length", 100),
edge_hough_max_gap=config.get("edge_hough_max_gap", 10),
edge_pair_max_distance=config.get("edge_pair_max_distance", 200),
edge_pair_min_distance=config.get("edge_pair_min_distance", 15),
)
boxes = []
for r in results:
box = BoundingBox(
x=r.x, y=r.y, w=r.w, h=r.h,
confidence=r.confidence, label=r.label,
)
boxes.append(box)
return boxes
def _run_local(self, frame: Frame, config: dict) -> list[BoundingBox]:
detect_edges_fn = _load_cv_edges().detect_edges
edge_results = detect_edges_fn(
frame.image,
canny_low=config.get("edge_canny_low", 50),
canny_high=config.get("edge_canny_high", 150),
hough_threshold=config.get("edge_hough_threshold", 80),
hough_min_length=config.get("edge_hough_min_length", 100),
hough_max_gap=config.get("edge_hough_max_gap", 10),
pair_max_distance=config.get("edge_pair_max_distance", 200),
pair_min_distance=config.get("edge_pair_min_distance", 15),
)
boxes = []
for r in edge_results:
box = BoundingBox(
x=r["x"], y=r["y"], w=r["w"], h=r["h"],
confidence=r["confidence"], label=r["label"],
)
boxes.append(box)
return boxes
# --- Module-level helpers ---
def _frame_to_b64(frame: Frame) -> str:
img = Image.fromarray(frame.image)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=70)
return base64.b64encode(buf.getvalue()).decode()
_cv_edges_mod = None
def _load_cv_edges():
global _cv_edges_mod
if _cv_edges_mod is None:
import importlib.util
from pathlib import Path
spec = importlib.util.spec_from_file_location("cv_edges", Path("core/gpu/models/cv/edges.py"))
_cv_edges_mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(_cv_edges_mod)
return _cv_edges_mod
# --- Backward compat: standalone function for graph.py ---
def _filter_by_field_mask(boxes, mask, margin_px=50):
"""
Keep only boxes that are near the pitch boundary (hoarding zone).
The field mask has 255=pitch, 0=not pitch. Hoardings sit just
outside the pitch boundary. We dilate the mask to create a
"boundary zone" and keep boxes whose center falls in the zone
between the dilated mask edge and the original mask.
"""
import cv2
import numpy as np
if mask is None or not boxes:
return boxes
# Dilate the pitch mask — the expansion zone is where hoardings are
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (margin_px * 2, margin_px * 2))
dilated = cv2.dilate(mask, kernel)
# Boundary zone = dilated but NOT original pitch
boundary_zone = cv2.bitwise_and(dilated, cv2.bitwise_not(mask))
kept = []
for box in boxes:
cx = box.x + box.w // 2
cy = box.y + box.h // 2
# Clamp to image bounds
cy = min(cy, boundary_zone.shape[0] - 1)
cx = min(cx, boundary_zone.shape[1] - 1)
if boundary_zone[cy, cx] > 0:
kept.append(box)
return kept
def detect_edge_regions(frames, config, inference_url=None, job_id=None, field_masks=None):
"""Convenience wrapper — calls EdgeDetectionStage.run(), optionally filters by field mask."""
stage = EdgeDetectionStage()
cfg = {
"enabled": config.enabled,
"edge_canny_low": config.edge_canny_low,
"edge_canny_high": config.edge_canny_high,
"edge_hough_threshold": config.edge_hough_threshold,
"edge_hough_min_length": config.edge_hough_min_length,
"edge_hough_max_gap": config.edge_hough_max_gap,
"edge_pair_max_distance": config.edge_pair_max_distance,
"edge_pair_min_distance": config.edge_pair_min_distance,
"inference_url": inference_url,
"job_id": job_id,
}
all_boxes = stage.run(frames, cfg)
# Filter by field segmentation mask if available
if field_masks:
filtered_total = 0
original_total = sum(len(b) for b in all_boxes.values())
for seq, boxes in all_boxes.items():
mask = field_masks.get(seq)
if mask is not None:
all_boxes[seq] = _filter_by_field_mask(boxes, mask)
filtered_total += len(all_boxes[seq])
else:
filtered_total += len(boxes)
if original_total != filtered_total:
from core.detect import emit
emit.log(job_id, "EdgeDetection", "INFO",
f"Field mask filter: {original_total}{filtered_total} regions")
return all_boxes

View File

@@ -0,0 +1,141 @@
"""
Stage — Field Segmentation
Calls the GPU inference server to detect pitch boundaries via
HSV green mask + morphology. The CV code lives in core/gpu/models/cv/.
Outputs a mask and boundary that downstream stages use as spatial priors.
"""
from __future__ import annotations
import base64
import io
import logging
import numpy as np
from PIL import Image
from core.detect import emit
from core.detect.models import Frame
from core.detect.stages.base import Stage
from core.detect.stages.models import (
FieldSegmentationConfig,
StageConfigField,
StageDefinition,
StageIO,
)
logger = logging.getLogger(__name__)
class FieldSegmentationStage(Stage):
definition = StageDefinition(
name="field_segmentation",
label="Field Segmentation",
description="HSV green mask — detect pitch boundaries for spatial priors",
category="cv_analysis",
io=StageIO(
reads=["filtered_frames"],
writes=["field_mask"],
),
config_fields=[
StageConfigField(name="enabled", type="bool", default=True, description="Enable field segmentation"),
StageConfigField(name="hue_low", type="int", default=30, description="HSV hue lower bound", min=0, max=180),
StageConfigField(name="hue_high", type="int", default=85, description="HSV hue upper bound", min=0, max=180),
StageConfigField(name="sat_low", type="int", default=30, description="HSV saturation lower bound", min=0, max=255),
StageConfigField(name="sat_high", type="int", default=255, description="HSV saturation upper bound", min=0, max=255),
StageConfigField(name="val_low", type="int", default=30, description="HSV value lower bound", min=0, max=255),
StageConfigField(name="val_high", type="int", default=255, description="HSV value upper bound", min=0, max=255),
StageConfigField(name="morph_kernel", type="int", default=15, description="Morphology kernel size", min=3, max=51),
StageConfigField(name="min_area_ratio", type="float", default=0.05, description="Min contour area as fraction of frame", min=0.01, max=0.5),
],
)
def _frame_to_b64(frame: Frame) -> str:
"""Encode frame image as base64 JPEG."""
img = Image.fromarray(frame.image)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=85)
return base64.b64encode(buf.getvalue()).decode()
def _decode_mask_b64(mask_b64: str) -> np.ndarray:
"""Decode a base64 PNG mask back to numpy array."""
data = base64.b64decode(mask_b64)
img = Image.open(io.BytesIO(data)).convert("L")
return np.array(img)
def run_field_segmentation(
frames: list[Frame],
config: FieldSegmentationConfig,
inference_url: str | None = None,
job_id: str | None = None,
) -> dict:
"""
Run field segmentation on all frames via the inference server.
Returns dict with:
field_masks: {seq: np.ndarray}
field_boundaries: {seq: [(x,y), ...]}
field_coverage: {seq: float}
"""
if not config.enabled:
emit.log(job_id, "FieldSegmentation", "INFO", "Disabled, skipping")
return {"field_masks": {}, "field_boundaries": {}, "field_coverage": {}}
import os
url = inference_url or os.environ.get("INFERENCE_URL")
if not url:
emit.log(job_id, "FieldSegmentation", "WARNING",
"No INFERENCE_URL, skipping field segmentation")
return {"field_masks": {}, "field_boundaries": {}, "field_coverage": {}}
emit.log(job_id, "FieldSegmentation", "INFO",
f"Segmenting {len(frames)} frames (hue={config.hue_low}-{config.hue_high})")
from core.detect.inference import InferenceClient
from core.detect.emit import _run_log_level
client = InferenceClient(base_url=url, job_id=job_id or "", log_level=_run_log_level)
field_masks = {}
field_boundaries = {}
field_coverage = {}
for frame in frames:
image_b64 = _frame_to_b64(frame)
resp = client.post("/segment_field", {
"image": image_b64,
"hue_low": config.hue_low,
"hue_high": config.hue_high,
"sat_low": config.sat_low,
"sat_high": config.sat_high,
"val_low": config.val_low,
"val_high": config.val_high,
"morph_kernel": config.morph_kernel,
"min_area_ratio": config.min_area_ratio,
})
if resp is None:
continue
mask_b64 = resp.get("mask_b64", "")
if mask_b64:
field_masks[frame.sequence] = _decode_mask_b64(mask_b64)
field_boundaries[frame.sequence] = resp.get("boundary", [])
field_coverage[frame.sequence] = resp.get("coverage", 0.0)
avg_coverage = sum(field_coverage.values()) / max(len(field_coverage), 1)
emit.log(job_id, "FieldSegmentation", "INFO",
f"Done: {len(frames)} frames, avg coverage {avg_coverage:.1%}")
return {
"field_masks": field_masks,
"field_boundaries": field_boundaries,
"field_coverage": field_coverage,
}

View File

@@ -0,0 +1,93 @@
"""
Stage 1 — Frame Extraction
Extracts frames from a video at a configurable FPS using the core ffmpeg module.
Emits log + stats_update SSE events as it works.
"""
from __future__ import annotations
import tempfile
import time
from pathlib import Path
import ffmpeg
import numpy as np
from PIL import Image
from core.ffmpeg.probe import probe_file
from core.detect import emit
from core.detect.models import Frame
from core.detect.stages.models import FrameExtractionConfig
def _load_frames(tmpdir: Path, fps: float) -> list[Frame]:
"""Load extracted JPEG files into Frame objects."""
frame_files = sorted(tmpdir.glob("frame_*.jpg"))
frames = []
for i, fpath in enumerate(frame_files):
img = Image.open(fpath)
frame = Frame(
sequence=i,
chunk_id=0,
timestamp=i / fps,
image=np.array(img),
)
frames.append(frame)
return frames
def extract_frames(
video_path: str,
config: FrameExtractionConfig,
job_id: str | None = None,
) -> list[Frame]:
"""
Extract frames from video at the configured FPS.
Uses ffmpeg-python to build the extraction pipeline,
outputs JPEG files to a temp dir, then loads as numpy arrays.
"""
probe = probe_file(video_path)
duration = probe.duration or 0.0
emit.log(job_id, "FrameExtractor", "INFO",
f"Starting extraction: {Path(video_path).name} "
f"({duration:.1f}s, {probe.width}x{probe.height}, fps={config.fps})")
emit.log(job_id, "FrameExtractor", "DEBUG",
f"Probe: codec={probe.video_codec}, bitrate={probe.video_bitrate}, max_frames={config.max_frames}")
with tempfile.TemporaryDirectory() as tmpdir:
pattern = str(Path(tmpdir) / "frame_%06d.jpg")
stream = (
ffmpeg
.input(video_path)
.filter("fps", fps=config.fps)
.output(pattern, qscale=2, frames=config.max_frames)
.overwrite_output()
)
t0 = time.monotonic()
try:
stream.run(capture_stdout=True, capture_stderr=True, quiet=True)
except ffmpeg.Error as e:
stderr = e.stderr.decode() if e.stderr else "unknown error"
emit.log(job_id, "FrameExtractor", "ERROR", f"FFmpeg failed: {stderr[:200]}")
raise RuntimeError(f"FFmpeg failed: {stderr}") from e
ffmpeg_ms = (time.monotonic() - t0) * 1000
emit.log(job_id, "FrameExtractor", "DEBUG", f"FFmpeg decode: {ffmpeg_ms:.0f}ms")
t0 = time.monotonic()
frames = _load_frames(Path(tmpdir), config.fps)
load_ms = (time.monotonic() - t0) * 1000
if frames:
h, w = frames[0].image.shape[:2]
mem_mb = sum(f.image.nbytes for f in frames) / (1024 * 1024)
emit.log(job_id, "FrameExtractor", "DEBUG",
f"Loaded {len(frames)} frames ({w}x{h}) in {load_ms:.0f}ms, {mem_mb:.1f}MB in memory")
emit.log(job_id, "FrameExtractor", "INFO", f"Extracted {len(frames)} frames")
emit.stats(job_id, frames_extracted=len(frames))
return frames

View File

@@ -0,0 +1,106 @@
"""
Pydantic Models - GENERATED FILE
Do not edit directly. Regenerate using modelgen.
"""
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from uuid import UUID
from pydantic import BaseModel, Field
class StageConfigField(BaseModel):
"""A single tunable config parameter for the editor UI."""
name: str
type: str
default: Any
description: str = ""
min: Optional[float] = None
max: Optional[float] = None
options: Optional[List[str]] = None
class StageIO(BaseModel):
"""Declares what a stage reads and writes."""
reads: List[str] = Field(default_factory=list)
writes: List[str] = Field(default_factory=list)
optional_reads: List[str] = Field(default_factory=list)
class StageDefinition(BaseModel):
"""Complete metadata for a pipeline stage."""
name: str
label: str
description: str
category: str = "detection"
io: StageIO
config_fields: List[StageConfigField] = Field(default_factory=list)
tracks_element: Optional[str] = None
class FrameExtractionConfig(BaseModel):
"""FrameExtractionConfig(fps: float = 2.0, max_frames: int = 500)"""
fps: float = 2.0
max_frames: int = 500
class SceneFilterConfig(BaseModel):
"""SceneFilterConfig(hamming_threshold: int = 8, enabled: bool = True)"""
hamming_threshold: int = 8
enabled: bool = True
class DetectionConfig(BaseModel):
"""DetectionConfig(model_name: str = 'yolov8n.pt', confidence_threshold: float = 0.3, target_classes: List[str] = <factory>)"""
model_name: str = "yolov8n.pt"
confidence_threshold: float = 0.3
target_classes: List[str]
class OCRConfig(BaseModel):
"""OCRConfig(languages: List[str] = <factory>, min_confidence: float = 0.5)"""
languages: List[str]
min_confidence: float = 0.5
class ResolverConfig(BaseModel):
"""ResolverConfig(fuzzy_threshold: int = 75)"""
fuzzy_threshold: int = 75
class RegionAnalysisConfig(BaseModel):
"""RegionAnalysisConfig(enabled: bool = True, edge_canny_low: int = 50, edge_canny_high: int = 150, edge_hough_threshold: int = 80, edge_hough_min_length: int = 100, edge_hough_max_gap: int = 10, edge_pair_max_distance: int = 200, edge_pair_min_distance: int = 15)"""
enabled: bool = True
edge_canny_low: int = 50
edge_canny_high: int = 150
edge_hough_threshold: int = 80
edge_hough_min_length: int = 100
edge_hough_max_gap: int = 10
edge_pair_max_distance: int = 200
edge_pair_min_distance: int = 15
class FieldSegmentationConfig(BaseModel):
"""FieldSegmentationConfig(enabled: bool = True, hue_low: int = 30, hue_high: int = 85, sat_low: int = 30, sat_high: int = 255, val_low: int = 30, val_high: int = 255, morph_kernel: int = 15, min_area_ratio: float = 0.05)"""
enabled: bool = True
hue_low: int = 30
hue_high: int = 85
sat_low: int = 30
sat_high: int = 255
val_low: int = 30
val_high: int = 255
morph_kernel: int = 15
min_area_ratio: float = 0.05
class StageRef(BaseModel):
"""Reference to a stage in the pipeline graph."""
name: str
branch: str = "trunk"
execution_target: str = "local"
class Edge(BaseModel):
"""Connection between stages in the graph."""
source: str
target: str
condition: str = ""
class PipelineConfig(BaseModel):
"""Pipeline graph topology + routing rules."""
name: str
profile_name: str
stages: List[StageRef] = Field(default_factory=list)
edges: List[Edge] = Field(default_factory=list)
routing_rules: Dict[str, Any]

View File

@@ -0,0 +1,139 @@
"""
Stage 4 — OCR
Reads text from detected regions (YOLO bounding box crops).
Two modes:
- remote: calls inference server over HTTP (separate GPU box, or localhost)
- local: runs PaddleOCR in-process (single-box setup with enough VRAM)
The mode is selected by whether inference_url is provided.
Model instances are cached at module level so they survive across pipeline runs.
"""
from __future__ import annotations
import logging
import time
from typing import TYPE_CHECKING
import numpy as np
from core.detect import emit
from core.detect.models import BoundingBox, Frame, TextCandidate
from core.detect.stages.models import OCRConfig
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
# Module-level cache — avoids reloading the model for every crop or pipeline run
_local_ocr_cache: dict[str, object] = {}
def _crop_region(frame: Frame, box: BoundingBox) -> np.ndarray:
h, w = frame.image.shape[:2]
x1 = max(0, box.x)
y1 = max(0, box.y)
x2 = min(w, box.x + box.w)
y2 = min(h, box.y + box.h)
return frame.image[y1:y2, x1:x2]
def _get_local_model(lang: str):
if lang not in _local_ocr_cache:
from paddleocr import PaddleOCR
logger.info("Loading PaddleOCR locally (lang=%s)", lang)
_local_ocr_cache[lang] = PaddleOCR(lang=lang)
return _local_ocr_cache[lang]
def _parse_ocr_raw(raw, min_confidence: float) -> list[dict]:
"""Parse PaddleOCR 3.x result — handles dict-based and nested-list layouts."""
results = []
for page in (raw or []):
if not page:
continue
if isinstance(page, dict):
for text, confidence in zip(page.get("rec_texts", []), page.get("rec_scores", [])):
if float(confidence) >= min_confidence:
results.append({"text": text, "confidence": float(confidence)})
continue
for line in page:
if not line:
continue
rec = line[1]
if isinstance(rec, (list, tuple)) and len(rec) >= 2:
text, confidence = rec[0], rec[1]
if float(confidence) >= min_confidence:
results.append({"text": text, "confidence": float(confidence)})
return results
def run_ocr(
frames: list[Frame],
boxes_by_frame: dict[int, list[BoundingBox]],
config: OCRConfig,
inference_url: str | None = None,
job_id: str | None = None,
) -> list[TextCandidate]:
"""
Run OCR on cropped regions from YOLO detections.
inference_url=None → local in-process PaddleOCR (single-box)
inference_url=str → remote inference server (split or localhost)
"""
total_regions = sum(len(boxes) for boxes in boxes_by_frame.values())
mode = "remote" if inference_url else "local"
emit.log(job_id, "OCRStage", "INFO",
f"Running OCR on {total_regions} regions (mode={mode})")
# Build these once per pipeline run, not per crop
if inference_url:
from core.detect.inference import InferenceClient
from core.detect.emit import _run_log_level
client = InferenceClient(base_url=inference_url, job_id=job_id or "", log_level=_run_log_level)
else:
model = _get_local_model(config.languages[0])
frame_map = {f.sequence: f for f in frames}
candidates: list[TextCandidate] = []
for seq, boxes in boxes_by_frame.items():
frame = frame_map.get(seq)
if not frame:
continue
for box in boxes:
crop = _crop_region(frame, box)
if crop.size == 0:
continue
t0 = time.monotonic()
if inference_url:
raw_results = client.ocr(image=crop, languages=config.languages)
texts = [{"text": r.text, "confidence": r.confidence} for r in raw_results]
else:
raw = model.ocr(crop)
texts = _parse_ocr_raw(raw, config.min_confidence)
ocr_ms = (time.monotonic() - t0) * 1000
h, w = crop.shape[:2]
text_preview = ", ".join(t["text"][:30] for t in texts) if texts else "(none)"
emit.log(job_id, "OCRStage", "DEBUG",
f"Frame {seq} box {box.x},{box.y} ({w}x{h}): {ocr_ms:.0f}ms → {text_preview}")
for t in texts:
candidates.append(TextCandidate(
frame=frame,
bbox=box,
text=t["text"],
ocr_confidence=t["confidence"],
))
emit.log(job_id, "OCRStage", "INFO",
f"Extracted text from {len(candidates)} regions")
emit.stats(job_id, regions_resolved_by_ocr=len(candidates))
return candidates

View File

@@ -0,0 +1,128 @@
"""
Stage 3.5 — Preprocessing
Runs between YOLO detection and OCR. Applies configurable image
preprocessing to each detected region crop: contrast enhancement,
deskewing, binarization.
Operates on the crops derived from boxes_by_frame, produces
preprocessed_crops keyed by (frame_sequence, box_index).
"""
from __future__ import annotations
import logging
import numpy as np
from core.detect import emit
from core.detect.models import BoundingBox, Frame
logger = logging.getLogger(__name__)
def _crop_region(frame: Frame, box: BoundingBox) -> np.ndarray:
h, w = frame.image.shape[:2]
x1 = max(0, box.x)
y1 = max(0, box.y)
x2 = min(w, box.x + box.w)
y2 = min(h, box.y + box.h)
return frame.image[y1:y2, x1:x2]
def preprocess_regions(
frames: list[Frame],
boxes_by_frame: dict[int, list[BoundingBox]],
do_contrast: bool = True,
do_deskew: bool = False,
do_binarize: bool = False,
inference_url: str | None = None,
job_id: str | None = None,
) -> dict[str, np.ndarray]:
"""
Preprocess cropped regions from YOLO detections.
Returns dict keyed by "{frame_seq}_{box_idx}" → preprocessed crop.
These are passed to the OCR stage instead of raw crops.
"""
total_regions = sum(len(boxes) for boxes in boxes_by_frame.values())
any_active = do_contrast or do_deskew or do_binarize
if not any_active:
emit.log(job_id, "Preprocess", "INFO",
f"Preprocessing disabled, passing {total_regions} regions through")
return {}
mode = "remote" if inference_url else "local"
emit.log(job_id, "Preprocess", "INFO",
f"Preprocessing {total_regions} regions (mode={mode}, "
f"contrast={do_contrast}, deskew={do_deskew}, binarize={do_binarize})")
frame_map = {f.sequence: f for f in frames}
preprocessed: dict[str, np.ndarray] = {}
processed_count = 0
for seq, boxes in boxes_by_frame.items():
frame = frame_map.get(seq)
if not frame:
continue
for idx, box in enumerate(boxes):
crop = _crop_region(frame, box)
if crop.size == 0:
continue
key = f"{seq}_{idx}"
if inference_url:
result = _preprocess_remote(crop, inference_url,
do_contrast, do_deskew, do_binarize)
else:
result = _preprocess_local(crop, do_contrast, do_deskew, do_binarize)
preprocessed[key] = result
processed_count += 1
emit.log(job_id, "Preprocess", "INFO",
f"Preprocessed {processed_count} regions")
return preprocessed
def _preprocess_remote(crop: np.ndarray, inference_url: str,
do_contrast: bool, do_deskew: bool, do_binarize: bool) -> np.ndarray:
"""Call GPU server /preprocess endpoint."""
import base64
import io
import requests
from PIL import Image
img = Image.fromarray(crop)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=85)
image_b64 = base64.b64encode(buf.getvalue()).decode()
resp = requests.post(
f"{inference_url.rstrip('/')}/preprocess",
json={
"image": image_b64,
"contrast": do_contrast,
"deskew": do_deskew,
"binarize": do_binarize,
},
timeout=30,
)
resp.raise_for_status()
data = resp.json()
result_bytes = base64.b64decode(data["image"])
result_img = Image.open(io.BytesIO(result_bytes)).convert("RGB")
return np.array(result_img)
def _preprocess_local(crop: np.ndarray,
do_contrast: bool, do_deskew: bool, do_binarize: bool) -> np.ndarray:
"""Run preprocessing in-process (requires opencv-python-headless)."""
from core.gpu.models.preprocess import preprocess
return preprocess(crop, do_binarize=do_binarize, do_deskew=do_deskew, do_contrast=do_contrast)

View File

@@ -0,0 +1,31 @@
"""
Stage registry — registers all built-in stages.
Split by category:
preprocessing.py — extract_frames, filter_scenes
cv_analysis.py — detect_edges (+ future: detect_contours, detect_color, merge_regions)
detection.py — detect_objects, run_ocr
resolution.py — match_brands
escalation.py — escalate_vlm, escalate_cloud
output.py — compile_report
_serializers.py — shared serialization helpers
"""
from . import preprocessing
from . import cv_analysis
from . import detection
from . import resolution
from . import escalation
from . import output
def register_all():
preprocessing.register()
cv_analysis.register()
detection.register()
resolution.register()
escalation.register()
output.register()
register_all()

View File

@@ -0,0 +1,25 @@
"""
Re-export serializers from core/schema/serializers/.
Stage registry modules import from here for convenience.
All serialization logic lives in core/schema/serializers/.
"""
from core.schema.serializers._common import (
safe_construct,
serialize_dataclass,
serialize_dataclass_list,
)
from core.schema.serializers.pipeline import (
serialize_frame_meta,
serialize_frames_with_upload as serialize_frames,
deserialize_frames_with_download as deserialize_frames,
serialize_text_candidate,
serialize_text_candidates,
deserialize_text_candidate,
deserialize_text_candidates,
deserialize_bounding_box,
deserialize_brand_detection,
deserialize_pipeline_stats,
deserialize_detection_report,
)

View File

@@ -0,0 +1,44 @@
"""Registration for CV analysis stages: edge detection."""
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
from core.detect.stages.base import register_stage
from ._serializers import serialize_dataclass_list, deserialize_bounding_box
def _ser_regions(state: dict, job_id: str) -> dict:
regions = state.get("edge_regions_by_frame", {})
serialized = {
str(seq): serialize_dataclass_list(bl) for seq, bl in regions.items()
}
return {"edge_regions_by_frame": serialized}
def _deser_regions(data: dict, job_id: str) -> dict:
regions = {}
for seq_str, box_dicts in data.get("edge_regions_by_frame", {}).items():
regions[int(seq_str)] = [deserialize_bounding_box(b) for b in box_dicts]
return {"edge_regions_by_frame": regions}
def register():
edge_detection = StageDefinition(
name="detect_edges",
label="Edge Detection",
description="Canny + HoughLinesP — find horizontal line pairs (hoarding boundaries)",
category="cv_analysis",
io=StageIO(
reads=["filtered_frames"],
writes=["edge_regions_by_frame"],
),
config_fields=[
StageConfigField(name="enabled", type="bool", default=True, description="Enable region analysis"),
StageConfigField(name="edge_canny_low", type="int", default=50, description="Canny low threshold", min=0, max=255),
StageConfigField(name="edge_canny_high", type="int", default=150, description="Canny high threshold", min=0, max=255),
StageConfigField(name="edge_hough_threshold", type="int", default=80, description="Hough accumulator threshold", min=1, max=500),
StageConfigField(name="edge_hough_min_length", type="int", default=100, description="Min line length (px)", min=10, max=2000),
StageConfigField(name="edge_hough_max_gap", type="int", default=10, description="Max line gap (px)", min=1, max=100),
StageConfigField(name="edge_pair_max_distance", type="int", default=200, description="Max distance between line pair (px)", min=10, max=500),
StageConfigField(name="edge_pair_min_distance", type="int", default=15, description="Min distance between line pair (px)", min=5, max=200),
],
)
register_stage(edge_detection, serialize_fn=_ser_regions, deserialize_fn=_deser_regions)

View File

@@ -0,0 +1,60 @@
"""Registration for detection stages: YOLO, OCR."""
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
from core.detect.stages.base import register_stage
from ._serializers import (
serialize_dataclass_list,
serialize_text_candidates,
deserialize_bounding_box,
)
def _ser_detect(state: dict, job_id: str) -> dict:
boxes = state.get("boxes_by_frame", {})
serialized = {str(seq): serialize_dataclass_list(bl) for seq, bl in boxes.items()}
return {"boxes_by_frame": serialized}
def _deser_detect(data: dict, job_id: str) -> dict:
boxes = {}
for seq_str, box_dicts in data.get("boxes_by_frame", {}).items():
boxes[int(seq_str)] = [deserialize_bounding_box(b) for b in box_dicts]
return {"boxes_by_frame": boxes}
def _ser_ocr(state: dict, job_id: str) -> dict:
candidates = state.get("text_candidates", [])
return {"text_candidates": serialize_text_candidates(candidates)}
def _deser_ocr(data: dict, job_id: str) -> dict:
return {"_text_candidates_raw": data["text_candidates"]}
def register():
yolo = StageDefinition(
name="detect_objects",
label="Object Detection",
description="YOLO object detection on filtered frames",
category="detection",
io=StageIO(reads=["filtered_frames"], writes=["boxes_by_frame"]),
config_fields=[
StageConfigField(name="model_name", type="str", default="yolov8n.pt", description="YOLO model file"),
StageConfigField(name="confidence_threshold", type="float", default=0.3, description="Min detection confidence", min=0.0, max=1.0),
StageConfigField(name="target_classes", type="list[str]", default=[], description="YOLO classes to detect (empty = all)"),
],
)
register_stage(yolo, serialize_fn=_ser_detect, deserialize_fn=_deser_detect)
ocr = StageDefinition(
name="run_ocr",
label="OCR",
description="Extract text from detected regions",
category="detection",
io=StageIO(reads=["filtered_frames", "boxes_by_frame"], writes=["text_candidates"]),
config_fields=[
StageConfigField(name="languages", type="list[str]", default=["en"], description="OCR languages"),
StageConfigField(name="min_confidence", type="float", default=0.5, description="Min OCR confidence", min=0.0, max=1.0),
],
)
register_stage(ocr, serialize_fn=_ser_ocr, deserialize_fn=_deser_ocr)

View File

@@ -0,0 +1,60 @@
"""Registration for escalation stages: local VLM, cloud LLM."""
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
from core.detect.stages.base import register_stage
from ._serializers import (
serialize_dataclass_list,
serialize_text_candidates,
deserialize_brand_detection,
)
def _ser_escalation(state: dict, job_id: str) -> dict:
matched = state.get("detections", [])
unresolved = state.get("unresolved_candidates", [])
return {
"detections": serialize_dataclass_list(matched),
"unresolved_candidates": serialize_text_candidates(unresolved),
}
def _deser_escalation(data: dict, job_id: str) -> dict:
detections = [deserialize_brand_detection(d) for d in data.get("detections", [])]
return {
"detections": detections,
"_unresolved_raw": data.get("unresolved_candidates", []),
}
def register():
vlm = StageDefinition(
name="escalate_vlm",
label="Local VLM",
description="Process unresolved crops with moondream2",
category="escalation",
io=StageIO(
reads=["unresolved_candidates"],
writes=["detections", "unresolved_candidates"],
optional_reads=["source_asset_id"],
),
config_fields=[
StageConfigField(name="min_confidence", type="float", default=0.5, description="Min VLM confidence", min=0.0, max=1.0),
],
)
register_stage(vlm, serialize_fn=_ser_escalation, deserialize_fn=_deser_escalation)
cloud = StageDefinition(
name="escalate_cloud",
label="Cloud LLM",
description="Escalate remaining crops to cloud provider",
category="escalation",
io=StageIO(
reads=["unresolved_candidates"],
writes=["detections"],
optional_reads=["source_asset_id"],
),
config_fields=[
StageConfigField(name="min_confidence", type="float", default=0.4, description="Min cloud confidence", min=0.0, max=1.0),
],
)
register_stage(cloud, serialize_fn=_ser_escalation, deserialize_fn=_deser_escalation)

View File

@@ -0,0 +1,30 @@
"""Registration for output stages: report compilation."""
from core.detect.stages.base import StageDefinition, StageIO, register_stage
from ._serializers import serialize_dataclass, deserialize_detection_report
def _ser_report(state: dict, job_id: str) -> dict:
report = state.get("report")
if report is None:
return {"report": None}
return {"report": serialize_dataclass(report)}
def _deser_report(data: dict, job_id: str) -> dict:
raw = data.get("report")
if raw is None:
return {"report": None}
return {"report": deserialize_detection_report(raw)}
def register():
report = StageDefinition(
name="compile_report",
label="Report",
description="Merge detections and compile final report",
category="output",
io=StageIO(reads=["detections"], writes=["report"]),
config_fields=[],
)
register_stage(report, serialize_fn=_ser_report, deserialize_fn=_deser_report)

View File

@@ -0,0 +1,82 @@
"""Registration for preprocessing stages: frame extraction, scene filter, image preprocessing."""
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
from core.detect.stages.base import register_stage
from ._serializers import serialize_frames, deserialize_frames
def _ser_extract(state: dict, job_id: str) -> dict:
frames = state.get("frames", [])
meta, manifest = serialize_frames(frames, job_id)
return {"frames_meta": meta, "frames_manifest": manifest}
def _deser_extract(data: dict, job_id: str) -> dict:
frames = deserialize_frames(data["frames_meta"], data["frames_manifest"], job_id)
return {"frames": frames}
def _ser_filter(state: dict, job_id: str) -> dict:
filtered = state.get("filtered_frames", [])
seqs = [f.sequence for f in filtered]
return {"filtered_frame_sequences": seqs}
def _deser_filter(data: dict, job_id: str) -> dict:
return {"_filtered_sequences": data["filtered_frame_sequences"]}
def _ser_preprocess(state: dict, job_id: str) -> dict:
# Preprocessed crops are numpy arrays — regenerable from frames + boxes + config
crops = state.get("preprocessed_crops", {})
return {"crop_keys": list(crops.keys()), "count": len(crops)}
def _deser_preprocess(data: dict, job_id: str) -> dict:
# Crops are regenerable — no need to restore from checkpoint
return {"preprocessed_crops": {}}
def register():
extract = StageDefinition(
name="extract_frames",
label="Frame Extraction",
description="Extract frames from video at configurable FPS",
category="preprocessing",
io=StageIO(reads=["video_path"], writes=["frames"]),
config_fields=[
StageConfigField(name="fps", type="float", default=2.0, description="Frames per second", min=0.1, max=30.0),
StageConfigField(name="max_frames", type="int", default=500, description="Maximum frames to extract", min=1, max=10000),
],
)
register_stage(extract, serialize_fn=_ser_extract, deserialize_fn=_deser_extract)
scene_filter = StageDefinition(
name="filter_scenes",
label="Scene Filter",
description="Deduplicate similar frames using perceptual hashing",
category="preprocessing",
io=StageIO(reads=["frames"], writes=["filtered_frames"]),
config_fields=[
StageConfigField(name="hamming_threshold", type="int", default=8, description="Hamming distance threshold", min=0, max=64),
StageConfigField(name="enabled", type="bool", default=True, description="Enable scene filtering"),
],
)
register_stage(scene_filter, serialize_fn=_ser_filter, deserialize_fn=_deser_filter)
preprocess = StageDefinition(
name="preprocess",
label="Preprocess",
description="Image preprocessing on detected regions before OCR",
category="preprocessing",
io=StageIO(
reads=["filtered_frames", "boxes_by_frame"],
writes=["preprocessed_crops"],
),
config_fields=[
StageConfigField(name="contrast", type="bool", default=True, description="CLAHE contrast enhancement"),
StageConfigField(name="deskew", type="bool", default=False, description="Correct slight rotation"),
StageConfigField(name="binarize", type="bool", default=False, description="Otsu binarization"),
],
)
register_stage(preprocess, serialize_fn=_ser_preprocess, deserialize_fn=_deser_preprocess)

View File

@@ -0,0 +1,44 @@
"""Registration for resolution stages: brand resolver."""
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
from core.detect.stages.base import register_stage
from ._serializers import (
serialize_dataclass_list,
serialize_text_candidates,
deserialize_brand_detection,
)
def _ser_brands(state: dict, job_id: str) -> dict:
matched = state.get("detections", [])
unresolved = state.get("unresolved_candidates", [])
return {
"detections": serialize_dataclass_list(matched),
"unresolved_candidates": serialize_text_candidates(unresolved),
}
def _deser_brands(data: dict, job_id: str) -> dict:
detections = [deserialize_brand_detection(d) for d in data.get("detections", [])]
return {
"detections": detections,
"_unresolved_raw": data.get("unresolved_candidates", []),
}
def register():
resolver = StageDefinition(
name="match_brands",
label="Brand Resolver",
description="Match OCR text against known brands (session + global DB)",
category="resolution",
io=StageIO(
reads=["text_candidates"],
writes=["detections", "unresolved_candidates"],
optional_reads=["session_brands", "source_asset_id"],
),
config_fields=[
StageConfigField(name="fuzzy_threshold", type="int", default=75, description="Fuzzy match threshold", min=0, max=100),
],
)
register_stage(resolver, serialize_fn=_ser_brands, deserialize_fn=_deser_brands)

View File

@@ -0,0 +1,86 @@
"""
Stage 2 — Scene Filter
Removes near-duplicate frames using perceptual hashing (pHash).
Frames with a hamming distance below the threshold are considered
duplicates and dropped. This dramatically reduces work for downstream
CV stages without losing unique visual content.
"""
from __future__ import annotations
import time
import imagehash
from PIL import Image
from core.detect import emit
from core.detect.models import Frame
from core.detect.stages.models import SceneFilterConfig
def _compute_hashes(frames: list[Frame]) -> list[imagehash.ImageHash]:
"""Compute perceptual hashes for all frames."""
hashes = []
for f in frames:
img = Image.fromarray(f.image)
h = imagehash.phash(img)
f.perceptual_hash = str(h)
hashes.append(h)
return hashes
def _dedup(frames: list[Frame], hashes: list[imagehash.ImageHash], threshold: int) -> list[Frame]:
"""Greedy dedup: keep a frame if it's sufficiently different from all kept frames."""
kept = [frames[0]]
kept_hashes = [hashes[0]]
for i in range(1, len(frames)):
is_duplicate = any(hashes[i] - kh < threshold for kh in kept_hashes)
if not is_duplicate:
kept.append(frames[i])
kept_hashes.append(hashes[i])
return kept
def scene_filter(
frames: list[Frame],
config: SceneFilterConfig,
job_id: str | None = None,
) -> list[Frame]:
"""
Filter near-duplicate frames based on perceptual hash distance.
Keeps the first frame in each group of similar frames.
Returns a new list — does not mutate the input.
"""
if not config.enabled:
emit.log(job_id, "SceneFilter", "INFO", "Scene filter disabled, passing all frames through")
return frames
if not frames:
return []
emit.log(job_id, "SceneFilter", "INFO",
f"Filtering {len(frames)} frames (hamming_threshold={config.hamming_threshold})")
t0 = time.monotonic()
hashes = _compute_hashes(frames)
hash_ms = (time.monotonic() - t0) * 1000
emit.log(job_id, "SceneFilter", "DEBUG",
f"Computed {len(hashes)} perceptual hashes in {hash_ms:.0f}ms ({hash_ms/max(len(hashes),1):.1f}ms/frame)")
t0 = time.monotonic()
kept = _dedup(frames, hashes, config.hamming_threshold)
dedup_ms = (time.monotonic() - t0) * 1000
emit.log(job_id, "SceneFilter", "DEBUG", f"Dedup pass: {dedup_ms:.0f}ms")
dropped = len(frames) - len(kept)
pct = (dropped / len(frames) * 100) if frames else 0
emit.log(job_id, "SceneFilter", "INFO",
f"Kept {len(kept)} frames, dropped {dropped} ({pct:.0f}% reduction)")
emit.stats(job_id, frames_extracted=len(frames), frames_after_scene_filter=len(kept))
return kept

View File

@@ -0,0 +1,201 @@
"""
Stage 7 — Cloud LLM escalation
Last resort for crops the local VLM couldn't resolve.
Provider-agnostic — switch via CLOUD_LLM_PROVIDER env var.
Each provider has its own file under detect/providers/.
Tracks token usage and cost.
"""
from __future__ import annotations
import base64
import io
import logging
import os
import time
import numpy as np
from PIL import Image
from core.detect import emit
from core.detect.models import BrandDetection, PipelineStats, TextCandidate
from core.detect.models import CropContext
from core.detect.providers import get_provider, has_api_key
logger = logging.getLogger(__name__)
ESTIMATED_TOKENS_PER_CROP = 500
def _register_discovered_brand(brand: str, source_asset_id: str | None,
timestamp: float, confidence: float):
"""Register a cloud-confirmed brand in the DB."""
try:
from core.detect.stages.brand_resolver import _register_brand, _record_sighting
brand_id = _register_brand(brand, "cloud_llm")
if brand_id and source_asset_id:
_record_sighting(source_asset_id, brand_id, brand, timestamp, confidence, "cloud_llm")
except Exception as e:
logger.debug("Failed to register brand %s: %s", brand, e)
def _encode_crop(crop: np.ndarray) -> str:
img = Image.fromarray(crop)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=85)
return base64.b64encode(buf.getvalue()).decode()
def _crop_image(candidate: TextCandidate) -> np.ndarray:
frame = candidate.frame
box = candidate.bbox
h, w = frame.image.shape[:2]
x1 = max(0, box.x)
y1 = max(0, box.y)
x2 = min(w, box.x + box.w)
y2 = min(h, box.y + box.h)
return frame.image[y1:y2, x1:x2]
def _parse_response(answer: str, total_tokens: int) -> dict:
"""Parse LLM free-text response into structured output."""
parts = [p.strip() for p in answer.split(",", 2)]
brand = parts[0] if parts else ""
confidence = 0.5
reasoning = answer
if len(parts) >= 2:
try:
confidence = float(parts[1])
confidence = max(0.0, min(1.0, confidence))
except ValueError:
pass
if len(parts) >= 3:
reasoning = parts[2]
return {
"brand": brand,
"confidence": confidence,
"reasoning": reasoning,
"tokens": total_tokens or ESTIMATED_TOKENS_PER_CROP,
}
def _call_cloud_api(image_b64: str, prompt: str) -> dict:
"""Route to the configured provider and parse the response."""
provider = get_provider()
result = provider.call(image_b64, prompt)
return _parse_response(result.answer, result.total_tokens)
def escalate_cloud(
candidates: list[TextCandidate],
vlm_prompt_fn,
stats: PipelineStats,
min_confidence: float = 0.4,
content_type: str = "",
source_asset_id: str | None = None,
job_id: str | None = None,
) -> list[BrandDetection]:
"""
Send remaining unresolved crops to cloud LLM.
Provider is selected via CLOUD_LLM_PROVIDER env var (groq, gemini, openai).
Updates stats with call count and cost.
"""
if not candidates:
return []
if os.environ.get("SKIP_CLOUD", "").strip() == "1":
emit.log(job_id, "CloudLLM", "INFO",
f"SKIP_CLOUD=1, skipping {len(candidates)} crops")
return []
if not has_api_key():
emit.log(job_id, "CloudLLM", "WARNING",
f"No API key set for cloud provider, skipping {len(candidates)} crops")
return []
provider = get_provider()
emit.log(job_id, "CloudLLM", "INFO",
f"Escalating {len(candidates)} crops to {provider.name}")
matched: list[BrandDetection] = []
total_cost = 0.0
for i, candidate in enumerate(candidates):
crop = _crop_image(candidate)
if crop.size == 0:
continue
crop_context = CropContext(
image=b"",
surrounding_text=candidate.text,
position_hint=f"frame {candidate.frame.sequence}",
)
prompt = vlm_prompt_fn(crop_context)
image_b64 = _encode_crop(crop)
t0 = time.monotonic()
try:
result = _call_cloud_api(image_b64, prompt)
except Exception as e:
call_ms = (time.monotonic() - t0) * 1000
emit.log(job_id, "CloudLLM", "DEBUG",
f"[{i+1}/{len(candidates)}] FAILED '{candidate.text[:30]}': {e} ({call_ms:.0f}ms)")
continue
call_ms = (time.monotonic() - t0) * 1000
stats.cloud_llm_calls += 1
model_info = provider.models.get(provider.model)
cost_per_token = model_info.cost_per_input_token if model_info else 0.00001
call_cost = result["tokens"] * cost_per_token
total_cost += call_cost
brand = result["brand"]
confidence = result["confidence"]
emit.log(job_id, "CloudLLM", "DEBUG",
f"[{i+1}/{len(candidates)}] '{candidate.text[:30]}'"
f"{'' + brand if brand else ''} "
f"(conf={confidence:.2f}, {result['tokens']}tok, ${call_cost:.4f}, {call_ms:.0f}ms)")
if brand and confidence >= min_confidence:
detection = BrandDetection(
brand=brand,
timestamp=candidate.frame.timestamp,
duration=0.5,
confidence=confidence,
source="cloud_llm",
bbox=candidate.bbox,
frame_ref=candidate.frame.sequence,
content_type=content_type,
)
matched.append(detection)
emit.detection(
job_id,
brand=brand,
confidence=confidence,
source="cloud_llm",
timestamp=candidate.frame.timestamp,
content_type=content_type,
frame_ref=candidate.frame.sequence,
)
# Register newly discovered brand in DB
_register_discovered_brand(brand, source_asset_id,
candidate.frame.timestamp, confidence)
stats.estimated_cloud_cost_usd += total_cost
stats.regions_escalated_to_cloud_llm = len(candidates)
emit.log(job_id, "CloudLLM", "INFO",
f"Cloud resolved {len(matched)}/{len(candidates)}"
f"cost ${total_cost:.4f} ({stats.cloud_llm_calls} calls total)")
return matched

View File

@@ -0,0 +1,157 @@
"""
Stage 6 — Local VLM escalation (moondream2)
Processes unresolved text candidates by sending crop images + prompt
to the local VLM on the inference server. Produces BrandDetection
objects for crops the VLM can identify.
"""
from __future__ import annotations
import logging
import os
import time
import numpy as np
from core.detect import emit
from core.detect.models import BrandDetection, TextCandidate
from core.detect.models import CropContext
logger = logging.getLogger(__name__)
def _register_discovered_brand(brand: str, source_asset_id: str | None,
timestamp: float, confidence: float, source: str):
"""Register a VLM-confirmed brand in the DB."""
try:
from core.detect.stages.brand_resolver import _register_brand, _record_sighting
brand_id = _register_brand(brand, source)
if brand_id and source_asset_id:
_record_sighting(source_asset_id, brand_id, brand, timestamp, confidence, source)
except Exception as e:
logger.debug("Failed to register brand %s: %s", brand, e)
def _crop_image(candidate: TextCandidate) -> np.ndarray:
frame = candidate.frame
box = candidate.bbox
h, w = frame.image.shape[:2]
x1 = max(0, box.x)
y1 = max(0, box.y)
x2 = min(w, box.x + box.w)
y2 = min(h, box.y + box.h)
return frame.image[y1:y2, x1:x2]
def escalate_vlm(
candidates: list[TextCandidate],
vlm_prompt_fn,
inference_url: str | None = None,
min_confidence: float = 0.5,
content_type: str = "",
source_asset_id: str | None = None,
job_id: str | None = None,
) -> tuple[list[BrandDetection], list[TextCandidate]]:
"""
Send unresolved crops to local VLM for brand identification.
Returns:
- matched: BrandDetections the VLM confirmed
- still_unresolved: candidates the VLM couldn't resolve (→ cloud escalation)
"""
if not candidates:
return [], []
if os.environ.get("SKIP_VLM", "").strip() == "1":
emit.log(job_id, "VLMLocal", "INFO",
f"SKIP_VLM=1, skipping {len(candidates)} crops")
return [], candidates
emit.log(job_id, "VLMLocal", "INFO",
f"Processing {len(candidates)} unresolved crops with moondream2")
matched: list[BrandDetection] = []
still_unresolved: list[TextCandidate] = []
if inference_url:
from core.detect.inference import InferenceClient
from core.detect.emit import _run_log_level
client = InferenceClient(base_url=inference_url, job_id=job_id or "", log_level=_run_log_level)
for i, candidate in enumerate(candidates):
crop = _crop_image(candidate)
if crop.size == 0:
still_unresolved.append(candidate)
continue
crop_context = CropContext(
image=b"", # not used for prompt generation
surrounding_text=candidate.text,
position_hint=f"frame {candidate.frame.sequence}",
)
prompt = vlm_prompt_fn(crop_context)
t0 = time.monotonic()
try:
if inference_url:
result = client.vlm(image=crop, prompt=prompt)
brand = result.brand
confidence = result.confidence
reasoning = result.reasoning
else:
brand, confidence, reasoning = _vlm_local(crop, prompt)
except Exception as e:
vlm_ms = (time.monotonic() - t0) * 1000
emit.log(job_id, "VLMLocal", "DEBUG",
f"[{i+1}/{len(candidates)}] FAILED '{candidate.text[:30]}': {e} ({vlm_ms:.0f}ms)")
still_unresolved.append(candidate)
continue
vlm_ms = (time.monotonic() - t0) * 1000
emit.log(job_id, "VLMLocal", "DEBUG",
f"[{i+1}/{len(candidates)}] '{candidate.text[:30]}'"
f"{'' + brand if brand else '✗ unresolved'} "
f"(conf={confidence:.2f}, {vlm_ms:.0f}ms)")
if brand and confidence >= min_confidence:
detection = BrandDetection(
brand=brand,
timestamp=candidate.frame.timestamp,
duration=0.5,
confidence=confidence,
source="local_vlm",
bbox=candidate.bbox,
frame_ref=candidate.frame.sequence,
content_type=content_type,
)
matched.append(detection)
emit.detection(
job_id,
brand=brand,
confidence=confidence,
source="local_vlm",
timestamp=candidate.frame.timestamp,
content_type=content_type,
frame_ref=candidate.frame.sequence,
)
# Register newly discovered brand in DB
_register_discovered_brand(brand, source_asset_id,
candidate.frame.timestamp, confidence, "local_vlm")
logger.debug("VLM matched: %s (%.2f) — %s", brand, confidence, reasoning)
else:
still_unresolved.append(candidate)
emit.log(job_id, "VLMLocal", "INFO",
f"VLM resolved {len(matched)}, unresolved {len(still_unresolved)} → cloud")
return matched, still_unresolved
def _vlm_local(crop: np.ndarray, prompt: str) -> tuple[str, float, str]:
"""Run moondream2 in-process (single-box mode)."""
from core.gpu.models.vlm import query
result = query(crop, prompt)
return result["brand"], result["confidence"], result["reasoning"]

View File

@@ -0,0 +1,138 @@
"""
Stage 3 — YOLO Object Detection
Detects regions of interest (logos, text, banners) in frames.
Two modes:
- Remote: calls inference server over HTTP (GPU on another machine)
- Local: imports ultralytics directly (GPU on same machine)
Emits frame_update events with bounding boxes for the UI.
"""
from __future__ import annotations
import base64
import io
import logging
import time
from PIL import Image
from core.detect import emit
from core.detect.models import BoundingBox, Frame
from core.detect.stages.models import DetectionConfig
logger = logging.getLogger(__name__)
def _frame_to_b64(frame: Frame) -> str:
"""Encode frame as base64 JPEG for SSE frame_update events."""
img = Image.fromarray(frame.image)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=70)
return base64.b64encode(buf.getvalue()).decode()
def _detect_remote(frame: Frame, config: DetectionConfig, inference_url: str,
job_id: str = "", log_level: str = "INFO") -> list[BoundingBox]:
"""Call the inference server over HTTP."""
from core.detect.inference import InferenceClient
client = InferenceClient(base_url=inference_url, job_id=job_id, log_level=log_level)
results = client.detect(
image=frame.image,
model=config.model_name,
confidence=config.confidence_threshold,
target_classes=config.target_classes,
)
boxes = []
for r in results:
box = BoundingBox(
x=r.x, y=r.y, w=r.w, h=r.h,
confidence=r.confidence, label=r.label,
)
boxes.append(box)
return boxes
def _detect_local(frame: Frame, config: DetectionConfig) -> list[BoundingBox]:
"""Run YOLO in-process (requires ultralytics installed)."""
from ultralytics import YOLO
model = YOLO(config.model_name)
results = model(frame.image, conf=config.confidence_threshold, verbose=False)
boxes = []
for r in results:
for det in r.boxes:
x1, y1, x2, y2 = det.xyxy[0].tolist()
label = r.names[int(det.cls[0])]
if config.target_classes and label not in config.target_classes:
continue
box = BoundingBox(
x=int(x1), y=int(y1),
w=int(x2 - x1), h=int(y2 - y1),
confidence=float(det.conf[0]),
label=label,
)
boxes.append(box)
return boxes
def detect_objects(
frames: list[Frame],
config: DetectionConfig,
inference_url: str | None = None,
job_id: str | None = None,
) -> dict[int, list[BoundingBox]]:
"""
Run object detection on all frames.
If inference_url is provided, calls the remote GPU server.
Otherwise, imports ultralytics and runs locally.
Returns a dict mapping frame sequence → list of bounding boxes.
"""
mode = "remote" if inference_url else "local"
emit.log(job_id, "YOLODetector", "INFO",
f"Detecting objects in {len(frames)} frames "
f"(model={config.model_name}, conf={config.confidence_threshold}, mode={mode})")
all_boxes: dict[int, list[BoundingBox]] = {}
total_regions = 0
for i, frame in enumerate(frames):
t0 = time.monotonic()
if inference_url:
from core.detect.emit import _run_log_level
boxes = _detect_remote(frame, config, inference_url,
job_id=job_id or "", log_level=_run_log_level)
else:
boxes = _detect_local(frame, config)
det_ms = (time.monotonic() - t0) * 1000
all_boxes[frame.sequence] = boxes
total_regions += len(boxes)
emit.log(job_id, "YOLODetector", "DEBUG",
f"Frame {frame.sequence}: {len(boxes)} regions in {det_ms:.0f}ms"
f" [{', '.join(b.label for b in boxes)}]" if boxes else
f"Frame {frame.sequence}: 0 regions in {det_ms:.0f}ms")
if boxes and job_id:
box_dicts = [{"x": b.x, "y": b.y, "w": b.w, "h": b.h,
"confidence": b.confidence, "label": b.label}
for b in boxes]
emit.frame_update(
job_id,
frame_ref=frame.sequence,
timestamp=frame.timestamp,
jpeg_b64=_frame_to_b64(frame),
boxes=box_dicts,
)
emit.log(job_id, "YOLODetector", "INFO",
f"Detected {total_regions} regions across {len(frames)} frames")
emit.stats(job_id, regions_detected=total_regions)
return all_boxes

43
core/detect/state.py Normal file
View File

@@ -0,0 +1,43 @@
"""
LangGraph state definition for the detection pipeline.
This TypedDict flows through all graph nodes. Each node reads what
it needs and writes its outputs. LangGraph manages the state transitions.
"""
from __future__ import annotations
from typing import TypedDict
from core.detect.models import BoundingBox, BrandDetection, DetectionReport, Frame, PipelineStats, TextCandidate
class DetectState(TypedDict, total=False):
# Input
video_path: str
job_id: str
profile_name: str
source_asset_id: str # UUID of the source MediaAsset
# Stage outputs
frames: list[Frame]
filtered_frames: list[Frame]
field_masks: dict # {seq: np.ndarray} — pitch mask per frame
field_boundaries: dict # {seq: [(x,y), ...]} — pitch boundary per frame
field_coverage: dict # {seq: float} — pitch coverage ratio per frame
edge_regions_by_frame: dict[int, list[BoundingBox]]
boxes_by_frame: dict[int, list[BoundingBox]]
preprocessed_crops: dict # "{frame_seq}_{box_idx}" → np.ndarray
text_candidates: list[TextCandidate]
unresolved_candidates: list[TextCandidate]
detections: list[BrandDetection]
report: DetectionReport
# Session brands (accumulated during the run, persisted to DB)
session_brands: dict # {normalized_name: canonical_name}
# Running stats (updated by each stage)
stats: PipelineStats
# Config overrides for replay (merged into profile configs dict)
config_overrides: dict

131
core/detect/tracing.py Normal file
View File

@@ -0,0 +1,131 @@
"""
Langfuse tracing for the detection pipeline.
Provides span helpers that graph nodes use to record timing, frame counts,
and stage-level metadata. The Langfuse client is optional — if not configured
(no LANGFUSE_SECRET_KEY), tracing is a no-op.
Usage in graph nodes:
from core.detect.tracing import trace_node
def node_extract_frames(state):
with trace_node(state, "extract_frames") as span:
...
span.set_output({"frames": len(frames)})
return {...}
"""
from __future__ import annotations
import logging
import os
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
logger = logging.getLogger(__name__)
_client = None
_enabled: bool | None = None
def _get_client():
"""Lazy-init Langfuse client. Returns None if not configured."""
global _client, _enabled
if _enabled is False:
return None
if _client is not None:
return _client
secret = os.environ.get("LANGFUSE_SECRET_KEY", "")
if not secret:
_enabled = False
logger.info("Langfuse not configured (no LANGFUSE_SECRET_KEY), tracing disabled")
return None
try:
from langfuse import Langfuse
_client = Langfuse()
_enabled = True
logger.info("Langfuse tracing enabled")
return _client
except Exception as e:
_enabled = False
logger.warning("Langfuse init failed: %s — tracing disabled", e)
return None
@dataclass
class SpanContext:
"""Wraps a Langfuse span with convenience methods."""
_span: object | None = None
_start: float = field(default_factory=time.monotonic)
metadata: dict = field(default_factory=dict)
def set_output(self, output: dict) -> None:
self.metadata.update(output)
def set_error(self, error: str) -> None:
self.metadata["error"] = error
def _finish(self, status: str = "ok") -> None:
elapsed = time.monotonic() - self._start
self.metadata["duration_seconds"] = round(elapsed, 3)
self.metadata["status"] = status
if self._span is not None:
try:
self._span.update(
output=self.metadata,
level="ERROR" if status == "error" else "DEFAULT",
)
self._span.end()
except Exception as e:
logger.debug("Failed to end Langfuse span: %s", e)
@contextmanager
def trace_node(state: dict, node_name: str):
"""
Context manager that creates a Langfuse span for a pipeline node.
Usage:
with trace_node(state, "extract_frames") as span:
frames = do_work()
span.set_output({"frames": len(frames)})
"""
job_id = state.get("job_id", "unknown")
profile = state.get("profile_name", "")
client = _get_client()
span_obj = None
if client is not None:
try:
trace = client.trace(
name=f"detect:{job_id}",
session_id=job_id,
metadata={"profile": profile},
)
span_obj = trace.span(
name=node_name,
input={"job_id": job_id, "profile": profile},
)
except Exception as e:
logger.debug("Failed to create Langfuse span: %s", e)
ctx = SpanContext(_span=span_obj)
try:
yield ctx
ctx._finish("ok")
except Exception:
ctx._finish("error")
raise
def flush():
"""Flush pending Langfuse events. Call at pipeline end."""
if _client is not None:
try:
_client.flush()
except Exception as e:
logger.debug("Langfuse flush failed: %s", e)