phase 4
This commit is contained in:
0
core/detect/__init__.py
Normal file
0
core/detect/__init__.py
Normal file
19
core/detect/checkpoint/__init__.py
Normal file
19
core/detect/checkpoint/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
Checkpoint system — Timeline + Checkpoint tree.
|
||||
|
||||
detect/checkpoint/
|
||||
frames.py — frame image S3 upload/download
|
||||
storage.py — Timeline + Checkpoint (Postgres + MinIO)
|
||||
replay.py — replay (TODO: migrate to new model)
|
||||
runner_bridge.py — checkpoint hook for PipelineRunner
|
||||
"""
|
||||
|
||||
from .storage import (
|
||||
create_timeline,
|
||||
get_timeline_frames,
|
||||
get_timeline_frames_b64,
|
||||
save_stage_output,
|
||||
load_stage_output,
|
||||
)
|
||||
from .frames import save_frames, load_frames
|
||||
from .runner_bridge import checkpoint_after_stage, reset_checkpoint_state, get_timeline_id
|
||||
111
core/detect/checkpoint/frames.py
Normal file
111
core/detect/checkpoint/frames.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""Frame image storage — save/load to S3/MinIO as JPEGs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from core.detect.models import Frame
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BUCKET = os.environ.get("S3_BUCKET", "mpr")
|
||||
CHECKPOINT_PREFIX = "checkpoints"
|
||||
|
||||
|
||||
def save_frames(job_id: str, frames: list[Frame]) -> dict[int, str]:
|
||||
"""
|
||||
Save frame images to S3 as JPEGs.
|
||||
|
||||
Returns manifest: {sequence: s3_key}
|
||||
"""
|
||||
from core.storage.s3 import upload_file
|
||||
|
||||
manifest = {}
|
||||
|
||||
for frame in frames:
|
||||
key = f"{CHECKPOINT_PREFIX}/{job_id}/frames/{frame.sequence}.jpg"
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
|
||||
img = Image.fromarray(frame.image)
|
||||
img.save(tmp, format="JPEG", quality=85)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
upload_file(tmp_path, BUCKET, key)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
manifest[frame.sequence] = key
|
||||
|
||||
logger.info("Saved %d frames to s3://%s/%s/%s/frames/",
|
||||
len(frames), BUCKET, CHECKPOINT_PREFIX, job_id)
|
||||
return manifest
|
||||
|
||||
|
||||
def load_frames(manifest: dict[int, str], frame_metadata: list[dict]) -> list[Frame]:
|
||||
"""
|
||||
Load frame images from S3 and reconstitute Frame objects.
|
||||
|
||||
frame_metadata: list of dicts with sequence, chunk_id, timestamp, perceptual_hash.
|
||||
"""
|
||||
from core.storage.s3 import download_to_temp
|
||||
|
||||
meta_map = {m["sequence"]: m for m in frame_metadata}
|
||||
frames = []
|
||||
|
||||
for seq, key in manifest.items():
|
||||
tmp_path = download_to_temp(BUCKET, key)
|
||||
try:
|
||||
img = Image.open(tmp_path).convert("RGB")
|
||||
image_array = np.array(img)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
meta = meta_map.get(seq, {})
|
||||
frame = Frame(
|
||||
sequence=seq,
|
||||
chunk_id=meta.get("chunk_id", 0),
|
||||
timestamp=meta.get("timestamp", 0.0),
|
||||
image=image_array,
|
||||
perceptual_hash=meta.get("perceptual_hash", ""),
|
||||
)
|
||||
frames.append(frame)
|
||||
|
||||
frames.sort(key=lambda f: f.sequence)
|
||||
return frames
|
||||
|
||||
|
||||
def load_frames_b64(manifest: dict[int, str], frame_metadata: list[dict]) -> list[dict]:
|
||||
"""
|
||||
Load frame images from S3 as base64 JPEG — lightweight, no numpy.
|
||||
|
||||
Returns list of dicts: {seq, timestamp, jpeg_b64}
|
||||
"""
|
||||
import base64
|
||||
from core.storage.s3 import download_to_temp
|
||||
|
||||
meta_map = {m["sequence"]: m for m in frame_metadata}
|
||||
frames = []
|
||||
|
||||
for seq, key in manifest.items():
|
||||
tmp_path = download_to_temp(BUCKET, key)
|
||||
try:
|
||||
with open(tmp_path, "rb") as f:
|
||||
jpeg_bytes = f.read()
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
meta = meta_map.get(seq, {})
|
||||
frames.append({
|
||||
"seq": seq,
|
||||
"timestamp": meta.get("timestamp", 0.0),
|
||||
"jpeg_b64": base64.b64encode(jpeg_bytes).decode(),
|
||||
})
|
||||
|
||||
frames.sort(key=lambda f: f["seq"])
|
||||
return frames
|
||||
232
core/detect/checkpoint/replay.py
Normal file
232
core/detect/checkpoint/replay.py
Normal file
@@ -0,0 +1,232 @@
|
||||
"""
|
||||
Pipeline replay — re-run from any stage with different config.
|
||||
|
||||
Loads a checkpoint, applies config overrides, builds a subgraph
|
||||
starting from the target stage, and invokes it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import uuid
|
||||
|
||||
from core.detect import emit
|
||||
# TODO: migrate to Timeline/Branch/Checkpoint model
|
||||
# These old functions no longer exist — replay needs rework
|
||||
def _not_migrated(*args, **kwargs):
|
||||
raise NotImplementedError("Replay not yet migrated to Timeline/Branch/Checkpoint model")
|
||||
|
||||
load_checkpoint = _not_migrated
|
||||
list_checkpoints = _not_migrated
|
||||
from core.detect.graph import NODES, build_graph
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
# OverrideProfile removed — config overrides are now handled by dict merging
|
||||
# in _load_profile() (nodes.py) and replay_single_stage (below).
|
||||
|
||||
|
||||
def replay_from(
|
||||
job_id: str,
|
||||
start_stage: str,
|
||||
config_overrides: dict | None = None,
|
||||
checkpoint: bool = True,
|
||||
) -> dict:
|
||||
"""
|
||||
Replay the pipeline from a specific stage.
|
||||
|
||||
Loads the checkpoint from the stage immediately before start_stage,
|
||||
applies config overrides, and runs the subgraph from start_stage onward.
|
||||
|
||||
Returns the final state dict.
|
||||
"""
|
||||
if start_stage not in NODES:
|
||||
raise ValueError(f"Unknown stage: {start_stage!r}. Options: {NODES}")
|
||||
|
||||
start_idx = NODES.index(start_stage)
|
||||
|
||||
# Load checkpoint from the stage before start_stage
|
||||
if start_idx == 0:
|
||||
raise ValueError("Cannot replay from the first stage — just run the full pipeline")
|
||||
|
||||
previous_stage = NODES[start_idx - 1]
|
||||
|
||||
available = list_checkpoints(job_id)
|
||||
if previous_stage not in available:
|
||||
raise ValueError(
|
||||
f"No checkpoint for stage {previous_stage!r} (job {job_id}). "
|
||||
f"Available: {available}"
|
||||
)
|
||||
|
||||
logger.info("Replaying job %s from %s (loading checkpoint: %s)",
|
||||
job_id, start_stage, previous_stage)
|
||||
|
||||
state = load_checkpoint(job_id, previous_stage)
|
||||
|
||||
# Apply config overrides
|
||||
if config_overrides:
|
||||
state["config_overrides"] = config_overrides
|
||||
|
||||
# Set run context for SSE events
|
||||
run_id = str(uuid.uuid4())[:8]
|
||||
emit.set_run_context(
|
||||
run_id=run_id,
|
||||
parent_job_id=job_id,
|
||||
run_type="replay",
|
||||
)
|
||||
|
||||
# Build subgraph starting from start_stage
|
||||
graph = build_graph(checkpoint=checkpoint, start_from=start_stage)
|
||||
pipeline = graph.compile()
|
||||
|
||||
try:
|
||||
result = pipeline.invoke(state)
|
||||
finally:
|
||||
emit.clear_run_context()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def replay_single_stage(
|
||||
job_id: str,
|
||||
stage: str,
|
||||
frame_refs: list[int] | None = None,
|
||||
config_overrides: dict | None = None,
|
||||
debug: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Replay a single stage on specific frames (or all frames from checkpoint).
|
||||
|
||||
Fast path for interactive parameter tuning — runs only the target stage
|
||||
function, not the full pipeline tail. Returns the stage output directly.
|
||||
|
||||
When debug=True and stage is detect_edges, returns additional overlay
|
||||
data (Canny edges, Hough lines) for visual feedback in the editor.
|
||||
|
||||
For detect_edges: returns {"edge_regions_by_frame": {seq: [box, ...]}}
|
||||
With debug=True, also returns {"debug": {seq: {edge_overlay_b64, lines_overlay_b64, ...}}}
|
||||
"""
|
||||
if stage not in NODES:
|
||||
raise ValueError(f"Unknown stage: {stage!r}. Options: {NODES}")
|
||||
|
||||
stage_idx = NODES.index(stage)
|
||||
if stage_idx == 0:
|
||||
raise ValueError("Cannot replay the first stage — just run the full pipeline")
|
||||
|
||||
previous_stage = NODES[stage_idx - 1]
|
||||
|
||||
available = list_checkpoints(job_id)
|
||||
if previous_stage not in available:
|
||||
raise ValueError(
|
||||
f"No checkpoint for stage {previous_stage!r} (job {job_id}). "
|
||||
f"Available: {available}"
|
||||
)
|
||||
|
||||
logger.info("Single-stage replay: job %s, stage %s (loading checkpoint: %s, debug=%s)",
|
||||
job_id, stage, previous_stage, debug)
|
||||
|
||||
state = load_checkpoint(job_id, previous_stage)
|
||||
|
||||
# Build profile with overrides
|
||||
from core.detect.profile import get_profile, get_stage_config
|
||||
profile = get_profile(state.get("profile_name", "soccer_broadcast"))
|
||||
if config_overrides:
|
||||
merged_configs = dict(profile.get("configs", {}))
|
||||
for sname, soverrides in config_overrides.items():
|
||||
if sname in merged_configs:
|
||||
merged_configs[sname] = {**merged_configs[sname], **soverrides}
|
||||
else:
|
||||
merged_configs[sname] = soverrides
|
||||
profile = {**profile, "configs": merged_configs}
|
||||
|
||||
# Run the stage function directly (not through the graph)
|
||||
if stage == "detect_edges":
|
||||
return _replay_detect_edges(state, profile, frame_refs, job_id, debug)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Single-stage replay not yet implemented for {stage!r}. "
|
||||
f"Use replay_from() for full pipeline replay."
|
||||
)
|
||||
|
||||
|
||||
def _replay_detect_edges(
|
||||
state: dict,
|
||||
profile,
|
||||
frame_refs: list[int] | None,
|
||||
job_id: str,
|
||||
debug: bool,
|
||||
) -> dict:
|
||||
"""Run edge detection on checkpoint frames, optionally with debug overlays."""
|
||||
import os
|
||||
from core.detect.stages.edge_detector import detect_edge_regions
|
||||
|
||||
from core.detect.profile import get_stage_config
|
||||
from core.detect.stages.models import RegionAnalysisConfig
|
||||
config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
|
||||
frames = state.get("filtered_frames", [])
|
||||
|
||||
if frame_refs:
|
||||
ref_set = set(frame_refs)
|
||||
frames = [f for f in frames if f.sequence in ref_set]
|
||||
|
||||
inference_url = os.environ.get("INFERENCE_URL")
|
||||
|
||||
# Normal run — always needed for the boxes
|
||||
result = detect_edge_regions(
|
||||
frames=frames,
|
||||
config=config,
|
||||
inference_url=inference_url,
|
||||
job_id=job_id,
|
||||
)
|
||||
output = {"edge_regions_by_frame": result}
|
||||
|
||||
# Debug overlays — call debug endpoint (remote) or local debug function
|
||||
if debug and frames:
|
||||
debug_data = {}
|
||||
if inference_url:
|
||||
from core.detect.inference import InferenceClient
|
||||
client = InferenceClient(base_url=inference_url, job_id=job_id)
|
||||
for frame in frames:
|
||||
dr = client.detect_edges_debug(
|
||||
image=frame.image,
|
||||
edge_canny_low=config.edge_canny_low,
|
||||
edge_canny_high=config.edge_canny_high,
|
||||
edge_hough_threshold=config.edge_hough_threshold,
|
||||
edge_hough_min_length=config.edge_hough_min_length,
|
||||
edge_hough_max_gap=config.edge_hough_max_gap,
|
||||
edge_pair_max_distance=config.edge_pair_max_distance,
|
||||
edge_pair_min_distance=config.edge_pair_min_distance,
|
||||
)
|
||||
debug_data[frame.sequence] = {
|
||||
"edge_overlay_b64": dr.edge_overlay_b64,
|
||||
"lines_overlay_b64": dr.lines_overlay_b64,
|
||||
"horizontal_count": dr.horizontal_count,
|
||||
"pair_count": dr.pair_count,
|
||||
}
|
||||
else:
|
||||
# Local mode — import GPU module directly
|
||||
from core.detect.stages.edge_detector import _load_cv_edges
|
||||
edges_mod = _load_cv_edges()
|
||||
for frame in frames:
|
||||
dr = edges_mod.detect_edges_debug(
|
||||
frame.image,
|
||||
canny_low=config.edge_canny_low,
|
||||
canny_high=config.edge_canny_high,
|
||||
hough_threshold=config.edge_hough_threshold,
|
||||
hough_min_length=config.edge_hough_min_length,
|
||||
hough_max_gap=config.edge_hough_max_gap,
|
||||
pair_max_distance=config.edge_pair_max_distance,
|
||||
pair_min_distance=config.edge_pair_min_distance,
|
||||
)
|
||||
debug_data[frame.sequence] = {
|
||||
"edge_overlay_b64": dr["edge_overlay_b64"],
|
||||
"lines_overlay_b64": dr["lines_overlay_b64"],
|
||||
"horizontal_count": dr["horizontal_count"],
|
||||
"pair_count": dr["pair_count"],
|
||||
}
|
||||
output["debug"] = debug_data
|
||||
|
||||
return output
|
||||
97
core/detect/checkpoint/runner_bridge.py
Normal file
97
core/detect/checkpoint/runner_bridge.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Runner bridge — checkpoint hook called by PipelineRunner after each stage.
|
||||
|
||||
Owns the per-job state (timeline, frame manifest, checkpoint chain) that
|
||||
the runner shouldn't know about.
|
||||
|
||||
Timeline and Job are independent entities:
|
||||
- One Timeline can serve multiple Jobs (re-run with different params)
|
||||
- One Job operates on one Timeline (set after frame extraction)
|
||||
- Checkpoints belong to Timeline, tagged with the Job that created them
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Per-job state
|
||||
_timeline_id: dict[str, str] = {}
|
||||
_frames_manifest: dict[str, dict[int, str]] = {}
|
||||
_latest_checkpoint: dict[str, str] = {}
|
||||
|
||||
|
||||
def reset_checkpoint_state(job_id: str):
|
||||
"""Clean up per-job checkpoint state. Called when pipeline finishes."""
|
||||
_timeline_id.pop(job_id, None)
|
||||
_frames_manifest.pop(job_id, None)
|
||||
_latest_checkpoint.pop(job_id, None)
|
||||
|
||||
|
||||
def checkpoint_after_stage(job_id: str, stage_name: str, state: dict, result: dict):
|
||||
"""
|
||||
Save a checkpoint after a stage completes.
|
||||
|
||||
Called by the runner. Handles:
|
||||
- Timeline creation (once, on extract_frames)
|
||||
- Frame upload (via create_timeline)
|
||||
- Stage output serialization (via stage registry)
|
||||
- Checkpoint chain (parent → child)
|
||||
"""
|
||||
if not job_id:
|
||||
return
|
||||
|
||||
from .storage import create_timeline, save_stage_output
|
||||
from core.detect.stages.base import _REGISTRY
|
||||
|
||||
merged = {**state, **result}
|
||||
|
||||
# On extract_frames: create Timeline + upload frames + root checkpoint
|
||||
if stage_name == "extract_frames" and job_id not in _timeline_id:
|
||||
frames = merged.get("frames", [])
|
||||
video_path = merged.get("video_path", "")
|
||||
profile_name = merged.get("profile_name", "")
|
||||
|
||||
tid, cid = create_timeline(
|
||||
source_video=video_path,
|
||||
profile_name=profile_name,
|
||||
frames=frames,
|
||||
)
|
||||
_timeline_id[job_id] = tid
|
||||
_latest_checkpoint[job_id] = cid
|
||||
logger.info("Job %s → Timeline %s (root checkpoint %s)", job_id, tid, cid)
|
||||
|
||||
# Emit timeline_id via SSE so the UI can use it for checkpoint loads
|
||||
from core.detect import emit
|
||||
emit.log(job_id, "Checkpoint", "INFO", f"timeline_id={tid}")
|
||||
return
|
||||
|
||||
# For subsequent stages: save checkpoint on the timeline
|
||||
tid = _timeline_id.get(job_id)
|
||||
if not tid:
|
||||
logger.warning("No timeline for job %s, skipping checkpoint", job_id)
|
||||
return
|
||||
|
||||
# Serialize stage output using the stage's serialize_fn if available
|
||||
stage_cls = _REGISTRY.get(stage_name)
|
||||
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
|
||||
if serialize_fn:
|
||||
output_json = serialize_fn(merged, job_id)
|
||||
else:
|
||||
output_json = {}
|
||||
|
||||
parent_id = _latest_checkpoint.get(job_id)
|
||||
new_checkpoint_id = save_stage_output(
|
||||
timeline_id=tid,
|
||||
parent_checkpoint_id=parent_id,
|
||||
stage_name=stage_name,
|
||||
output_json=output_json,
|
||||
job_id=job_id,
|
||||
)
|
||||
_latest_checkpoint[job_id] = new_checkpoint_id
|
||||
|
||||
|
||||
def get_timeline_id(job_id: str) -> str | None:
|
||||
"""Get the timeline_id for a running job. Used by the UI to load checkpoints."""
|
||||
return _timeline_id.get(job_id)
|
||||
108
core/detect/checkpoint/serializer.py
Normal file
108
core/detect/checkpoint/serializer.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
State serialization — DetectState ↔ JSON-compatible dict.
|
||||
|
||||
Delegates to each stage's serialize_fn/deserialize_fn via the registry.
|
||||
This file has no model-specific knowledge — stages own their data format.
|
||||
|
||||
The only things serialized here are the "envelope" fields (job_id, video_path, etc.)
|
||||
that don't belong to any stage.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from core.schema.serializers._common import serialize_dataclass
|
||||
from core.schema.serializers.pipeline import (
|
||||
deserialize_pipeline_stats,
|
||||
deserialize_text_candidates,
|
||||
)
|
||||
|
||||
|
||||
# Envelope fields — not owned by any stage, always present
|
||||
ENVELOPE_KEYS = ["job_id", "video_path", "profile_name", "config_overrides"]
|
||||
|
||||
|
||||
def serialize_state(state: dict, frames_manifest: dict[int, str]) -> dict:
|
||||
"""
|
||||
Serialize DetectState to a JSON-compatible dict.
|
||||
|
||||
Calls each registered stage's serialize_fn for stage-owned data.
|
||||
Envelope fields (job_id, etc.) are copied directly.
|
||||
"""
|
||||
from core.detect.stages.base import _REGISTRY
|
||||
|
||||
checkpoint = {}
|
||||
|
||||
# Envelope
|
||||
for key in ENVELOPE_KEYS:
|
||||
default = {} if key == "config_overrides" else ""
|
||||
checkpoint[key] = state.get(key, default)
|
||||
|
||||
# Frames manifest (needed by frame-loading stages)
|
||||
checkpoint["frames_manifest"] = {str(k): v for k, v in frames_manifest.items()}
|
||||
|
||||
# Stats (shared across stages, not owned by one)
|
||||
stats = state.get("stats")
|
||||
if stats is not None:
|
||||
checkpoint["stats"] = serialize_dataclass(stats)
|
||||
else:
|
||||
checkpoint["stats"] = {}
|
||||
|
||||
# Per-stage data
|
||||
for name, stage_def in _REGISTRY.items():
|
||||
if stage_def.serialize_fn is None:
|
||||
continue
|
||||
job_id = state.get("job_id", "")
|
||||
stage_data = stage_def.serialize_fn(state, job_id)
|
||||
checkpoint[f"stage_{name}"] = stage_data
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
def deserialize_state(checkpoint: dict, frames: list) -> dict:
|
||||
"""
|
||||
Reconstitute DetectState from a checkpoint dict + loaded frames.
|
||||
|
||||
Calls each stage's deserialize_fn to restore stage-owned data.
|
||||
"""
|
||||
from core.detect.stages.base import _REGISTRY
|
||||
|
||||
frame_map = {f.sequence: f for f in frames}
|
||||
|
||||
state = {}
|
||||
|
||||
# Envelope
|
||||
for key in ENVELOPE_KEYS:
|
||||
default = {} if key == "config_overrides" else ""
|
||||
state[key] = checkpoint.get(key, default)
|
||||
|
||||
# Frames (always present, loaded externally)
|
||||
state["frames"] = frames
|
||||
|
||||
# Stats
|
||||
state["stats"] = deserialize_pipeline_stats(checkpoint.get("stats", {}))
|
||||
|
||||
# Per-stage data
|
||||
for name, stage_def in _REGISTRY.items():
|
||||
if stage_def.deserialize_fn is None:
|
||||
continue
|
||||
|
||||
stage_key = f"stage_{name}"
|
||||
if stage_key not in checkpoint:
|
||||
continue
|
||||
|
||||
job_id = state.get("job_id", "")
|
||||
stage_data = stage_def.deserialize_fn(checkpoint[stage_key], job_id)
|
||||
|
||||
for k, v in stage_data.items():
|
||||
if k == "_filtered_sequences":
|
||||
# Reconnect filtered frames from sequence list
|
||||
seq_set = set(v)
|
||||
state["filtered_frames"] = [f for f in frames if f.sequence in seq_set]
|
||||
elif k.endswith("_raw"):
|
||||
# Raw text candidates need frame reference reconnection
|
||||
real_key = k.removeprefix("_").removesuffix("_raw")
|
||||
state[real_key] = deserialize_text_candidates(v, frame_map)
|
||||
else:
|
||||
state[k] = v
|
||||
|
||||
return state
|
||||
177
core/detect/checkpoint/storage.py
Normal file
177
core/detect/checkpoint/storage.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""
|
||||
Checkpoint storage — Timeline + Checkpoint (tree of snapshots).
|
||||
|
||||
Timeline: frame sequence from source video (frames in MinIO)
|
||||
Checkpoint: snapshot of pipeline state (stage outputs as JSONB in Postgres)
|
||||
parent_id forms a tree — multiple children = different config tries
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
from .frames import save_frames, load_frames, CHECKPOINT_PREFIX
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Timeline
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def create_timeline(
|
||||
source_video: str,
|
||||
profile_name: str,
|
||||
frames: list,
|
||||
fps: float = 2.0,
|
||||
source_asset_id: UUID | None = None,
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Create a timeline from frames. Uploads frame images to MinIO,
|
||||
creates Timeline + root Checkpoint in Postgres.
|
||||
|
||||
Returns (timeline_id, checkpoint_id).
|
||||
"""
|
||||
from core.db.models import Timeline, Checkpoint
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
timeline = Timeline(
|
||||
source_video=source_video,
|
||||
profile_name=profile_name,
|
||||
source_asset_id=source_asset_id,
|
||||
fps=fps,
|
||||
)
|
||||
session.add(timeline)
|
||||
session.flush()
|
||||
tid = str(timeline.id)
|
||||
|
||||
# Upload frames to MinIO
|
||||
manifest = save_frames(tid, frames)
|
||||
|
||||
frames_meta = [
|
||||
{
|
||||
"sequence": f.sequence,
|
||||
"chunk_id": getattr(f, "chunk_id", 0),
|
||||
"timestamp": f.timestamp,
|
||||
"perceptual_hash": getattr(f, "perceptual_hash", ""),
|
||||
}
|
||||
for f in frames
|
||||
]
|
||||
|
||||
timeline.frames_prefix = f"{CHECKPOINT_PREFIX}/{tid}/frames/"
|
||||
timeline.frames_manifest = {str(k): v for k, v in manifest.items()}
|
||||
timeline.frames_meta = frames_meta
|
||||
|
||||
checkpoint = Checkpoint(
|
||||
timeline_id=timeline.id,
|
||||
parent_id=None,
|
||||
stage_outputs={},
|
||||
stats={"frames_extracted": len(frames)},
|
||||
)
|
||||
session.add(checkpoint)
|
||||
session.commit()
|
||||
session.refresh(checkpoint)
|
||||
cid = str(checkpoint.id)
|
||||
|
||||
logger.info("Timeline created: %s (%d frames, root checkpoint %s)", tid, len(frames), cid)
|
||||
return tid, cid
|
||||
|
||||
|
||||
def get_timeline_frames(timeline_id: str) -> list:
|
||||
"""Load frames from a timeline (from MinIO) as Frame objects."""
|
||||
from core.db.models import Timeline
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
timeline = session.get(Timeline, UUID(timeline_id))
|
||||
if not timeline:
|
||||
raise ValueError(f"Timeline not found: {timeline_id}")
|
||||
|
||||
raw_manifest = timeline.frames_manifest or {}
|
||||
manifest = {int(k): v for k, v in raw_manifest.items()}
|
||||
return load_frames(manifest, timeline.frames_meta or [])
|
||||
|
||||
|
||||
def get_timeline_frames_b64(timeline_id: str) -> list[dict]:
|
||||
"""Load frames as base64 JPEG (lightweight, no numpy)."""
|
||||
from core.db.models import Timeline
|
||||
from core.db.connection import get_session
|
||||
from .frames import load_frames_b64
|
||||
|
||||
with get_session() as session:
|
||||
timeline = session.get(Timeline, UUID(timeline_id))
|
||||
if not timeline:
|
||||
raise ValueError(f"Timeline not found: {timeline_id}")
|
||||
|
||||
raw_manifest = timeline.frames_manifest or {}
|
||||
manifest = {int(k): v for k, v in raw_manifest.items()}
|
||||
return load_frames_b64(manifest, timeline.frames_meta or [])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Checkpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def save_stage_output(
|
||||
timeline_id: str,
|
||||
parent_checkpoint_id: str | None,
|
||||
stage_name: str,
|
||||
output_json: dict,
|
||||
config_overrides: dict | None = None,
|
||||
stats: dict | None = None,
|
||||
is_scenario: bool = False,
|
||||
scenario_label: str = "",
|
||||
job_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Save a stage's output as a new checkpoint (child of parent).
|
||||
|
||||
Carries forward stage outputs from parent + adds the new one.
|
||||
Returns the new checkpoint ID.
|
||||
"""
|
||||
from core.db.models import Checkpoint
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
parent_outputs = {}
|
||||
parent_stats = {}
|
||||
parent_config = {}
|
||||
if parent_checkpoint_id:
|
||||
parent = session.get(Checkpoint, UUID(parent_checkpoint_id))
|
||||
if parent:
|
||||
parent_outputs = dict(parent.stage_outputs or {})
|
||||
parent_stats = dict(parent.stats or {})
|
||||
parent_config = dict(parent.config_overrides or {})
|
||||
|
||||
checkpoint = Checkpoint(
|
||||
timeline_id=UUID(timeline_id),
|
||||
job_id=UUID(job_id) if job_id else None,
|
||||
parent_id=UUID(parent_checkpoint_id) if parent_checkpoint_id else None,
|
||||
stage_outputs={**parent_outputs, stage_name: output_json},
|
||||
config_overrides={**parent_config, **(config_overrides or {})},
|
||||
stats={**parent_stats, **(stats or {})},
|
||||
is_scenario=is_scenario,
|
||||
scenario_label=scenario_label,
|
||||
)
|
||||
session.add(checkpoint)
|
||||
session.commit()
|
||||
session.refresh(checkpoint)
|
||||
cid = str(checkpoint.id)
|
||||
|
||||
logger.info("Checkpoint saved: %s (timeline %s, stage %s, parent %s)",
|
||||
cid, timeline_id, stage_name, parent_checkpoint_id)
|
||||
return cid
|
||||
|
||||
|
||||
def load_stage_output(checkpoint_id: str, stage_name: str) -> dict | None:
|
||||
"""Load a stage's output from a checkpoint."""
|
||||
from core.db.models import Checkpoint
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
checkpoint = session.get(Checkpoint, UUID(checkpoint_id))
|
||||
if not checkpoint:
|
||||
return None
|
||||
return (checkpoint.stage_outputs or {}).get(stage_name)
|
||||
159
core/detect/emit.py
Normal file
159
core/detect/emit.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
Event emission helpers for detection pipeline stages.
|
||||
|
||||
Single place that knows how to build event payloads.
|
||||
Stages call these instead of constructing dicts or dataclasses directly.
|
||||
|
||||
Run context (run_id, parent_job_id) is set once at pipeline start via
|
||||
set_run_context() and automatically injected into all events.
|
||||
|
||||
Log level is set per-run with optional per-stage overrides.
|
||||
DEBUG events are only pushed when the run (or stage) log level allows it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from core.detect.events import push_detect_event
|
||||
from core.detect.models import PipelineStats
|
||||
|
||||
# Log level ordering for comparison
|
||||
_LEVEL_ORDER = {"DEBUG": 0, "INFO": 1, "WARN": 2, "ERROR": 3}
|
||||
|
||||
# Module-level run context — set once per pipeline invocation
|
||||
_run_context: dict = {}
|
||||
_run_log_level: str = "INFO"
|
||||
_stage_log_levels: dict[str, str] = {} # stage_name → level override
|
||||
|
||||
|
||||
def set_run_context(
|
||||
run_id: str = "",
|
||||
parent_job_id: str = "",
|
||||
run_type: str = "initial",
|
||||
log_level: str = "INFO",
|
||||
):
|
||||
"""Set the run context for all subsequent events in this pipeline invocation."""
|
||||
global _run_context, _run_log_level
|
||||
_run_context = {
|
||||
"run_id": run_id,
|
||||
"parent_job_id": parent_job_id,
|
||||
"run_type": run_type,
|
||||
}
|
||||
_run_log_level = log_level.upper()
|
||||
_stage_log_levels.clear()
|
||||
|
||||
|
||||
def set_stage_log_level(stage: str, level: str):
|
||||
"""Override log level for a specific stage."""
|
||||
_stage_log_levels[stage] = level.upper()
|
||||
|
||||
|
||||
def clear_stage_log_level(stage: str):
|
||||
"""Remove per-stage log level override."""
|
||||
_stage_log_levels.pop(stage, None)
|
||||
|
||||
|
||||
def clear_run_context():
|
||||
global _run_context, _run_log_level
|
||||
_run_context = {}
|
||||
_run_log_level = "INFO"
|
||||
_stage_log_levels.clear()
|
||||
|
||||
|
||||
def _should_emit(level: str, stage: str) -> bool:
|
||||
"""Check if this log level should be emitted given run/stage settings."""
|
||||
effective = _stage_log_levels.get(stage, _run_log_level)
|
||||
return _LEVEL_ORDER.get(level.upper(), 1) >= _LEVEL_ORDER.get(effective, 1)
|
||||
|
||||
|
||||
def _inject_context(payload: dict) -> dict:
|
||||
"""Add run context fields to an event payload."""
|
||||
if _run_context:
|
||||
payload.update(_run_context)
|
||||
return payload
|
||||
|
||||
|
||||
def log(job_id: str | None, stage: str, level: str, msg: str) -> None:
|
||||
if not job_id:
|
||||
return
|
||||
if not _should_emit(level, stage):
|
||||
return
|
||||
payload = {
|
||||
"level": level,
|
||||
"stage": stage,
|
||||
"msg": msg,
|
||||
"ts": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
_inject_context(payload)
|
||||
push_detect_event(job_id, "log", payload)
|
||||
|
||||
|
||||
def stats(job_id: str | None, **kwargs) -> None:
|
||||
if not job_id:
|
||||
return
|
||||
s = PipelineStats(**kwargs)
|
||||
payload = dataclasses.asdict(s)
|
||||
_inject_context(payload)
|
||||
push_detect_event(job_id, "stats_update", payload)
|
||||
|
||||
|
||||
def frame_update(
|
||||
job_id: str | None,
|
||||
frame_ref: int,
|
||||
timestamp: float,
|
||||
jpeg_b64: str,
|
||||
boxes: list[dict],
|
||||
) -> None:
|
||||
if not job_id:
|
||||
return
|
||||
payload = {
|
||||
"frame_ref": frame_ref,
|
||||
"timestamp": timestamp,
|
||||
"jpeg_b64": jpeg_b64,
|
||||
"boxes": boxes,
|
||||
}
|
||||
_inject_context(payload)
|
||||
push_detect_event(job_id, "frame_update", payload)
|
||||
|
||||
|
||||
def graph_update(job_id: str | None, nodes: list[dict]) -> None:
|
||||
if not job_id:
|
||||
return
|
||||
payload = {"nodes": nodes}
|
||||
_inject_context(payload)
|
||||
push_detect_event(job_id, "graph_update", payload)
|
||||
|
||||
|
||||
def detection(
|
||||
job_id: str | None,
|
||||
brand: str,
|
||||
confidence: float,
|
||||
source: str,
|
||||
timestamp: float,
|
||||
duration: float = 0.0,
|
||||
content_type: str = "",
|
||||
frame_ref: int | None = None,
|
||||
) -> None:
|
||||
if not job_id:
|
||||
return
|
||||
payload = {
|
||||
"brand": brand,
|
||||
"confidence": confidence,
|
||||
"source": source,
|
||||
"timestamp": timestamp,
|
||||
"duration": duration,
|
||||
"content_type": content_type,
|
||||
"frame_ref": frame_ref,
|
||||
}
|
||||
_inject_context(payload)
|
||||
push_detect_event(job_id, "detection", payload)
|
||||
|
||||
|
||||
def job_complete(job_id: str | None, report: dict) -> None:
|
||||
if not job_id:
|
||||
return
|
||||
payload = {"job_id": job_id, "report": report}
|
||||
_inject_context(payload)
|
||||
push_detect_event(job_id, "job_complete", payload)
|
||||
42
core/detect/events.py
Normal file
42
core/detect/events.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""
|
||||
Detection pipeline event helpers.
|
||||
|
||||
Non-generated runtime code for pushing SSE events.
|
||||
The event payload types are in sse_contract.py (generated by modelgen).
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.events import push_event
|
||||
|
||||
DETECT_EVENTS_PREFIX = "detect_events"
|
||||
|
||||
# SSE event type names
|
||||
EVENT_GRAPH_UPDATE = "graph_update"
|
||||
EVENT_STATS_UPDATE = "stats_update"
|
||||
EVENT_FRAME_UPDATE = "frame_update"
|
||||
EVENT_DETECTION = "detection"
|
||||
EVENT_LOG = "log"
|
||||
EVENT_JOB_COMPLETE = "job_complete"
|
||||
|
||||
ALL_EVENT_TYPES = [
|
||||
EVENT_GRAPH_UPDATE,
|
||||
EVENT_STATS_UPDATE,
|
||||
EVENT_FRAME_UPDATE,
|
||||
EVENT_DETECTION,
|
||||
EVENT_LOG,
|
||||
EVENT_JOB_COMPLETE,
|
||||
]
|
||||
|
||||
TERMINAL_EVENTS = [EVENT_JOB_COMPLETE]
|
||||
|
||||
|
||||
def push_detect_event(job_id: str, event_type: str, data: BaseModel | dict) -> None:
|
||||
"""Push a detection event to Redis. Accepts Pydantic models or plain dicts."""
|
||||
payload = data.model_dump(mode="json") if isinstance(data, BaseModel) else data
|
||||
push_event(
|
||||
job_id=job_id,
|
||||
event_type=event_type,
|
||||
data=payload,
|
||||
prefix=DETECT_EVENTS_PREFIX,
|
||||
)
|
||||
45
core/detect/graph/__init__.py
Normal file
45
core/detect/graph/__init__.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
Detection pipeline graph.
|
||||
|
||||
detect/graph/
|
||||
nodes.py — node functions (one per stage)
|
||||
events.py — graph_update SSE emission
|
||||
runner.py — PipelineRunner (config-driven, checkpoint, cancel, pause)
|
||||
"""
|
||||
|
||||
from .nodes import NODES, NODE_FUNCTIONS
|
||||
from .runner import (
|
||||
PipelineCancelled,
|
||||
PipelineRunner,
|
||||
build_graph,
|
||||
clear_cancel_check,
|
||||
clear_pause,
|
||||
get_pipeline,
|
||||
init_pause,
|
||||
is_paused,
|
||||
pause_pipeline,
|
||||
resume_pipeline,
|
||||
set_cancel_check,
|
||||
set_pause_after_stage,
|
||||
step_pipeline,
|
||||
)
|
||||
from .events import _node_states
|
||||
|
||||
__all__ = [
|
||||
"NODES",
|
||||
"NODE_FUNCTIONS",
|
||||
"PipelineCancelled",
|
||||
"PipelineRunner",
|
||||
"build_graph",
|
||||
"get_pipeline",
|
||||
"set_cancel_check",
|
||||
"clear_cancel_check",
|
||||
"init_pause",
|
||||
"clear_pause",
|
||||
"pause_pipeline",
|
||||
"resume_pipeline",
|
||||
"step_pipeline",
|
||||
"set_pause_after_stage",
|
||||
"is_paused",
|
||||
"_node_states",
|
||||
]
|
||||
27
core/detect/graph/events.py
Normal file
27
core/detect/graph/events.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""
|
||||
Graph event emission — node state tracking + SSE graph_update events.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from core.detect import emit
|
||||
from core.detect.state import DetectState
|
||||
|
||||
|
||||
# Track node states across pipeline runs
|
||||
_node_states: dict[str, dict[str, str]] = {}
|
||||
|
||||
|
||||
def emit_transition(state: DetectState, node: str, status: str, node_list: list[str]):
|
||||
"""Update node status and emit graph_update SSE event."""
|
||||
job_id = state.get("job_id")
|
||||
if not job_id:
|
||||
return
|
||||
|
||||
if job_id not in _node_states:
|
||||
_node_states[job_id] = {n: "pending" for n in node_list}
|
||||
|
||||
_node_states[job_id][node] = status
|
||||
|
||||
nodes = [{"id": n, "status": _node_states[job_id][n]} for n in node_list]
|
||||
emit.graph_update(job_id, nodes)
|
||||
366
core/detect/graph/nodes.py
Normal file
366
core/detect/graph/nodes.py
Normal file
@@ -0,0 +1,366 @@
|
||||
"""
|
||||
Pipeline node functions — one per stage.
|
||||
|
||||
Each node: reads state, gets config from profile dict, runs stage logic,
|
||||
emits transitions, returns output dict.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from core.detect import emit
|
||||
from core.detect.models import CropContext, PipelineStats
|
||||
from core.detect.profile import get_profile, get_stage_config, build_vlm_prompt, aggregate_detections
|
||||
from core.detect.stages.models import (
|
||||
DetectionConfig,
|
||||
FieldSegmentationConfig,
|
||||
FrameExtractionConfig,
|
||||
OCRConfig,
|
||||
RegionAnalysisConfig,
|
||||
ResolverConfig,
|
||||
SceneFilterConfig,
|
||||
)
|
||||
from core.detect.state import DetectState
|
||||
from core.detect.stages.frame_extractor import extract_frames
|
||||
from core.detect.stages.scene_filter import scene_filter
|
||||
from core.detect.stages.field_segmentation import run_field_segmentation
|
||||
from core.detect.stages.edge_detector import detect_edge_regions
|
||||
from core.detect.stages.yolo_detector import detect_objects
|
||||
from core.detect.stages.preprocess import preprocess_regions
|
||||
from core.detect.stages.ocr_stage import run_ocr
|
||||
from core.detect.stages.brand_resolver import resolve_brands
|
||||
from core.detect.stages.vlm_local import escalate_vlm
|
||||
from core.detect.stages.vlm_cloud import escalate_cloud
|
||||
from core.detect.stages.aggregator import compile_report
|
||||
from core.detect.tracing import trace_node, flush as flush_traces
|
||||
|
||||
from .events import emit_transition
|
||||
|
||||
INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode
|
||||
|
||||
NODES = [
|
||||
"extract_frames",
|
||||
"filter_scenes",
|
||||
"field_segmentation",
|
||||
"detect_edges",
|
||||
"detect_objects",
|
||||
"preprocess",
|
||||
"run_ocr",
|
||||
"match_brands",
|
||||
"escalate_vlm",
|
||||
"escalate_cloud",
|
||||
"compile_report",
|
||||
]
|
||||
|
||||
|
||||
def _load_profile(state: DetectState) -> dict:
|
||||
"""Load profile dict, apply config overrides if present."""
|
||||
name = state.get("profile_name", "soccer_broadcast")
|
||||
profile = get_profile(name)
|
||||
|
||||
overrides = state.get("config_overrides")
|
||||
if overrides:
|
||||
# Merge overrides into a copy of the profile configs
|
||||
merged_configs = dict(profile.get("configs", {}))
|
||||
for stage_name, stage_overrides in overrides.items():
|
||||
if stage_name in merged_configs:
|
||||
merged_configs[stage_name] = {**merged_configs[stage_name], **stage_overrides}
|
||||
else:
|
||||
merged_configs[stage_name] = stage_overrides
|
||||
profile = {**profile, "configs": merged_configs}
|
||||
|
||||
return profile
|
||||
|
||||
|
||||
def _emit(state, node, status):
|
||||
emit_transition(state, node, status, NODES)
|
||||
|
||||
|
||||
# --- Node functions ---
|
||||
|
||||
def node_extract_frames(state: DetectState) -> dict:
|
||||
job_id = state.get("job_id", "")
|
||||
if job_id and not emit._run_context:
|
||||
emit.set_run_context(run_id=job_id, parent_job_id=job_id, run_type="initial")
|
||||
|
||||
source_asset_id = state.get("source_asset_id")
|
||||
if source_asset_id and not state.get("session_brands"):
|
||||
from core.detect.stages.brand_resolver import build_session_dict
|
||||
session_brands = build_session_dict(source_asset_id)
|
||||
state["session_brands"] = session_brands
|
||||
|
||||
_emit(state, "extract_frames", "running")
|
||||
|
||||
with trace_node(state, "extract_frames") as span:
|
||||
profile = _load_profile(state)
|
||||
config = FrameExtractionConfig(**get_stage_config(profile, "extract_frames"))
|
||||
frames = extract_frames(state["video_path"], config, job_id=job_id)
|
||||
span.set_output({"frames_extracted": len(frames)})
|
||||
|
||||
_emit(state, "extract_frames", "done")
|
||||
return {"frames": frames, "stats": PipelineStats(frames_extracted=len(frames))}
|
||||
|
||||
|
||||
def node_filter_scenes(state: DetectState) -> dict:
|
||||
_emit(state, "filter_scenes", "running")
|
||||
|
||||
with trace_node(state, "filter_scenes") as span:
|
||||
profile = _load_profile(state)
|
||||
config = SceneFilterConfig(**get_stage_config(profile, "filter_scenes"))
|
||||
frames = state.get("frames", [])
|
||||
kept = scene_filter(frames, config, job_id=state.get("job_id"))
|
||||
span.set_output({"frames_in": len(frames), "frames_kept": len(kept)})
|
||||
|
||||
stats = state.get("stats", PipelineStats())
|
||||
stats.frames_after_scene_filter = len(kept)
|
||||
|
||||
_emit(state, "filter_scenes", "done")
|
||||
return {"filtered_frames": kept, "stats": stats}
|
||||
|
||||
|
||||
def node_field_segmentation(state: DetectState) -> dict:
|
||||
_emit(state, "field_segmentation", "running")
|
||||
|
||||
with trace_node(state, "field_segmentation") as span:
|
||||
profile = _load_profile(state)
|
||||
config = FieldSegmentationConfig(**get_stage_config(profile, "field_segmentation"))
|
||||
frames = state.get("filtered_frames", [])
|
||||
job_id = state.get("job_id")
|
||||
|
||||
result = run_field_segmentation(frames, config, inference_url=INFERENCE_URL, job_id=job_id)
|
||||
span.set_output({
|
||||
"frames": len(frames),
|
||||
"avg_coverage": sum(result["field_coverage"].values()) / max(len(result["field_coverage"]), 1),
|
||||
})
|
||||
|
||||
_emit(state, "field_segmentation", "done")
|
||||
return {
|
||||
"field_masks": result["field_masks"],
|
||||
"field_boundaries": result["field_boundaries"],
|
||||
"field_coverage": result["field_coverage"],
|
||||
}
|
||||
|
||||
|
||||
def node_detect_edges(state: DetectState) -> dict:
|
||||
_emit(state, "detect_edges", "running")
|
||||
|
||||
with trace_node(state, "detect_edges") as span:
|
||||
profile = _load_profile(state)
|
||||
config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
|
||||
frames = state.get("filtered_frames", [])
|
||||
field_masks = state.get("field_masks", {})
|
||||
job_id = state.get("job_id")
|
||||
|
||||
regions = detect_edge_regions(
|
||||
frames, config, inference_url=INFERENCE_URL, job_id=job_id,
|
||||
field_masks=field_masks,
|
||||
)
|
||||
total = sum(len(r) for r in regions.values())
|
||||
span.set_output({"frames": len(frames), "edge_regions": total})
|
||||
|
||||
stats = state.get("stats", PipelineStats())
|
||||
stats.cv_regions_detected = total
|
||||
|
||||
_emit(state, "detect_edges", "done")
|
||||
return {"edge_regions_by_frame": regions, "stats": stats}
|
||||
|
||||
|
||||
def node_detect_objects(state: DetectState) -> dict:
|
||||
_emit(state, "detect_objects", "running")
|
||||
|
||||
with trace_node(state, "detect_objects") as span:
|
||||
profile = _load_profile(state)
|
||||
config = DetectionConfig(**get_stage_config(profile, "detect_objects"))
|
||||
frames = state.get("filtered_frames", [])
|
||||
job_id = state.get("job_id")
|
||||
|
||||
all_boxes = detect_objects(frames, config, inference_url=INFERENCE_URL, job_id=job_id)
|
||||
total_regions = sum(len(boxes) for boxes in all_boxes.values())
|
||||
span.set_output({"frames": len(frames), "regions_detected": total_regions})
|
||||
|
||||
stats = state.get("stats", PipelineStats())
|
||||
stats.regions_detected = total_regions
|
||||
|
||||
_emit(state, "detect_objects", "done")
|
||||
return {"boxes_by_frame": all_boxes, "stats": stats}
|
||||
|
||||
|
||||
def node_preprocess(state: DetectState) -> dict:
|
||||
_emit(state, "preprocess", "running")
|
||||
|
||||
with trace_node(state, "preprocess") as span:
|
||||
profile = _load_profile(state)
|
||||
prep_config = get_stage_config(profile, "preprocess")
|
||||
frames = state.get("filtered_frames", [])
|
||||
boxes = state.get("boxes_by_frame", {})
|
||||
job_id = state.get("job_id")
|
||||
|
||||
do_contrast = prep_config.get("contrast", True)
|
||||
do_deskew = prep_config.get("deskew", False)
|
||||
do_binarize = prep_config.get("binarize", False)
|
||||
|
||||
preprocessed = preprocess_regions(
|
||||
frames, boxes,
|
||||
do_contrast=do_contrast,
|
||||
do_deskew=do_deskew,
|
||||
do_binarize=do_binarize,
|
||||
inference_url=INFERENCE_URL,
|
||||
job_id=job_id,
|
||||
)
|
||||
span.set_output({"regions_preprocessed": len(preprocessed)})
|
||||
|
||||
_emit(state, "preprocess", "done")
|
||||
return {"preprocessed_crops": preprocessed}
|
||||
|
||||
|
||||
def node_run_ocr(state: DetectState) -> dict:
|
||||
_emit(state, "run_ocr", "running")
|
||||
|
||||
with trace_node(state, "run_ocr") as span:
|
||||
profile = _load_profile(state)
|
||||
config = OCRConfig(**get_stage_config(profile, "run_ocr"))
|
||||
frames = state.get("filtered_frames", [])
|
||||
boxes = state.get("boxes_by_frame", {})
|
||||
job_id = state.get("job_id")
|
||||
|
||||
candidates = run_ocr(frames, boxes, config, inference_url=INFERENCE_URL, job_id=job_id)
|
||||
span.set_output({"regions_in": sum(len(b) for b in boxes.values()), "text_candidates": len(candidates)})
|
||||
|
||||
stats = state.get("stats", PipelineStats())
|
||||
stats.regions_resolved_by_ocr = len(candidates)
|
||||
|
||||
_emit(state, "run_ocr", "done")
|
||||
return {"text_candidates": candidates, "stats": stats}
|
||||
|
||||
|
||||
def node_match_brands(state: DetectState) -> dict:
|
||||
_emit(state, "match_brands", "running")
|
||||
|
||||
with trace_node(state, "match_brands") as span:
|
||||
profile = _load_profile(state)
|
||||
config = ResolverConfig(**get_stage_config(profile, "match_brands"))
|
||||
candidates = state.get("text_candidates", [])
|
||||
session_brands = state.get("session_brands", {})
|
||||
job_id = state.get("job_id")
|
||||
source_asset_id = state.get("source_asset_id")
|
||||
|
||||
matched, unresolved = resolve_brands(
|
||||
candidates, config,
|
||||
session_brands=session_brands,
|
||||
source_asset_id=source_asset_id,
|
||||
content_type=profile["name"], job_id=job_id,
|
||||
)
|
||||
span.set_output({"matched": len(matched), "unresolved": len(unresolved)})
|
||||
|
||||
_emit(state, "match_brands", "done")
|
||||
return {"detections": matched, "unresolved_candidates": unresolved}
|
||||
|
||||
|
||||
def node_escalate_vlm(state: DetectState) -> dict:
|
||||
_emit(state, "escalate_vlm", "running")
|
||||
|
||||
with trace_node(state, "escalate_vlm") as span:
|
||||
profile = _load_profile(state)
|
||||
vlm_config = get_stage_config(profile, "escalate_vlm")
|
||||
vlm_template = vlm_config.get("vlm_prompt_template", "Identify the brand in this image.")
|
||||
candidates = state.get("unresolved_candidates", [])
|
||||
job_id = state.get("job_id")
|
||||
|
||||
vlm_prompt_fn = lambda ctx: build_vlm_prompt(ctx, vlm_template)
|
||||
|
||||
vlm_matched, still_unresolved = escalate_vlm(
|
||||
candidates,
|
||||
vlm_prompt_fn=vlm_prompt_fn,
|
||||
inference_url=INFERENCE_URL,
|
||||
content_type=profile["name"],
|
||||
source_asset_id=state.get("source_asset_id"),
|
||||
job_id=job_id,
|
||||
)
|
||||
|
||||
stats = state.get("stats", PipelineStats())
|
||||
stats.regions_escalated_to_local_vlm = len(candidates)
|
||||
span.set_output({"candidates": len(candidates), "matched": len(vlm_matched),
|
||||
"still_unresolved": len(still_unresolved)})
|
||||
|
||||
existing = state.get("detections", [])
|
||||
|
||||
vlm_skipped = os.environ.get("SKIP_VLM", "").strip() == "1"
|
||||
_emit(state, "escalate_vlm", "skipped" if vlm_skipped else "done")
|
||||
return {
|
||||
"detections": existing + vlm_matched,
|
||||
"unresolved_candidates": still_unresolved,
|
||||
"stats": stats,
|
||||
}
|
||||
|
||||
|
||||
def node_escalate_cloud(state: DetectState) -> dict:
|
||||
_emit(state, "escalate_cloud", "running")
|
||||
|
||||
with trace_node(state, "escalate_cloud") as span:
|
||||
profile = _load_profile(state)
|
||||
vlm_config = get_stage_config(profile, "escalate_vlm")
|
||||
vlm_template = vlm_config.get("vlm_prompt_template", "Identify the brand in this image.")
|
||||
candidates = state.get("unresolved_candidates", [])
|
||||
job_id = state.get("job_id")
|
||||
stats = state.get("stats", PipelineStats())
|
||||
|
||||
vlm_prompt_fn = lambda ctx: build_vlm_prompt(ctx, vlm_template)
|
||||
|
||||
cloud_matched = escalate_cloud(
|
||||
candidates,
|
||||
vlm_prompt_fn=vlm_prompt_fn,
|
||||
stats=stats,
|
||||
content_type=profile["name"],
|
||||
source_asset_id=state.get("source_asset_id"),
|
||||
job_id=job_id,
|
||||
)
|
||||
|
||||
span.set_output({"candidates": len(candidates), "matched": len(cloud_matched),
|
||||
"cloud_calls": stats.cloud_llm_calls,
|
||||
"cost_usd": stats.estimated_cloud_cost_usd})
|
||||
|
||||
existing = state.get("detections", [])
|
||||
|
||||
cloud_skipped = os.environ.get("SKIP_CLOUD", "").strip() == "1"
|
||||
_emit(state, "escalate_cloud", "skipped" if cloud_skipped else "done")
|
||||
return {"detections": existing + cloud_matched, "stats": stats}
|
||||
|
||||
|
||||
def node_compile_report(state: DetectState) -> dict:
|
||||
_emit(state, "compile_report", "running")
|
||||
|
||||
with trace_node(state, "compile_report") as span:
|
||||
profile = _load_profile(state)
|
||||
detections = state.get("detections", [])
|
||||
stats = state.get("stats", PipelineStats())
|
||||
job_id = state.get("job_id")
|
||||
|
||||
report = compile_report(
|
||||
detections=detections,
|
||||
stats=stats,
|
||||
video_source=state.get("video_path", ""),
|
||||
content_type=profile["name"],
|
||||
job_id=job_id,
|
||||
)
|
||||
|
||||
span.set_output({"brands": len(report.brands), "detections": len(report.timeline)})
|
||||
|
||||
flush_traces()
|
||||
_emit(state, "compile_report", "done")
|
||||
return {"report": report}
|
||||
|
||||
|
||||
NODE_FUNCTIONS = [
|
||||
("extract_frames", node_extract_frames),
|
||||
("filter_scenes", node_filter_scenes),
|
||||
("field_segmentation", node_field_segmentation),
|
||||
("detect_edges", node_detect_edges),
|
||||
("detect_objects", node_detect_objects),
|
||||
("preprocess", node_preprocess),
|
||||
("run_ocr", node_run_ocr),
|
||||
("match_brands", node_match_brands),
|
||||
("escalate_vlm", node_escalate_vlm),
|
||||
("escalate_cloud", node_escalate_cloud),
|
||||
("compile_report", node_compile_report),
|
||||
]
|
||||
274
core/detect/graph/runner.py
Normal file
274
core/detect/graph/runner.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""
|
||||
Pipeline runner — executes stages sequentially with checkpointing,
|
||||
cancellation, and pause/resume.
|
||||
|
||||
Reads PipelineConfig from the profile to determine what stages to run.
|
||||
Flattens the graph into a linear sequence for now (serial execution).
|
||||
Executor socket: all stages run via LocalExecutor (call function directly).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
|
||||
from core.detect.stages.models import PipelineConfig
|
||||
from core.detect.state import DetectState
|
||||
from .nodes import NODES, NODE_FUNCTIONS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1"
|
||||
|
||||
|
||||
class PipelineCancelled(Exception):
|
||||
"""Raised when a pipeline run is cancelled."""
|
||||
pass
|
||||
|
||||
|
||||
class PipelinePaused(Exception):
|
||||
"""Raised when a pipeline is paused (internally, for flow control)."""
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cancellation — checked before each node
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_cancel_check: dict[str, callable] = {}
|
||||
|
||||
|
||||
def set_cancel_check(job_id: str, fn):
|
||||
_cancel_check[job_id] = fn
|
||||
|
||||
|
||||
def clear_cancel_check(job_id: str):
|
||||
_cancel_check.pop(job_id, None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pause / Resume / Step — checked after each node completes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_pause_gate: dict[str, threading.Event] = {}
|
||||
_pause_after_stage: dict[str, bool] = {}
|
||||
|
||||
|
||||
def init_pause(job_id: str, pause_after_stage: bool = False):
|
||||
"""Initialize pause state for a job. Called when pipeline starts."""
|
||||
gate = threading.Event()
|
||||
gate.set() # start unpaused
|
||||
_pause_gate[job_id] = gate
|
||||
_pause_after_stage[job_id] = pause_after_stage
|
||||
|
||||
|
||||
def clear_pause(job_id: str):
|
||||
"""Clean up pause state. Called when pipeline finishes."""
|
||||
_pause_gate.pop(job_id, None)
|
||||
_pause_after_stage.pop(job_id, None)
|
||||
|
||||
|
||||
def pause_pipeline(job_id: str):
|
||||
"""Pause a running pipeline. It will block after the current stage completes."""
|
||||
gate = _pause_gate.get(job_id)
|
||||
if gate:
|
||||
gate.clear()
|
||||
logger.info("Pipeline %s paused", job_id)
|
||||
|
||||
|
||||
def resume_pipeline(job_id: str):
|
||||
"""Resume a paused pipeline."""
|
||||
gate = _pause_gate.get(job_id)
|
||||
if gate:
|
||||
gate.set()
|
||||
logger.info("Pipeline %s resumed", job_id)
|
||||
|
||||
|
||||
def step_pipeline(job_id: str):
|
||||
"""Run one stage then pause again."""
|
||||
_pause_after_stage[job_id] = True
|
||||
gate = _pause_gate.get(job_id)
|
||||
if gate:
|
||||
gate.set()
|
||||
logger.info("Pipeline %s stepping", job_id)
|
||||
|
||||
|
||||
def set_pause_after_stage(job_id: str, enabled: bool):
|
||||
"""Toggle pause-after-each-stage mode."""
|
||||
_pause_after_stage[job_id] = enabled
|
||||
if not enabled:
|
||||
gate = _pause_gate.get(job_id)
|
||||
if gate:
|
||||
gate.set()
|
||||
|
||||
|
||||
def is_paused(job_id: str) -> bool:
|
||||
"""Check if a pipeline is currently paused."""
|
||||
gate = _pause_gate.get(job_id)
|
||||
return gate is not None and not gate.is_set()
|
||||
|
||||
|
||||
def _wait_if_paused(job_id: str, node_name: str):
|
||||
"""Block until resumed. Called after each node completes."""
|
||||
gate = _pause_gate.get(job_id)
|
||||
if gate is None:
|
||||
return
|
||||
|
||||
if _pause_after_stage.get(job_id, False):
|
||||
gate.clear()
|
||||
from core.detect import emit
|
||||
emit.log(job_id, "Pipeline", "INFO", f"Paused after {node_name}")
|
||||
|
||||
while not gate.wait(timeout=0.5):
|
||||
check = _cancel_check.get(job_id)
|
||||
if check and check():
|
||||
raise PipelineCancelled(f"Cancelled while paused before next stage")
|
||||
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pipeline Runner
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Node function lookup — maps stage name to callable
|
||||
_NODE_FN_MAP: dict[str, callable] = {name: fn for name, fn in NODE_FUNCTIONS}
|
||||
|
||||
|
||||
def _flatten_config(config: PipelineConfig, start_from: str | None = None) -> list[str]:
|
||||
"""
|
||||
Flatten a PipelineConfig into a linear stage sequence.
|
||||
|
||||
For now: topological sort via edges. Falls back to stage order if no edges.
|
||||
Respects start_from for replay (skip stages before it).
|
||||
"""
|
||||
if not config.edges:
|
||||
# No edges defined — use stage order as-is
|
||||
names = [s.name for s in config.stages]
|
||||
else:
|
||||
# Topological sort from edges
|
||||
graph: dict[str, list[str]] = {}
|
||||
in_degree: dict[str, int] = {}
|
||||
stage_names = {s.name for s in config.stages}
|
||||
|
||||
for name in stage_names:
|
||||
graph[name] = []
|
||||
in_degree[name] = 0
|
||||
|
||||
for edge in config.edges:
|
||||
if edge.source in stage_names and edge.target in stage_names:
|
||||
graph[edge.source].append(edge.target)
|
||||
in_degree[edge.target] = in_degree.get(edge.target, 0) + 1
|
||||
|
||||
# Kahn's algorithm
|
||||
queue = [n for n in stage_names if in_degree.get(n, 0) == 0]
|
||||
# Stable sort: prefer order from config.stages
|
||||
stage_order = {s.name: i for i, s in enumerate(config.stages)}
|
||||
queue.sort(key=lambda n: stage_order.get(n, 999))
|
||||
|
||||
names = []
|
||||
while queue:
|
||||
node = queue.pop(0)
|
||||
names.append(node)
|
||||
for neighbor in graph.get(node, []):
|
||||
in_degree[neighbor] -= 1
|
||||
if in_degree[neighbor] == 0:
|
||||
queue.append(neighbor)
|
||||
queue.sort(key=lambda n: stage_order.get(n, 999))
|
||||
|
||||
if start_from:
|
||||
try:
|
||||
idx = names.index(start_from)
|
||||
names = names[idx:]
|
||||
except ValueError:
|
||||
raise ValueError(f"Stage {start_from!r} not in pipeline config")
|
||||
|
||||
return names
|
||||
|
||||
|
||||
class PipelineRunner:
|
||||
"""
|
||||
Executes a pipeline defined by PipelineConfig.
|
||||
|
||||
Runs stages sequentially (flattened). Each stage:
|
||||
1. Check cancel
|
||||
2. Run node function (via executor — local for now)
|
||||
3. Merge result into state
|
||||
4. Checkpoint (if enabled)
|
||||
5. Check pause
|
||||
|
||||
Executor socket: currently calls node functions directly.
|
||||
Future: dispatch to LocalExecutor / GrpcExecutor / LambdaExecutor
|
||||
based on StageRef.execution_target.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PipelineConfig,
|
||||
checkpoint: bool = False,
|
||||
start_from: str | None = None,
|
||||
):
|
||||
self.config = config
|
||||
self.do_checkpoint = checkpoint
|
||||
self.stage_sequence = _flatten_config(config, start_from)
|
||||
|
||||
def invoke(self, state: DetectState) -> DetectState:
|
||||
"""Run the pipeline on the given state. Returns final state."""
|
||||
for stage_name in self.stage_sequence:
|
||||
job_id = state.get("job_id", "")
|
||||
|
||||
# 1. Cancel check
|
||||
check = _cancel_check.get(job_id)
|
||||
if check and check():
|
||||
raise PipelineCancelled(f"Cancelled before {stage_name}")
|
||||
|
||||
# 2. Run node function
|
||||
node_fn = _NODE_FN_MAP.get(stage_name)
|
||||
if node_fn is None:
|
||||
logger.warning("No node function for stage %s, skipping", stage_name)
|
||||
continue
|
||||
|
||||
result = node_fn(state)
|
||||
|
||||
# 3. Merge result into state
|
||||
state.update(result)
|
||||
|
||||
# 4. Checkpoint
|
||||
if self.do_checkpoint:
|
||||
from core.detect.checkpoint import checkpoint_after_stage
|
||||
checkpoint_after_stage(job_id, stage_name, state, result)
|
||||
|
||||
# 5. Pause check
|
||||
_wait_if_paused(job_id, stage_name)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API — backwards compatible with old get_pipeline/build_graph
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def get_pipeline(
|
||||
checkpoint: bool | None = None,
|
||||
profile_name: str = "soccer_broadcast",
|
||||
start_from: str | None = None,
|
||||
) -> PipelineRunner:
|
||||
"""Return a PipelineRunner for the given profile."""
|
||||
from core.detect.profile import get_profile, pipeline_config_from_dict
|
||||
|
||||
do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED
|
||||
profile = get_profile(profile_name)
|
||||
config = pipeline_config_from_dict(profile["pipeline"])
|
||||
|
||||
return PipelineRunner(
|
||||
config=config,
|
||||
checkpoint=do_checkpoint,
|
||||
start_from=start_from,
|
||||
)
|
||||
|
||||
|
||||
def build_graph(checkpoint: bool | None = None, start_from: str | None = None):
|
||||
"""Backwards-compatible wrapper. Returns a PipelineRunner."""
|
||||
return get_pipeline(checkpoint=checkpoint, start_from=start_from)
|
||||
4
core/detect/inference/__init__.py
Normal file
4
core/detect/inference/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .client import InferenceClient
|
||||
from .types import DetectResult, OCRResult, VLMResult
|
||||
|
||||
__all__ = ["InferenceClient", "DetectResult", "OCRResult", "VLMResult"]
|
||||
262
core/detect/inference/client.py
Normal file
262
core/detect/inference/client.py
Normal file
@@ -0,0 +1,262 @@
|
||||
"""
|
||||
HTTP client for the inference server.
|
||||
|
||||
The pipeline stages call this instead of importing ML libraries directly.
|
||||
The inference server runs on the GPU machine (or spot instance).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
from .types import DetectResult, OCRResult, RegionDebugResult, RegionResult, ServerStatus, VLMResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_URL = os.environ.get("INFERENCE_URL", "http://localhost:8000")
|
||||
|
||||
|
||||
def _encode_image(image: np.ndarray) -> str:
|
||||
"""Encode numpy array as base64 JPEG."""
|
||||
img = Image.fromarray(image)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG", quality=85)
|
||||
return base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
|
||||
class InferenceClient:
|
||||
"""HTTP client for the GPU inference server."""
|
||||
|
||||
def __init__(self, base_url: str | None = None, timeout: float = 60.0,
|
||||
job_id: str = "", log_level: str = "INFO"):
|
||||
self.base_url = (base_url or DEFAULT_URL).rstrip("/")
|
||||
self.timeout = timeout
|
||||
self.job_id = job_id
|
||||
self.log_level = log_level
|
||||
self.session = requests.Session()
|
||||
if job_id:
|
||||
self.session.headers["X-Job-Id"] = job_id
|
||||
self.session.headers["X-Log-Level"] = log_level
|
||||
|
||||
def health(self) -> ServerStatus:
|
||||
"""Check server health and loaded models."""
|
||||
resp = self.session.get(f"{self.base_url}/health", timeout=self.timeout)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return ServerStatus(
|
||||
loaded_models=data.get("loaded_models", []),
|
||||
vram_used_mb=data.get("vram_used_mb", 0),
|
||||
vram_budget_mb=data.get("vram_budget_mb", 0),
|
||||
strategy=data.get("strategy", "sequential"),
|
||||
)
|
||||
|
||||
def detect(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
model: str = "yolov8n",
|
||||
confidence: float = 0.3,
|
||||
target_classes: list[str] | None = None,
|
||||
) -> list[DetectResult]:
|
||||
"""Run object detection on an image."""
|
||||
payload = {
|
||||
"image": _encode_image(image),
|
||||
"model": model,
|
||||
"confidence": confidence,
|
||||
}
|
||||
if target_classes:
|
||||
payload["target_classes"] = target_classes
|
||||
|
||||
resp = self.session.post(
|
||||
f"{self.base_url}/detect",
|
||||
json=payload,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
results = []
|
||||
for d in resp.json().get("detections", []):
|
||||
result = DetectResult(
|
||||
x=d["x"], y=d["y"], w=d["w"], h=d["h"],
|
||||
confidence=d["confidence"], label=d["label"],
|
||||
)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
def ocr(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
languages: list[str] | None = None,
|
||||
) -> list[OCRResult]:
|
||||
"""Run OCR on an image region."""
|
||||
payload = {
|
||||
"image": _encode_image(image),
|
||||
}
|
||||
if languages:
|
||||
payload["languages"] = languages
|
||||
|
||||
resp = self.session.post(
|
||||
f"{self.base_url}/ocr",
|
||||
json=payload,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
results = []
|
||||
for d in resp.json().get("results", []):
|
||||
result = OCRResult(
|
||||
text=d["text"],
|
||||
confidence=d["confidence"],
|
||||
bbox=tuple(d["bbox"]),
|
||||
)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
def vlm(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
prompt: str,
|
||||
model: str = "moondream2",
|
||||
) -> VLMResult:
|
||||
"""Query a visual language model with an image crop + prompt."""
|
||||
payload = {
|
||||
"image": _encode_image(image),
|
||||
"prompt": prompt,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
resp = self.session.post(
|
||||
f"{self.base_url}/vlm",
|
||||
json=payload,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
data = resp.json()
|
||||
return VLMResult(
|
||||
brand=data.get("brand", ""),
|
||||
confidence=data.get("confidence", 0.0),
|
||||
reasoning=data.get("reasoning", ""),
|
||||
)
|
||||
|
||||
def detect_edges(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
edge_canny_low: int = 50,
|
||||
edge_canny_high: int = 150,
|
||||
edge_hough_threshold: int = 80,
|
||||
edge_hough_min_length: int = 100,
|
||||
edge_hough_max_gap: int = 10,
|
||||
edge_pair_max_distance: int = 200,
|
||||
edge_pair_min_distance: int = 15,
|
||||
) -> list[RegionResult]:
|
||||
"""Run edge detection on an image."""
|
||||
payload = {
|
||||
"image": _encode_image(image),
|
||||
"edge_canny_low": edge_canny_low,
|
||||
"edge_canny_high": edge_canny_high,
|
||||
"edge_hough_threshold": edge_hough_threshold,
|
||||
"edge_hough_min_length": edge_hough_min_length,
|
||||
"edge_hough_max_gap": edge_hough_max_gap,
|
||||
"edge_pair_max_distance": edge_pair_max_distance,
|
||||
"edge_pair_min_distance": edge_pair_min_distance,
|
||||
}
|
||||
|
||||
resp = self.session.post(
|
||||
f"{self.base_url}/detect_edges",
|
||||
json=payload,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
results = []
|
||||
for r in resp.json().get("regions", []):
|
||||
result = RegionResult(
|
||||
x=r["x"], y=r["y"], w=r["w"], h=r["h"],
|
||||
confidence=r["confidence"], label=r["label"],
|
||||
)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
def detect_edges_debug(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
edge_canny_low: int = 50,
|
||||
edge_canny_high: int = 150,
|
||||
edge_hough_threshold: int = 80,
|
||||
edge_hough_min_length: int = 100,
|
||||
edge_hough_max_gap: int = 10,
|
||||
edge_pair_max_distance: int = 200,
|
||||
edge_pair_min_distance: int = 15,
|
||||
) -> RegionDebugResult:
|
||||
"""Run edge detection with debug overlays."""
|
||||
payload = {
|
||||
"image": _encode_image(image),
|
||||
"edge_canny_low": edge_canny_low,
|
||||
"edge_canny_high": edge_canny_high,
|
||||
"edge_hough_threshold": edge_hough_threshold,
|
||||
"edge_hough_min_length": edge_hough_min_length,
|
||||
"edge_hough_max_gap": edge_hough_max_gap,
|
||||
"edge_pair_max_distance": edge_pair_max_distance,
|
||||
"edge_pair_min_distance": edge_pair_min_distance,
|
||||
}
|
||||
|
||||
resp = self.session.post(
|
||||
f"{self.base_url}/detect_edges/debug",
|
||||
json=payload,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
data = resp.json()
|
||||
regions = []
|
||||
for r in data.get("regions", []):
|
||||
region = RegionResult(
|
||||
x=r["x"], y=r["y"], w=r["w"], h=r["h"],
|
||||
confidence=r["confidence"], label=r["label"],
|
||||
)
|
||||
regions.append(region)
|
||||
|
||||
return RegionDebugResult(
|
||||
regions=regions,
|
||||
edge_overlay_b64=data.get("edge_overlay_b64", ""),
|
||||
lines_overlay_b64=data.get("lines_overlay_b64", ""),
|
||||
horizontal_count=data.get("horizontal_count", 0),
|
||||
pair_count=data.get("pair_count", 0),
|
||||
)
|
||||
|
||||
def post(self, path: str, payload: dict) -> dict | None:
|
||||
"""Generic POST to the inference server. Returns JSON response or None on error."""
|
||||
try:
|
||||
resp = self.session.post(
|
||||
f"{self.base_url}{path}",
|
||||
json=payload,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
except Exception as e:
|
||||
logger.warning("Inference POST %s failed: %s", path, e)
|
||||
return None
|
||||
|
||||
def load_model(self, model: str, quantization: str = "fp16") -> None:
|
||||
"""Request the server to load a model into VRAM."""
|
||||
self.session.post(
|
||||
f"{self.base_url}/models/load",
|
||||
json={"model": model, "quantization": quantization},
|
||||
timeout=self.timeout,
|
||||
).raise_for_status()
|
||||
|
||||
def unload_model(self, model: str) -> None:
|
||||
"""Request the server to unload a model from VRAM."""
|
||||
self.session.post(
|
||||
f"{self.base_url}/models/unload",
|
||||
json={"model": model},
|
||||
timeout=self.timeout,
|
||||
).raise_for_status()
|
||||
76
core/detect/inference/types.py
Normal file
76
core/detect/inference/types.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
Inference response types.
|
||||
|
||||
These are the shapes returned by the inference server.
|
||||
Kept separate from core.detect.models to avoid coupling the
|
||||
inference protocol to pipeline internals.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectResult:
|
||||
"""Single object detection from YOLO or similar."""
|
||||
x: int
|
||||
y: int
|
||||
w: int
|
||||
h: int
|
||||
confidence: float
|
||||
label: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRResult:
|
||||
"""Text extracted from a region."""
|
||||
text: str
|
||||
confidence: float
|
||||
bbox: tuple[int, int, int, int] # x, y, w, h
|
||||
|
||||
|
||||
@dataclass
|
||||
class VLMResult:
|
||||
"""Visual language model response for a crop."""
|
||||
brand: str
|
||||
confidence: float
|
||||
reasoning: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegionResult:
|
||||
"""A candidate region from CV analysis."""
|
||||
x: int
|
||||
y: int
|
||||
w: int
|
||||
h: int
|
||||
confidence: float
|
||||
label: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegionDebugResult:
|
||||
"""CV region analysis with debug overlays."""
|
||||
regions: list[RegionResult] = field(default_factory=list)
|
||||
edge_overlay_b64: str = ""
|
||||
lines_overlay_b64: str = ""
|
||||
horizontal_count: int = 0
|
||||
pair_count: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""Info about a loaded model."""
|
||||
name: str
|
||||
vram_mb: float
|
||||
quantization: str # fp32, fp16, int8, int4
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServerStatus:
|
||||
"""Inference server health response."""
|
||||
loaded_models: list[ModelInfo] = field(default_factory=list)
|
||||
vram_used_mb: float = 0.0
|
||||
vram_budget_mb: float = 0.0
|
||||
strategy: str = "sequential" # sequential, concurrent, auto
|
||||
95
core/detect/models.py
Normal file
95
core/detect/models.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
Detection pipeline runtime models.
|
||||
|
||||
These are the data structures that flow between pipeline stages.
|
||||
They contain runtime types (np.ndarray) so they live here, not in
|
||||
core/schema/models/ (which is for modelgen source of truth).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class Frame:
|
||||
sequence: int
|
||||
chunk_id: int
|
||||
timestamp: float # position in video (seconds)
|
||||
image: np.ndarray
|
||||
perceptual_hash: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BoundingBox:
|
||||
x: int
|
||||
y: int
|
||||
w: int
|
||||
h: int
|
||||
confidence: float
|
||||
label: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextCandidate:
|
||||
frame: Frame
|
||||
bbox: BoundingBox
|
||||
text: str
|
||||
ocr_confidence: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class BrandDetection:
|
||||
brand: str
|
||||
timestamp: float
|
||||
duration: float
|
||||
confidence: float
|
||||
source: Literal["ocr", "local_vlm", "cloud_llm", "logo_match", "auxiliary"]
|
||||
bbox: BoundingBox | None = None
|
||||
frame_ref: int | None = None
|
||||
content_type: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BrandStats:
|
||||
total_appearances: int = 0
|
||||
total_screen_time: float = 0.0
|
||||
avg_confidence: float = 0.0
|
||||
first_seen: float = 0.0
|
||||
last_seen: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineStats:
|
||||
frames_extracted: int = 0
|
||||
frames_after_scene_filter: int = 0
|
||||
cv_regions_detected: int = 0
|
||||
regions_detected: int = 0
|
||||
regions_resolved_by_ocr: int = 0
|
||||
regions_escalated_to_local_vlm: int = 0
|
||||
regions_escalated_to_cloud_llm: int = 0
|
||||
auxiliary_detections: int = 0
|
||||
cloud_llm_calls: int = 0
|
||||
processing_time_seconds: float = 0.0
|
||||
estimated_cloud_cost_usd: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectionReport:
|
||||
video_source: str
|
||||
content_type: str
|
||||
duration_seconds: float
|
||||
brands: dict[str, BrandStats] = field(default_factory=dict)
|
||||
timeline: list[BrandDetection] = field(default_factory=list)
|
||||
pipeline_stats: PipelineStats = field(default_factory=PipelineStats)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CropContext:
|
||||
"""Runtime type — holds image bytes for VLM prompts."""
|
||||
image: bytes
|
||||
surrounding_text: str = ""
|
||||
position_hint: str = ""
|
||||
107
core/detect/profile.py
Normal file
107
core/detect/profile.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
Profile registry and helpers.
|
||||
|
||||
Loads profile data from Postgres.
|
||||
A profile is a dict with keys: name, pipeline, configs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
from core.detect.stages.models import PipelineConfig, StageRef, Edge
|
||||
from core.detect.models import (
|
||||
BrandDetection,
|
||||
BrandStats,
|
||||
CropContext,
|
||||
DetectionReport,
|
||||
PipelineStats,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_profile(name: str) -> Dict[str, Any]:
|
||||
"""Get a profile dict by name from the database."""
|
||||
from core.db.connection import get_session
|
||||
from core.db.models import Profile
|
||||
|
||||
with get_session() as session:
|
||||
row = session.query(Profile).filter(Profile.name == name).first()
|
||||
|
||||
if row is None:
|
||||
raise ValueError(f"Unknown profile: {name!r}")
|
||||
|
||||
return {
|
||||
"name": row.name,
|
||||
"pipeline": row.pipeline or {},
|
||||
"configs": row.configs or {},
|
||||
}
|
||||
|
||||
|
||||
def list_profiles() -> list[str]:
|
||||
"""List available profile names from the database."""
|
||||
from core.db.connection import get_session
|
||||
from core.db.models import Profile
|
||||
|
||||
with get_session() as session:
|
||||
rows = session.query(Profile.name).all()
|
||||
|
||||
return [r[0] for r in rows]
|
||||
|
||||
|
||||
def get_stage_config(profile: Dict[str, Any], stage_name: str) -> dict:
|
||||
"""Get config values for a stage from a profile."""
|
||||
return profile.get("configs", {}).get(stage_name, {})
|
||||
|
||||
|
||||
def pipeline_config_from_dict(data: Dict[str, Any]) -> PipelineConfig:
|
||||
"""Deserialize a PipelineConfig from a JSONB dict."""
|
||||
stages = [StageRef(**s) for s in data.get("stages", [])]
|
||||
edges = [Edge(**e) for e in data.get("edges", [])]
|
||||
return PipelineConfig(
|
||||
name=data.get("name", ""),
|
||||
profile_name=data.get("profile_name", ""),
|
||||
stages=stages,
|
||||
edges=edges,
|
||||
routing_rules=data.get("routing_rules", {}),
|
||||
)
|
||||
|
||||
|
||||
def build_vlm_prompt(crop_context: CropContext, template: str) -> str:
|
||||
"""Build a VLM prompt from a template and crop context."""
|
||||
hint = f" Position: {crop_context.position_hint}." if crop_context.position_hint else ""
|
||||
text = f" Nearby text: '{crop_context.surrounding_text}'." if crop_context.surrounding_text else ""
|
||||
return template.format(hint=hint, text=text)
|
||||
|
||||
|
||||
def aggregate_detections(
|
||||
detections: list[BrandDetection],
|
||||
content_type: str,
|
||||
) -> DetectionReport:
|
||||
"""Group detections by brand into a report."""
|
||||
brands: dict[str, BrandStats] = {}
|
||||
for d in detections:
|
||||
if d.brand not in brands:
|
||||
brands[d.brand] = BrandStats()
|
||||
s = brands[d.brand]
|
||||
s.total_appearances += 1
|
||||
s.total_screen_time += d.duration
|
||||
s.avg_confidence = (
|
||||
(s.avg_confidence * (s.total_appearances - 1) + d.confidence)
|
||||
/ s.total_appearances
|
||||
)
|
||||
if s.first_seen == 0.0 or d.timestamp < s.first_seen:
|
||||
s.first_seen = d.timestamp
|
||||
if d.timestamp > s.last_seen:
|
||||
s.last_seen = d.timestamp
|
||||
|
||||
return DetectionReport(
|
||||
video_source="",
|
||||
content_type=content_type,
|
||||
duration_seconds=0.0,
|
||||
brands=brands,
|
||||
timeline=sorted(detections, key=lambda d: d.timestamp),
|
||||
pipeline_stats=PipelineStats(),
|
||||
)
|
||||
58
core/detect/providers/__init__.py
Normal file
58
core/detect/providers/__init__.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
Cloud LLM provider registry.
|
||||
|
||||
Select provider via CLOUD_LLM_PROVIDER env var.
|
||||
Each provider reads its own env vars for auth/config.
|
||||
|
||||
CLOUD_LLM_PROVIDER=groq → GROQ_API_KEY, GROQ_MODEL, GROQ_BASE_URL
|
||||
CLOUD_LLM_PROVIDER=gemini → GEMINI_API_KEY, GEMINI_MODEL
|
||||
CLOUD_LLM_PROVIDER=openai → OPENAI_API_KEY, OPENAI_MODEL, OPENAI_BASE_URL
|
||||
CLOUD_LLM_PROVIDER=claude → ANTHROPIC_API_KEY, CLAUDE_MODEL
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from .base import CloudProvider, ProviderResponse
|
||||
from .groq import GroqProvider
|
||||
from .gemini import GeminiProvider
|
||||
from .openai_compat import OpenAICompatProvider
|
||||
from .claude import ClaudeProvider
|
||||
|
||||
PROVIDERS: dict[str, type] = {
|
||||
"groq": GroqProvider,
|
||||
"gemini": GeminiProvider,
|
||||
"openai": OpenAICompatProvider,
|
||||
"claude": ClaudeProvider,
|
||||
}
|
||||
|
||||
_cached: CloudProvider | None = None
|
||||
|
||||
|
||||
def get_provider() -> CloudProvider:
|
||||
"""Get the configured cloud provider (cached after first call)."""
|
||||
global _cached
|
||||
if _cached is not None:
|
||||
return _cached
|
||||
|
||||
name = os.environ.get("CLOUD_LLM_PROVIDER", "groq")
|
||||
cls = PROVIDERS.get(name)
|
||||
if cls is None:
|
||||
raise ValueError(f"Unknown provider: {name!r}. Options: {list(PROVIDERS)}")
|
||||
|
||||
_cached = cls()
|
||||
return _cached
|
||||
|
||||
|
||||
def has_api_key() -> bool:
|
||||
"""Check if the configured provider has an API key set."""
|
||||
name = os.environ.get("CLOUD_LLM_PROVIDER", "groq")
|
||||
key_map = {
|
||||
"groq": "GROQ_API_KEY",
|
||||
"gemini": "GEMINI_API_KEY",
|
||||
"openai": "OPENAI_API_KEY",
|
||||
"claude": "ANTHROPIC_API_KEY",
|
||||
}
|
||||
env_var = key_map.get(name, "")
|
||||
return bool(os.environ.get(env_var, ""))
|
||||
36
core/detect/providers/base.py
Normal file
36
core/detect/providers/base.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Cloud LLM provider protocol and model metadata."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""Metadata for a cloud LLM model."""
|
||||
id: str
|
||||
vision: bool = True
|
||||
cost_per_input_token: float = 0.0
|
||||
cost_per_output_token: float = 0.0
|
||||
max_output_tokens: int = 4096
|
||||
notes: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderResponse:
|
||||
answer: str
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
class CloudProvider(Protocol):
|
||||
"""
|
||||
Interface for cloud LLM providers.
|
||||
|
||||
Each provider handles its own auth, payload format, and response parsing.
|
||||
The pipeline only calls call() and reads the response.
|
||||
"""
|
||||
name: str
|
||||
models: dict[str, ModelInfo]
|
||||
|
||||
def call(self, image_b64: str, prompt: str) -> ProviderResponse: ...
|
||||
73
core/detect/providers/claude.py
Normal file
73
core/detect/providers/claude.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Anthropic Claude provider — uses the official SDK."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from .base import ModelInfo, ProviderResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Claude-specific env vars
|
||||
# ANTHROPIC_API_KEY is read by the SDK automatically
|
||||
CLAUDE_MODEL = os.environ.get("CLAUDE_MODEL", "claude-sonnet-4-20250514")
|
||||
|
||||
MODELS = {
|
||||
"claude-sonnet-4-20250514": ModelInfo(
|
||||
id="claude-sonnet-4-20250514",
|
||||
vision=True,
|
||||
cost_per_input_token=0.000003,
|
||||
cost_per_output_token=0.000015,
|
||||
notes="Best balance of quality/cost with vision",
|
||||
),
|
||||
"claude-haiku-4-5-20251001": ModelInfo(
|
||||
id="claude-haiku-4-5-20251001",
|
||||
vision=True,
|
||||
cost_per_input_token=0.0000008,
|
||||
cost_per_output_token=0.000004,
|
||||
notes="Fastest, cheapest, good for simple brand ID",
|
||||
),
|
||||
"claude-opus-4-6": ModelInfo(
|
||||
id="claude-opus-4-6",
|
||||
vision=True,
|
||||
cost_per_input_token=0.000015,
|
||||
cost_per_output_token=0.000075,
|
||||
notes="Highest quality, use for ambiguous cases",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class ClaudeProvider:
|
||||
name = "claude"
|
||||
models = MODELS
|
||||
|
||||
def __init__(self):
|
||||
from anthropic import Anthropic
|
||||
self.client = Anthropic()
|
||||
self.model = CLAUDE_MODEL
|
||||
|
||||
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
|
||||
message = self.client.messages.create(
|
||||
model=self.model,
|
||||
max_tokens=150,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": image_b64,
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}],
|
||||
)
|
||||
|
||||
answer = message.content[0].text.strip()
|
||||
total_tokens = message.usage.input_tokens + message.usage.output_tokens
|
||||
|
||||
return ProviderResponse(answer=answer, total_tokens=total_tokens)
|
||||
75
core/detect/providers/gemini.py
Normal file
75
core/detect/providers/gemini.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Google Gemini provider — native REST API, not OpenAI-compatible."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import requests
|
||||
|
||||
from .base import ModelInfo, ProviderResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Gemini-specific env vars
|
||||
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "")
|
||||
GEMINI_MODEL = os.environ.get("GEMINI_MODEL", "gemini-2.0-flash")
|
||||
|
||||
MODELS = {
|
||||
"gemini-2.0-flash": ModelInfo(
|
||||
id="gemini-2.0-flash",
|
||||
vision=True,
|
||||
cost_per_input_token=0.0000001,
|
||||
cost_per_output_token=0.0000004,
|
||||
notes="Fast, cheap, good vision",
|
||||
),
|
||||
"gemini-2.0-pro": ModelInfo(
|
||||
id="gemini-2.0-pro",
|
||||
vision=True,
|
||||
cost_per_input_token=0.00000125,
|
||||
cost_per_output_token=0.000005,
|
||||
notes="Higher quality, slower",
|
||||
),
|
||||
"gemini-1.5-flash": ModelInfo(
|
||||
id="gemini-1.5-flash",
|
||||
vision=True,
|
||||
cost_per_input_token=0.000000075,
|
||||
cost_per_output_token=0.0000003,
|
||||
notes="Cheapest option",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class GeminiProvider:
|
||||
name = "gemini"
|
||||
models = MODELS
|
||||
|
||||
def __init__(self):
|
||||
self.api_key = GEMINI_API_KEY
|
||||
self.model = GEMINI_MODEL
|
||||
self.endpoint = (
|
||||
f"https://generativelanguage.googleapis.com/v1beta/models/"
|
||||
f"{self.model}:generateContent"
|
||||
)
|
||||
|
||||
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
|
||||
payload = {
|
||||
"contents": [{
|
||||
"parts": [
|
||||
{"text": prompt},
|
||||
{"inline_data": {"mime_type": "image/jpeg", "data": image_b64}},
|
||||
],
|
||||
}],
|
||||
"generationConfig": {"maxOutputTokens": 150},
|
||||
}
|
||||
|
||||
url = f"{self.endpoint}?key={self.api_key}"
|
||||
resp = requests.post(url, json=payload, timeout=30)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
answer = data["candidates"][0]["content"]["parts"][0]["text"].strip()
|
||||
usage = data.get("usageMetadata", {})
|
||||
total_tokens = usage.get("totalTokenCount", 0)
|
||||
|
||||
return ProviderResponse(answer=answer, total_tokens=total_tokens)
|
||||
66
core/detect/providers/groq.py
Normal file
66
core/detect/providers/groq.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Groq cloud provider — OpenAI-compatible API with vision."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import requests
|
||||
|
||||
from .base import ModelInfo, ProviderResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Groq-specific env vars
|
||||
GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "")
|
||||
GROQ_BASE_URL = os.environ.get("GROQ_BASE_URL", "https://api.groq.com/openai/v1")
|
||||
GROQ_MODEL = os.environ.get("GROQ_MODEL", "meta-llama/llama-4-scout-17b-16e-instruct")
|
||||
|
||||
MODELS = {
|
||||
"meta-llama/llama-4-scout-17b-16e-instruct": ModelInfo(
|
||||
id="meta-llama/llama-4-scout-17b-16e-instruct",
|
||||
vision=True,
|
||||
cost_per_input_token=0.0,
|
||||
cost_per_output_token=0.0,
|
||||
notes="Llama 4 Scout, only vision model on Groq free tier",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class GroqProvider:
|
||||
name = "groq"
|
||||
models = MODELS
|
||||
|
||||
def __init__(self):
|
||||
self.api_key = GROQ_API_KEY
|
||||
self.base_url = GROQ_BASE_URL
|
||||
self.model = GROQ_MODEL
|
||||
self.endpoint = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{"type": "image_url", "image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_b64}",
|
||||
}},
|
||||
],
|
||||
}],
|
||||
"max_tokens": 150,
|
||||
}
|
||||
|
||||
resp = requests.post(self.endpoint, headers=self.headers, json=payload, timeout=30)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
answer = data["choices"][0]["message"]["content"].strip()
|
||||
total_tokens = data.get("usage", {}).get("total_tokens", 0)
|
||||
|
||||
return ProviderResponse(answer=answer, total_tokens=total_tokens)
|
||||
73
core/detect/providers/openai_compat.py
Normal file
73
core/detect/providers/openai_compat.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Generic OpenAI-compatible provider (OpenAI, Together, etc.)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import requests
|
||||
|
||||
from .base import ModelInfo, ProviderResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# OpenAI-compat specific env vars
|
||||
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
|
||||
OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
||||
OPENAI_MODEL = os.environ.get("OPENAI_MODEL", "gpt-4o-mini")
|
||||
|
||||
MODELS = {
|
||||
"gpt-4o-mini": ModelInfo(
|
||||
id="gpt-4o-mini",
|
||||
vision=True,
|
||||
cost_per_input_token=0.00000015,
|
||||
cost_per_output_token=0.0000006,
|
||||
notes="Cheap, fast, decent vision",
|
||||
),
|
||||
"gpt-4o": ModelInfo(
|
||||
id="gpt-4o",
|
||||
vision=True,
|
||||
cost_per_input_token=0.0000025,
|
||||
cost_per_output_token=0.00001,
|
||||
notes="Best OpenAI vision model",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class OpenAICompatProvider:
|
||||
name = "openai"
|
||||
models = MODELS
|
||||
|
||||
def __init__(self):
|
||||
self.api_key = OPENAI_API_KEY
|
||||
self.base_url = OPENAI_BASE_URL
|
||||
self.model = OPENAI_MODEL
|
||||
self.endpoint = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{"type": "image_url", "image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_b64}",
|
||||
}},
|
||||
],
|
||||
}],
|
||||
"max_tokens": 150,
|
||||
}
|
||||
|
||||
resp = requests.post(self.endpoint, headers=self.headers, json=payload, timeout=30)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
answer = data["choices"][0]["message"]["content"].strip()
|
||||
total_tokens = data.get("usage", {}).get("total_tokens", 0)
|
||||
|
||||
return ProviderResponse(answer=answer, total_tokens=total_tokens)
|
||||
163
core/detect/sse.py
Normal file
163
core/detect/sse.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
Pydantic Models - GENERATED FILE
|
||||
|
||||
Do not edit directly. Regenerate using modelgen.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class GraphNode(BaseModel):
|
||||
"""A pipeline stage node."""
|
||||
id: str
|
||||
status: str = "idle"
|
||||
items_in: int = 0
|
||||
items_out: int = 0
|
||||
|
||||
class GraphEdge(BaseModel):
|
||||
"""An edge between pipeline stages."""
|
||||
source: str
|
||||
target: str
|
||||
throughput: int = 0
|
||||
|
||||
class BoundingBoxEvent(BaseModel):
|
||||
"""Bounding box in SSE event payloads."""
|
||||
x: int
|
||||
y: int
|
||||
w: int
|
||||
h: int
|
||||
confidence: float
|
||||
label: str
|
||||
resolved_brand: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
stage: Optional[str] = None
|
||||
|
||||
class BrandSummary(BaseModel):
|
||||
"""Per-brand stats in the final report."""
|
||||
brand: str
|
||||
total_appearances: int = 0
|
||||
total_screen_time: float = 0.0
|
||||
avg_confidence: float = 0.0
|
||||
first_seen: float = 0.0
|
||||
last_seen: float = 0.0
|
||||
|
||||
class GraphUpdate(BaseModel):
|
||||
"""Pipeline node state transition. SSE event: graph_update"""
|
||||
nodes: List[GraphNode] = Field(default_factory=list)
|
||||
edges: List[GraphEdge] = Field(default_factory=list)
|
||||
active_path: List[str] = Field(default_factory=list)
|
||||
|
||||
class StatsUpdate(BaseModel):
|
||||
"""Funnel statistics snapshot. SSE event: stats_update"""
|
||||
frames_extracted: int = 0
|
||||
frames_after_scene_filter: int = 0
|
||||
cv_regions_detected: int = 0
|
||||
regions_detected: int = 0
|
||||
regions_resolved_by_ocr: int = 0
|
||||
regions_escalated_to_local_vlm: int = 0
|
||||
regions_escalated_to_cloud_llm: int = 0
|
||||
cloud_llm_calls: int = 0
|
||||
processing_time_seconds: float = 0.0
|
||||
estimated_cloud_cost_usd: float = 0.0
|
||||
run_id: Optional[str] = None
|
||||
parent_job_id: Optional[str] = None
|
||||
run_type: str = "initial"
|
||||
|
||||
class FrameUpdate(BaseModel):
|
||||
"""Current frame being processed. SSE event: frame_update"""
|
||||
frame_ref: int
|
||||
timestamp: float
|
||||
jpeg_b64: str
|
||||
boxes: List[BoundingBoxEvent] = Field(default_factory=list)
|
||||
|
||||
class Detection(BaseModel):
|
||||
"""A confirmed brand detection. SSE event: detection"""
|
||||
brand: str
|
||||
timestamp: float
|
||||
duration: float
|
||||
confidence: float
|
||||
source: str
|
||||
content_type: str
|
||||
bbox: Optional[BoundingBoxEvent] = None
|
||||
frame_ref: Optional[int] = None
|
||||
|
||||
class LogEvent(BaseModel):
|
||||
"""Pipeline log line. SSE event: log"""
|
||||
level: str
|
||||
stage: str
|
||||
msg: str
|
||||
ts: str
|
||||
trace_id: Optional[str] = None
|
||||
|
||||
class DetectionReportSummary(BaseModel):
|
||||
"""Final detection report summary."""
|
||||
video_source: str
|
||||
content_type: str
|
||||
duration_seconds: float
|
||||
total_detections: int = 0
|
||||
brands: List[BrandSummary] = Field(default_factory=list)
|
||||
stats: Optional[StatsUpdate] = None
|
||||
|
||||
class JobComplete(BaseModel):
|
||||
"""Final report when pipeline finishes. SSE event: job_complete"""
|
||||
job_id: str
|
||||
report: Optional[DetectionReportSummary] = None
|
||||
|
||||
class RunContext(BaseModel):
|
||||
"""Run context injected into all SSE events for grouping."""
|
||||
run_id: str
|
||||
parent_job_id: str
|
||||
run_type: str = "initial"
|
||||
|
||||
class CheckpointInfo(BaseModel):
|
||||
"""Available checkpoint for a stage."""
|
||||
stage: str
|
||||
is_scenario: bool = False
|
||||
scenario_label: str = ""
|
||||
|
||||
class ReplayRequest(BaseModel):
|
||||
"""Request to replay pipeline from a specific stage."""
|
||||
job_id: str
|
||||
start_stage: str
|
||||
config_overrides: Optional[Dict[str, Any]] = None
|
||||
|
||||
class ReplayResponse(BaseModel):
|
||||
"""Result of a replay invocation."""
|
||||
status: str
|
||||
job_id: str
|
||||
start_stage: str
|
||||
detections: int = 0
|
||||
brands_found: int = 0
|
||||
|
||||
class RetryRequest(BaseModel):
|
||||
"""Request to queue async retry with different config."""
|
||||
job_id: str
|
||||
config_overrides: Optional[Dict[str, Any]] = None
|
||||
start_stage: str = "escalate_vlm"
|
||||
schedule_seconds: Optional[float] = None
|
||||
|
||||
class RetryResponse(BaseModel):
|
||||
"""Result of queueing a retry task."""
|
||||
status: str
|
||||
task_id: str
|
||||
job_id: str
|
||||
|
||||
class RunRequest(BaseModel):
|
||||
"""Request body for launching a detection pipeline run."""
|
||||
video_path: str
|
||||
profile_name: str = "soccer_broadcast"
|
||||
source_asset_id: str = ""
|
||||
checkpoint: bool = True
|
||||
skip_vlm: bool = False
|
||||
skip_cloud: bool = False
|
||||
log_level: str = "INFO"
|
||||
|
||||
class RunResponse(BaseModel):
|
||||
"""Response after starting a pipeline run."""
|
||||
status: str
|
||||
job_id: str
|
||||
video_path: str
|
||||
22
core/detect/stages/__init__.py
Normal file
22
core/detect/stages/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
Pipeline stages.
|
||||
|
||||
Each stage is a file with a Stage subclass. Auto-discovered via
|
||||
__init_subclass__ — importing the file registers the stage.
|
||||
"""
|
||||
|
||||
from .base import (
|
||||
Stage,
|
||||
get_stage,
|
||||
get_stage_instance,
|
||||
list_stages,
|
||||
list_stage_classes,
|
||||
get_palette,
|
||||
)
|
||||
|
||||
# Import all stage files to trigger auto-registration
|
||||
from . import edge_detector # noqa: F401
|
||||
from . import field_segmentation # noqa: F401
|
||||
|
||||
# Import registry for backward compat (other stages still use old pattern)
|
||||
from . import registry # noqa: F401
|
||||
116
core/detect/stages/aggregator.py
Normal file
116
core/detect/stages/aggregator.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
Stage 8 — Report compilation
|
||||
|
||||
Groups all detections by brand, merges contiguous appearances,
|
||||
and builds the final DetectionReport.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from core.detect import emit
|
||||
from core.detect.models import BrandDetection, BrandStats, DetectionReport, PipelineStats
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _merge_contiguous(detections: list[BrandDetection], gap_threshold: float = 2.0) -> list[BrandDetection]:
|
||||
"""
|
||||
Merge detections of the same brand that are close in time.
|
||||
|
||||
If two detections of the same brand are within gap_threshold seconds,
|
||||
they're merged into one detection spanning the full range.
|
||||
"""
|
||||
if not detections:
|
||||
return []
|
||||
|
||||
sorted_dets = sorted(detections, key=lambda d: (d.brand, d.timestamp))
|
||||
merged: list[BrandDetection] = []
|
||||
current = sorted_dets[0]
|
||||
|
||||
for det in sorted_dets[1:]:
|
||||
if (det.brand == current.brand
|
||||
and det.timestamp <= current.timestamp + current.duration + gap_threshold):
|
||||
end = max(current.timestamp + current.duration,
|
||||
det.timestamp + det.duration)
|
||||
current = BrandDetection(
|
||||
brand=current.brand,
|
||||
timestamp=current.timestamp,
|
||||
duration=end - current.timestamp,
|
||||
confidence=max(current.confidence, det.confidence),
|
||||
source=current.source,
|
||||
bbox=current.bbox,
|
||||
frame_ref=current.frame_ref,
|
||||
content_type=current.content_type,
|
||||
)
|
||||
else:
|
||||
merged.append(current)
|
||||
current = det
|
||||
|
||||
merged.append(current)
|
||||
return merged
|
||||
|
||||
|
||||
def compile_report(
|
||||
detections: list[BrandDetection],
|
||||
stats: PipelineStats,
|
||||
video_source: str = "",
|
||||
content_type: str = "",
|
||||
duration_seconds: float = 0.0,
|
||||
job_id: str | None = None,
|
||||
) -> DetectionReport:
|
||||
"""
|
||||
Build the final detection report from all accumulated detections.
|
||||
|
||||
Merges contiguous detections, computes per-brand stats,
|
||||
and emits the job_complete event.
|
||||
"""
|
||||
merged = _merge_contiguous(detections)
|
||||
|
||||
brands: dict[str, BrandStats] = {}
|
||||
for d in merged:
|
||||
if d.brand not in brands:
|
||||
brands[d.brand] = BrandStats()
|
||||
s = brands[d.brand]
|
||||
s.total_appearances += 1
|
||||
s.total_screen_time += d.duration
|
||||
s.avg_confidence = (
|
||||
(s.avg_confidence * (s.total_appearances - 1) + d.confidence)
|
||||
/ s.total_appearances
|
||||
)
|
||||
if s.first_seen == 0.0 or d.timestamp < s.first_seen:
|
||||
s.first_seen = d.timestamp
|
||||
if d.timestamp > s.last_seen:
|
||||
s.last_seen = d.timestamp
|
||||
|
||||
report = DetectionReport(
|
||||
video_source=video_source,
|
||||
content_type=content_type,
|
||||
duration_seconds=duration_seconds,
|
||||
brands=brands,
|
||||
timeline=sorted(merged, key=lambda d: d.timestamp),
|
||||
pipeline_stats=stats,
|
||||
)
|
||||
|
||||
emit.log(job_id, "Aggregator", "INFO",
|
||||
f"Report: {len(brands)} brands, {len(merged)} detections "
|
||||
f"(merged from {len(detections)} raw)")
|
||||
|
||||
emit.job_complete(job_id, {
|
||||
"video_source": report.video_source,
|
||||
"content_type": report.content_type,
|
||||
"duration_seconds": report.duration_seconds,
|
||||
"brands": {
|
||||
k: {
|
||||
"total_appearances": v.total_appearances,
|
||||
"total_screen_time": v.total_screen_time,
|
||||
"avg_confidence": round(v.avg_confidence, 3),
|
||||
"first_seen": v.first_seen,
|
||||
"last_seen": v.last_seen,
|
||||
}
|
||||
for k, v in brands.items()
|
||||
},
|
||||
})
|
||||
|
||||
return report
|
||||
151
core/detect/stages/base.py
Normal file
151
core/detect/stages/base.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""
|
||||
Stage base class — common interface for all pipeline stages.
|
||||
|
||||
Each stage is a file that subclasses Stage. Auto-discovered via
|
||||
__init_subclass__. No manual registration needed.
|
||||
|
||||
A stage:
|
||||
- Has a StageDefinition (generated from schema) with name, config, IO
|
||||
- Implements run(frames, config) → output
|
||||
- Owns its output serialization (opaque blob)
|
||||
- Optionally has a TypeScript port for browser-side execution
|
||||
|
||||
The checkpoint layer stores stage output as blobs without knowing
|
||||
the format. The stage that wrote it is the only one that can read it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from core.detect.stages.models import (
|
||||
StageConfigField,
|
||||
StageIO,
|
||||
StageDefinition,
|
||||
)
|
||||
|
||||
|
||||
# Legacy runtime extension — adds callable fields for old-style stages.
|
||||
# New stages use Stage subclass with serialize()/deserialize() methods instead.
|
||||
class LegacyStageDefinition:
|
||||
"""Wraps a StageDefinition with callable serialize/deserialize functions."""
|
||||
|
||||
def __init__(self, definition: StageDefinition, fn=None, serialize_fn=None, deserialize_fn=None):
|
||||
self._definition = definition
|
||||
self.fn = fn
|
||||
self.serialize_fn = serialize_fn
|
||||
self.deserialize_fn = deserialize_fn
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._definition, name)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry — auto-populated by __init_subclass__ (new stages)
|
||||
# + register_stage() (legacy stages during migration)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_REGISTRY: dict[str, type['Stage']] = {}
|
||||
_LEGACY_REGISTRY: dict[str, LegacyStageDefinition] = {}
|
||||
|
||||
|
||||
def register_stage(
|
||||
definition: StageDefinition,
|
||||
fn=None,
|
||||
serialize_fn=None,
|
||||
deserialize_fn=None,
|
||||
):
|
||||
"""Legacy registration for stages not yet converted to Stage subclass."""
|
||||
legacy = LegacyStageDefinition(definition, fn=fn, serialize_fn=serialize_fn, deserialize_fn=deserialize_fn)
|
||||
_LEGACY_REGISTRY[definition.name] = legacy
|
||||
|
||||
|
||||
class Stage:
|
||||
"""
|
||||
Base class for all pipeline stages.
|
||||
|
||||
Subclass this in detect/stages/<name>.py. Define `definition` as a
|
||||
class attribute. Implement `run()`. Optionally override `serialize()`
|
||||
and `deserialize()` for custom blob formats (default is JSON).
|
||||
"""
|
||||
|
||||
definition: StageDefinition # set by each subclass
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if hasattr(cls, 'definition') and cls.definition is not None:
|
||||
_REGISTRY[cls.definition.name] = cls
|
||||
|
||||
def run(self, frames: list, config: dict) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def serialize(self, output: Any) -> bytes:
|
||||
"""Serialize stage output to bytes for checkpoint storage."""
|
||||
import json
|
||||
return json.dumps(output, default=str).encode()
|
||||
|
||||
def deserialize(self, data: bytes) -> Any:
|
||||
"""Deserialize stage output from checkpoint blob."""
|
||||
import json
|
||||
return json.loads(data)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Discovery API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _all_definitions():
|
||||
"""Merge new Stage subclass registry + legacy registry.
|
||||
|
||||
Returns StageDefinition for new-style stages,
|
||||
LegacyStageDefinition for legacy stages (has serialize_fn etc).
|
||||
"""
|
||||
merged = {}
|
||||
for name, legacy in _LEGACY_REGISTRY.items():
|
||||
merged[name] = legacy
|
||||
for name, cls in _REGISTRY.items():
|
||||
merged[name] = cls.definition
|
||||
return merged
|
||||
|
||||
|
||||
def get_stage(name: str) -> StageDefinition:
|
||||
"""Get a stage definition by name (works for both new and legacy)."""
|
||||
all_defs = _all_definitions()
|
||||
if name not in all_defs:
|
||||
raise KeyError(f"Unknown stage: {name!r}. Registered: {list(all_defs)}")
|
||||
return all_defs[name]
|
||||
|
||||
|
||||
def get_stage_class(name: str) -> type[Stage] | None:
|
||||
"""Get a Stage subclass by name. Returns None for legacy stages."""
|
||||
return _REGISTRY.get(name)
|
||||
|
||||
|
||||
def get_stage_instance(name: str) -> Stage:
|
||||
"""Get an instantiated Stage by name. Only works for new-style stages."""
|
||||
cls = _REGISTRY.get(name)
|
||||
if cls is None:
|
||||
raise KeyError(f"No Stage subclass for {name!r}. Legacy stages don't have instances.")
|
||||
return cls()
|
||||
|
||||
|
||||
def list_stages() -> list[StageDefinition]:
|
||||
"""List all registered stage definitions (new + legacy)."""
|
||||
return list(_all_definitions().values())
|
||||
|
||||
|
||||
def list_stage_classes() -> list[type[Stage]]:
|
||||
"""List all registered Stage subclasses (new-style only)."""
|
||||
return list(_REGISTRY.values())
|
||||
|
||||
|
||||
def get_palette() -> dict[str, list[StageDefinition]]:
|
||||
"""Group stages by category for the editor palette."""
|
||||
palette: dict[str, list[StageDefinition]] = {}
|
||||
for defn in _all_definitions().values():
|
||||
if defn.category not in palette:
|
||||
palette[defn.category] = []
|
||||
palette[defn.category].append(defn)
|
||||
return palette
|
||||
216
core/detect/stages/brand_resolver.py
Normal file
216
core/detect/stages/brand_resolver.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""
|
||||
Stage 5 — Brand Resolver (discovery mode)
|
||||
|
||||
Discovery-first brand matching. No static dictionary — all brands live in the DB.
|
||||
|
||||
Flow:
|
||||
1. Check session brands first (brands already seen in this run, in-memory)
|
||||
2. Check global known brands (accumulated across all runs)
|
||||
3. Unresolved candidates → escalate to VLM/cloud
|
||||
4. Confirmed brands get added to DB for future runs
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from rapidfuzz import fuzz
|
||||
|
||||
from core.detect import emit
|
||||
from core.detect.models import BrandDetection, TextCandidate
|
||||
from core.detect.stages.models import ResolverConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _normalize(text: str) -> str:
|
||||
return text.strip().lower()
|
||||
|
||||
|
||||
def _has_db() -> bool:
|
||||
try:
|
||||
from core.db import find_brand_by_text as _
|
||||
return True
|
||||
except (ImportError, Exception):
|
||||
return False
|
||||
|
||||
|
||||
def _match_session(text: str, session_brands: dict[str, str]) -> str | None:
|
||||
return session_brands.get(_normalize(text))
|
||||
|
||||
|
||||
def _match_known(text: str, threshold: int) -> tuple[str | None, str | None]:
|
||||
"""Check against global known brands in DB. Returns (canonical_name, brand_id) or (None, None)."""
|
||||
if not _has_db():
|
||||
return None, None
|
||||
|
||||
from core.db import find_brand_by_text, list_brands
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
brand = find_brand_by_text(session, text)
|
||||
if brand:
|
||||
return brand.canonical_name, str(brand.id)
|
||||
|
||||
all_brands = list_brands(session)
|
||||
|
||||
normalized = _normalize(text)
|
||||
best_brand = None
|
||||
best_score = 0
|
||||
|
||||
for known in all_brands:
|
||||
names = [known.canonical_name] + (known.aliases or [])
|
||||
for name in names:
|
||||
score = fuzz.ratio(normalized, _normalize(name))
|
||||
if score > best_score and score >= threshold:
|
||||
best_score = score
|
||||
best_brand = known
|
||||
|
||||
if best_brand:
|
||||
return best_brand.canonical_name, str(best_brand.id)
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def _register_brand(canonical_name: str, source: str) -> str | None:
|
||||
"""Register a newly discovered brand in the DB. Returns brand_id."""
|
||||
if not _has_db():
|
||||
return None
|
||||
|
||||
from core.db import get_or_create_brand
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
brand, created = get_or_create_brand(session, canonical_name, source=source)
|
||||
session.commit()
|
||||
if created:
|
||||
logger.info("New brand discovered: %s (source=%s)", canonical_name, source)
|
||||
return str(brand.id)
|
||||
|
||||
|
||||
def _record_airing(timeline_id: str | None, brand_id: str,
|
||||
frame_seq: int, confidence: float, source: str):
|
||||
"""Record a brand airing on a timeline."""
|
||||
if not _has_db() or not timeline_id:
|
||||
return
|
||||
|
||||
from core.db import record_airing
|
||||
from core.db.connection import get_session
|
||||
from uuid import UUID
|
||||
|
||||
with get_session() as session:
|
||||
record_airing(
|
||||
session,
|
||||
brand_id=UUID(brand_id),
|
||||
timeline_id=UUID(timeline_id),
|
||||
frame_start=frame_seq,
|
||||
frame_end=frame_seq,
|
||||
confidence=confidence,
|
||||
source=source,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
|
||||
def build_session_dict(source_asset_id: str | None = None) -> dict[str, str]:
|
||||
"""
|
||||
Load known brands from DB as a session lookup dict.
|
||||
|
||||
Returns {normalized_name: canonical_name, ...} including aliases.
|
||||
"""
|
||||
if not _has_db():
|
||||
return {}
|
||||
|
||||
from core.db import list_brands
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
all_brands = list_brands(session)
|
||||
|
||||
session_dict = {}
|
||||
for brand in all_brands:
|
||||
session_dict[_normalize(brand.canonical_name)] = brand.canonical_name
|
||||
for alias in (brand.aliases or []):
|
||||
session_dict[_normalize(alias)] = brand.canonical_name
|
||||
|
||||
return session_dict
|
||||
|
||||
|
||||
def resolve_brands(
|
||||
candidates: list[TextCandidate],
|
||||
config: ResolverConfig,
|
||||
session_brands: dict[str, str] | None = None,
|
||||
source_asset_id: str | None = None,
|
||||
content_type: str = "",
|
||||
job_id: str | None = None,
|
||||
) -> tuple[list[BrandDetection], list[TextCandidate]]:
|
||||
"""
|
||||
Match text candidates against known brands (session → global → unresolved).
|
||||
|
||||
session_brands: pre-loaded session dict (from build_session_dict)
|
||||
job_id: timeline_id — used to record airings
|
||||
"""
|
||||
if session_brands is None:
|
||||
session_brands = {}
|
||||
|
||||
emit.log(job_id, "BrandResolver", "INFO",
|
||||
f"Resolving {len(candidates)} candidates "
|
||||
f"(session={len(session_brands)} brands, fuzzy={config.fuzzy_threshold})")
|
||||
|
||||
matched: list[BrandDetection] = []
|
||||
unresolved: list[TextCandidate] = []
|
||||
session_hits = 0
|
||||
known_hits = 0
|
||||
|
||||
for candidate in candidates:
|
||||
text = candidate.text
|
||||
brand_name = None
|
||||
brand_id = None
|
||||
match_source = "ocr"
|
||||
|
||||
# 1. Check session (cheapest — in-memory dict)
|
||||
brand_name = _match_session(text, session_brands)
|
||||
if brand_name:
|
||||
session_hits += 1
|
||||
else:
|
||||
# 2. Check global known brands (DB query + fuzzy)
|
||||
brand_name, brand_id = _match_known(text, config.fuzzy_threshold)
|
||||
if brand_name:
|
||||
known_hits += 1
|
||||
session_brands[_normalize(brand_name)] = brand_name
|
||||
|
||||
if brand_name:
|
||||
detection = BrandDetection(
|
||||
brand=brand_name,
|
||||
timestamp=candidate.frame.timestamp,
|
||||
duration=0.5,
|
||||
confidence=candidate.ocr_confidence,
|
||||
source=match_source,
|
||||
bbox=candidate.bbox,
|
||||
frame_ref=candidate.frame.sequence,
|
||||
content_type=content_type,
|
||||
)
|
||||
matched.append(detection)
|
||||
|
||||
if brand_id:
|
||||
_record_airing(
|
||||
job_id, brand_id,
|
||||
candidate.frame.sequence, candidate.ocr_confidence, match_source,
|
||||
)
|
||||
|
||||
emit.detection(
|
||||
job_id,
|
||||
brand=brand_name,
|
||||
confidence=candidate.ocr_confidence,
|
||||
source=match_source,
|
||||
timestamp=candidate.frame.timestamp,
|
||||
content_type=content_type,
|
||||
frame_ref=candidate.frame.sequence,
|
||||
)
|
||||
else:
|
||||
unresolved.append(candidate)
|
||||
|
||||
emit.log(job_id, "BrandResolver", "INFO",
|
||||
f"Session: {session_hits}, Known: {known_hits}, "
|
||||
f"Unresolved: {len(unresolved)} → escalating")
|
||||
|
||||
return matched, unresolved
|
||||
288
core/detect/stages/edge_detector.py
Normal file
288
core/detect/stages/edge_detector.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
Stage — Edge Detection
|
||||
|
||||
Canny + HoughLinesP to find horizontal line pairs that bound
|
||||
advertising hoardings. Pure OpenCV, no ML models.
|
||||
|
||||
Two modes:
|
||||
- Remote: calls GPU inference server over HTTP
|
||||
- Local: imports cv2 directly (OpenCV on same machine)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from core.detect import emit
|
||||
from core.detect.models import BoundingBox, Frame
|
||||
from core.detect.stages.base import Stage
|
||||
from core.detect.stages.models import StageDefinition, StageConfigField, StageIO
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EdgeDetectionStage(Stage):
|
||||
|
||||
definition = StageDefinition(
|
||||
name="detect_edges",
|
||||
label="Edge Detection",
|
||||
description="Canny + HoughLinesP — find horizontal line pairs (hoarding boundaries)",
|
||||
category="cv_analysis",
|
||||
io=StageIO(
|
||||
reads=["filtered_frames"],
|
||||
writes=["edge_regions_by_frame"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField(name="enabled", type="bool", default=True, description="Enable edge detection"),
|
||||
StageConfigField(name="edge_canny_low", type="int", default=50, description="Canny low threshold", min=0, max=255),
|
||||
StageConfigField(name="edge_canny_high", type="int", default=150, description="Canny high threshold", min=0, max=255),
|
||||
StageConfigField(name="edge_hough_threshold", type="int", default=80, description="Hough accumulator threshold", min=1, max=500),
|
||||
StageConfigField(name="edge_hough_min_length", type="int", default=100, description="Min line length (px)", min=10, max=2000),
|
||||
StageConfigField(name="edge_hough_max_gap", type="int", default=10, description="Max line gap (px)", min=1, max=100),
|
||||
StageConfigField(name="edge_pair_max_distance", type="int", default=200, description="Max distance between line pair (px)", min=10, max=500),
|
||||
StageConfigField(name="edge_pair_min_distance", type="int", default=15, description="Min distance between line pair (px)", min=5, max=200),
|
||||
],
|
||||
tracks_element="edge_region",
|
||||
)
|
||||
|
||||
def run(self, frames: list[Frame], config: dict) -> dict[int, list[BoundingBox]]:
|
||||
"""
|
||||
Run edge detection on all frames.
|
||||
|
||||
Config keys: enabled, edge_canny_low, edge_canny_high, edge_hough_threshold,
|
||||
edge_hough_min_length, edge_hough_max_gap, edge_pair_max_distance, edge_pair_min_distance,
|
||||
debug (bool), inference_url (str|None), job_id (str|None).
|
||||
|
||||
Returns dict mapping frame sequence → list of BoundingBox.
|
||||
"""
|
||||
enabled = config.get("enabled", True)
|
||||
job_id = config.get("job_id")
|
||||
inference_url = config.get("inference_url") or os.environ.get("INFERENCE_URL")
|
||||
|
||||
if not enabled:
|
||||
emit.log(job_id, "EdgeDetection", "INFO", "Edge detection disabled, skipping")
|
||||
return {}
|
||||
|
||||
mode = "remote" if inference_url else "local"
|
||||
emit.log(job_id, "EdgeDetection", "INFO",
|
||||
f"Detecting edges in {len(frames)} frames (mode={mode})")
|
||||
|
||||
all_boxes: dict[int, list[BoundingBox]] = {}
|
||||
total_regions = 0
|
||||
|
||||
for frame in frames:
|
||||
t0 = time.monotonic()
|
||||
if inference_url:
|
||||
boxes = self._run_remote(frame, config, inference_url, job_id or "")
|
||||
else:
|
||||
boxes = self._run_local(frame, config)
|
||||
ms = (time.monotonic() - t0) * 1000
|
||||
|
||||
all_boxes[frame.sequence] = boxes
|
||||
total_regions += len(boxes)
|
||||
|
||||
emit.log(job_id, "EdgeDetection", "DEBUG",
|
||||
f"Frame {frame.sequence}: {len(boxes)} regions in {ms:.0f}ms"
|
||||
+ (f" [{', '.join(b.label for b in boxes)}]" if boxes else ""))
|
||||
|
||||
if boxes and job_id:
|
||||
box_dicts = [
|
||||
{"x": b.x, "y": b.y, "w": b.w, "h": b.h,
|
||||
"confidence": b.confidence, "label": b.label,
|
||||
"stage": "detect_edges"}
|
||||
for b in boxes
|
||||
]
|
||||
emit.frame_update(
|
||||
job_id,
|
||||
frame_ref=frame.sequence,
|
||||
timestamp=frame.timestamp,
|
||||
jpeg_b64=_frame_to_b64(frame),
|
||||
boxes=box_dicts,
|
||||
)
|
||||
|
||||
emit.log(job_id, "EdgeDetection", "INFO",
|
||||
f"Found {total_regions} edge regions across {len(frames)} frames")
|
||||
emit.stats(job_id, cv_regions_detected=total_regions)
|
||||
|
||||
return all_boxes
|
||||
|
||||
def serialize(self, output: Any) -> bytes:
|
||||
"""Serialize edge regions to JSON blob."""
|
||||
serialized = {}
|
||||
for seq, boxes in output.items():
|
||||
serialized[str(seq)] = [
|
||||
{"x": b.x, "y": b.y, "w": b.w, "h": b.h,
|
||||
"confidence": b.confidence, "label": b.label}
|
||||
for b in boxes
|
||||
]
|
||||
return json.dumps(serialized).encode()
|
||||
|
||||
def deserialize(self, data: bytes) -> dict[int, list[BoundingBox]]:
|
||||
"""Deserialize edge regions from JSON blob."""
|
||||
raw = json.loads(data)
|
||||
result = {}
|
||||
for seq_str, box_dicts in raw.items():
|
||||
boxes = [
|
||||
BoundingBox(x=b["x"], y=b["y"], w=b["w"], h=b["h"],
|
||||
confidence=b["confidence"], label=b["label"])
|
||||
for b in box_dicts
|
||||
]
|
||||
result[int(seq_str)] = boxes
|
||||
return result
|
||||
|
||||
# --- Private helpers ---
|
||||
|
||||
def _run_remote(self, frame: Frame, config: dict,
|
||||
inference_url: str, job_id: str) -> list[BoundingBox]:
|
||||
from core.detect.inference import InferenceClient
|
||||
from core.detect.emit import _run_log_level
|
||||
|
||||
client = InferenceClient(
|
||||
base_url=inference_url, job_id=job_id, log_level=_run_log_level,
|
||||
)
|
||||
results = client.detect_edges(
|
||||
image=frame.image,
|
||||
edge_canny_low=config.get("edge_canny_low", 50),
|
||||
edge_canny_high=config.get("edge_canny_high", 150),
|
||||
edge_hough_threshold=config.get("edge_hough_threshold", 80),
|
||||
edge_hough_min_length=config.get("edge_hough_min_length", 100),
|
||||
edge_hough_max_gap=config.get("edge_hough_max_gap", 10),
|
||||
edge_pair_max_distance=config.get("edge_pair_max_distance", 200),
|
||||
edge_pair_min_distance=config.get("edge_pair_min_distance", 15),
|
||||
)
|
||||
boxes = []
|
||||
for r in results:
|
||||
box = BoundingBox(
|
||||
x=r.x, y=r.y, w=r.w, h=r.h,
|
||||
confidence=r.confidence, label=r.label,
|
||||
)
|
||||
boxes.append(box)
|
||||
return boxes
|
||||
|
||||
def _run_local(self, frame: Frame, config: dict) -> list[BoundingBox]:
|
||||
detect_edges_fn = _load_cv_edges().detect_edges
|
||||
|
||||
edge_results = detect_edges_fn(
|
||||
frame.image,
|
||||
canny_low=config.get("edge_canny_low", 50),
|
||||
canny_high=config.get("edge_canny_high", 150),
|
||||
hough_threshold=config.get("edge_hough_threshold", 80),
|
||||
hough_min_length=config.get("edge_hough_min_length", 100),
|
||||
hough_max_gap=config.get("edge_hough_max_gap", 10),
|
||||
pair_max_distance=config.get("edge_pair_max_distance", 200),
|
||||
pair_min_distance=config.get("edge_pair_min_distance", 15),
|
||||
)
|
||||
|
||||
boxes = []
|
||||
for r in edge_results:
|
||||
box = BoundingBox(
|
||||
x=r["x"], y=r["y"], w=r["w"], h=r["h"],
|
||||
confidence=r["confidence"], label=r["label"],
|
||||
)
|
||||
boxes.append(box)
|
||||
return boxes
|
||||
|
||||
|
||||
# --- Module-level helpers ---
|
||||
|
||||
def _frame_to_b64(frame: Frame) -> str:
|
||||
img = Image.fromarray(frame.image)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG", quality=70)
|
||||
return base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
|
||||
_cv_edges_mod = None
|
||||
|
||||
def _load_cv_edges():
|
||||
global _cv_edges_mod
|
||||
if _cv_edges_mod is None:
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
spec = importlib.util.spec_from_file_location("cv_edges", Path("core/gpu/models/cv/edges.py"))
|
||||
_cv_edges_mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(_cv_edges_mod)
|
||||
return _cv_edges_mod
|
||||
|
||||
|
||||
# --- Backward compat: standalone function for graph.py ---
|
||||
|
||||
def _filter_by_field_mask(boxes, mask, margin_px=50):
|
||||
"""
|
||||
Keep only boxes that are near the pitch boundary (hoarding zone).
|
||||
|
||||
The field mask has 255=pitch, 0=not pitch. Hoardings sit just
|
||||
outside the pitch boundary. We dilate the mask to create a
|
||||
"boundary zone" and keep boxes whose center falls in the zone
|
||||
between the dilated mask edge and the original mask.
|
||||
"""
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
if mask is None or not boxes:
|
||||
return boxes
|
||||
|
||||
# Dilate the pitch mask — the expansion zone is where hoardings are
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (margin_px * 2, margin_px * 2))
|
||||
dilated = cv2.dilate(mask, kernel)
|
||||
|
||||
# Boundary zone = dilated but NOT original pitch
|
||||
boundary_zone = cv2.bitwise_and(dilated, cv2.bitwise_not(mask))
|
||||
|
||||
kept = []
|
||||
for box in boxes:
|
||||
cx = box.x + box.w // 2
|
||||
cy = box.y + box.h // 2
|
||||
# Clamp to image bounds
|
||||
cy = min(cy, boundary_zone.shape[0] - 1)
|
||||
cx = min(cx, boundary_zone.shape[1] - 1)
|
||||
if boundary_zone[cy, cx] > 0:
|
||||
kept.append(box)
|
||||
|
||||
return kept
|
||||
|
||||
|
||||
def detect_edge_regions(frames, config, inference_url=None, job_id=None, field_masks=None):
|
||||
"""Convenience wrapper — calls EdgeDetectionStage.run(), optionally filters by field mask."""
|
||||
stage = EdgeDetectionStage()
|
||||
cfg = {
|
||||
"enabled": config.enabled,
|
||||
"edge_canny_low": config.edge_canny_low,
|
||||
"edge_canny_high": config.edge_canny_high,
|
||||
"edge_hough_threshold": config.edge_hough_threshold,
|
||||
"edge_hough_min_length": config.edge_hough_min_length,
|
||||
"edge_hough_max_gap": config.edge_hough_max_gap,
|
||||
"edge_pair_max_distance": config.edge_pair_max_distance,
|
||||
"edge_pair_min_distance": config.edge_pair_min_distance,
|
||||
"inference_url": inference_url,
|
||||
"job_id": job_id,
|
||||
}
|
||||
all_boxes = stage.run(frames, cfg)
|
||||
|
||||
# Filter by field segmentation mask if available
|
||||
if field_masks:
|
||||
filtered_total = 0
|
||||
original_total = sum(len(b) for b in all_boxes.values())
|
||||
for seq, boxes in all_boxes.items():
|
||||
mask = field_masks.get(seq)
|
||||
if mask is not None:
|
||||
all_boxes[seq] = _filter_by_field_mask(boxes, mask)
|
||||
filtered_total += len(all_boxes[seq])
|
||||
else:
|
||||
filtered_total += len(boxes)
|
||||
|
||||
if original_total != filtered_total:
|
||||
from core.detect import emit
|
||||
emit.log(job_id, "EdgeDetection", "INFO",
|
||||
f"Field mask filter: {original_total} → {filtered_total} regions")
|
||||
|
||||
return all_boxes
|
||||
141
core/detect/stages/field_segmentation.py
Normal file
141
core/detect/stages/field_segmentation.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
Stage — Field Segmentation
|
||||
|
||||
Calls the GPU inference server to detect pitch boundaries via
|
||||
HSV green mask + morphology. The CV code lives in core/gpu/models/cv/.
|
||||
|
||||
Outputs a mask and boundary that downstream stages use as spatial priors.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from core.detect import emit
|
||||
from core.detect.models import Frame
|
||||
from core.detect.stages.base import Stage
|
||||
from core.detect.stages.models import (
|
||||
FieldSegmentationConfig,
|
||||
StageConfigField,
|
||||
StageDefinition,
|
||||
StageIO,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FieldSegmentationStage(Stage):
|
||||
|
||||
definition = StageDefinition(
|
||||
name="field_segmentation",
|
||||
label="Field Segmentation",
|
||||
description="HSV green mask — detect pitch boundaries for spatial priors",
|
||||
category="cv_analysis",
|
||||
io=StageIO(
|
||||
reads=["filtered_frames"],
|
||||
writes=["field_mask"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField(name="enabled", type="bool", default=True, description="Enable field segmentation"),
|
||||
StageConfigField(name="hue_low", type="int", default=30, description="HSV hue lower bound", min=0, max=180),
|
||||
StageConfigField(name="hue_high", type="int", default=85, description="HSV hue upper bound", min=0, max=180),
|
||||
StageConfigField(name="sat_low", type="int", default=30, description="HSV saturation lower bound", min=0, max=255),
|
||||
StageConfigField(name="sat_high", type="int", default=255, description="HSV saturation upper bound", min=0, max=255),
|
||||
StageConfigField(name="val_low", type="int", default=30, description="HSV value lower bound", min=0, max=255),
|
||||
StageConfigField(name="val_high", type="int", default=255, description="HSV value upper bound", min=0, max=255),
|
||||
StageConfigField(name="morph_kernel", type="int", default=15, description="Morphology kernel size", min=3, max=51),
|
||||
StageConfigField(name="min_area_ratio", type="float", default=0.05, description="Min contour area as fraction of frame", min=0.01, max=0.5),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _frame_to_b64(frame: Frame) -> str:
|
||||
"""Encode frame image as base64 JPEG."""
|
||||
img = Image.fromarray(frame.image)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG", quality=85)
|
||||
return base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
|
||||
def _decode_mask_b64(mask_b64: str) -> np.ndarray:
|
||||
"""Decode a base64 PNG mask back to numpy array."""
|
||||
data = base64.b64decode(mask_b64)
|
||||
img = Image.open(io.BytesIO(data)).convert("L")
|
||||
return np.array(img)
|
||||
|
||||
|
||||
def run_field_segmentation(
|
||||
frames: list[Frame],
|
||||
config: FieldSegmentationConfig,
|
||||
inference_url: str | None = None,
|
||||
job_id: str | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Run field segmentation on all frames via the inference server.
|
||||
|
||||
Returns dict with:
|
||||
field_masks: {seq: np.ndarray}
|
||||
field_boundaries: {seq: [(x,y), ...]}
|
||||
field_coverage: {seq: float}
|
||||
"""
|
||||
if not config.enabled:
|
||||
emit.log(job_id, "FieldSegmentation", "INFO", "Disabled, skipping")
|
||||
return {"field_masks": {}, "field_boundaries": {}, "field_coverage": {}}
|
||||
|
||||
import os
|
||||
url = inference_url or os.environ.get("INFERENCE_URL")
|
||||
if not url:
|
||||
emit.log(job_id, "FieldSegmentation", "WARNING",
|
||||
"No INFERENCE_URL, skipping field segmentation")
|
||||
return {"field_masks": {}, "field_boundaries": {}, "field_coverage": {}}
|
||||
|
||||
emit.log(job_id, "FieldSegmentation", "INFO",
|
||||
f"Segmenting {len(frames)} frames (hue={config.hue_low}-{config.hue_high})")
|
||||
|
||||
from core.detect.inference import InferenceClient
|
||||
from core.detect.emit import _run_log_level
|
||||
client = InferenceClient(base_url=url, job_id=job_id or "", log_level=_run_log_level)
|
||||
|
||||
field_masks = {}
|
||||
field_boundaries = {}
|
||||
field_coverage = {}
|
||||
|
||||
for frame in frames:
|
||||
image_b64 = _frame_to_b64(frame)
|
||||
|
||||
resp = client.post("/segment_field", {
|
||||
"image": image_b64,
|
||||
"hue_low": config.hue_low,
|
||||
"hue_high": config.hue_high,
|
||||
"sat_low": config.sat_low,
|
||||
"sat_high": config.sat_high,
|
||||
"val_low": config.val_low,
|
||||
"val_high": config.val_high,
|
||||
"morph_kernel": config.morph_kernel,
|
||||
"min_area_ratio": config.min_area_ratio,
|
||||
})
|
||||
|
||||
if resp is None:
|
||||
continue
|
||||
|
||||
mask_b64 = resp.get("mask_b64", "")
|
||||
if mask_b64:
|
||||
field_masks[frame.sequence] = _decode_mask_b64(mask_b64)
|
||||
|
||||
field_boundaries[frame.sequence] = resp.get("boundary", [])
|
||||
field_coverage[frame.sequence] = resp.get("coverage", 0.0)
|
||||
|
||||
avg_coverage = sum(field_coverage.values()) / max(len(field_coverage), 1)
|
||||
emit.log(job_id, "FieldSegmentation", "INFO",
|
||||
f"Done: {len(frames)} frames, avg coverage {avg_coverage:.1%}")
|
||||
|
||||
return {
|
||||
"field_masks": field_masks,
|
||||
"field_boundaries": field_boundaries,
|
||||
"field_coverage": field_coverage,
|
||||
}
|
||||
93
core/detect/stages/frame_extractor.py
Normal file
93
core/detect/stages/frame_extractor.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
Stage 1 — Frame Extraction
|
||||
|
||||
Extracts frames from a video at a configurable FPS using the core ffmpeg module.
|
||||
Emits log + stats_update SSE events as it works.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import ffmpeg
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from core.ffmpeg.probe import probe_file
|
||||
from core.detect import emit
|
||||
from core.detect.models import Frame
|
||||
from core.detect.stages.models import FrameExtractionConfig
|
||||
|
||||
|
||||
def _load_frames(tmpdir: Path, fps: float) -> list[Frame]:
|
||||
"""Load extracted JPEG files into Frame objects."""
|
||||
frame_files = sorted(tmpdir.glob("frame_*.jpg"))
|
||||
frames = []
|
||||
for i, fpath in enumerate(frame_files):
|
||||
img = Image.open(fpath)
|
||||
frame = Frame(
|
||||
sequence=i,
|
||||
chunk_id=0,
|
||||
timestamp=i / fps,
|
||||
image=np.array(img),
|
||||
)
|
||||
frames.append(frame)
|
||||
return frames
|
||||
|
||||
|
||||
def extract_frames(
|
||||
video_path: str,
|
||||
config: FrameExtractionConfig,
|
||||
job_id: str | None = None,
|
||||
) -> list[Frame]:
|
||||
"""
|
||||
Extract frames from video at the configured FPS.
|
||||
|
||||
Uses ffmpeg-python to build the extraction pipeline,
|
||||
outputs JPEG files to a temp dir, then loads as numpy arrays.
|
||||
"""
|
||||
probe = probe_file(video_path)
|
||||
duration = probe.duration or 0.0
|
||||
|
||||
emit.log(job_id, "FrameExtractor", "INFO",
|
||||
f"Starting extraction: {Path(video_path).name} "
|
||||
f"({duration:.1f}s, {probe.width}x{probe.height}, fps={config.fps})")
|
||||
emit.log(job_id, "FrameExtractor", "DEBUG",
|
||||
f"Probe: codec={probe.video_codec}, bitrate={probe.video_bitrate}, max_frames={config.max_frames}")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pattern = str(Path(tmpdir) / "frame_%06d.jpg")
|
||||
|
||||
stream = (
|
||||
ffmpeg
|
||||
.input(video_path)
|
||||
.filter("fps", fps=config.fps)
|
||||
.output(pattern, qscale=2, frames=config.max_frames)
|
||||
.overwrite_output()
|
||||
)
|
||||
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
stream.run(capture_stdout=True, capture_stderr=True, quiet=True)
|
||||
except ffmpeg.Error as e:
|
||||
stderr = e.stderr.decode() if e.stderr else "unknown error"
|
||||
emit.log(job_id, "FrameExtractor", "ERROR", f"FFmpeg failed: {stderr[:200]}")
|
||||
raise RuntimeError(f"FFmpeg failed: {stderr}") from e
|
||||
ffmpeg_ms = (time.monotonic() - t0) * 1000
|
||||
emit.log(job_id, "FrameExtractor", "DEBUG", f"FFmpeg decode: {ffmpeg_ms:.0f}ms")
|
||||
|
||||
t0 = time.monotonic()
|
||||
frames = _load_frames(Path(tmpdir), config.fps)
|
||||
load_ms = (time.monotonic() - t0) * 1000
|
||||
if frames:
|
||||
h, w = frames[0].image.shape[:2]
|
||||
mem_mb = sum(f.image.nbytes for f in frames) / (1024 * 1024)
|
||||
emit.log(job_id, "FrameExtractor", "DEBUG",
|
||||
f"Loaded {len(frames)} frames ({w}x{h}) in {load_ms:.0f}ms, {mem_mb:.1f}MB in memory")
|
||||
|
||||
emit.log(job_id, "FrameExtractor", "INFO", f"Extracted {len(frames)} frames")
|
||||
emit.stats(job_id, frames_extracted=len(frames))
|
||||
|
||||
return frames
|
||||
106
core/detect/stages/models.py
Normal file
106
core/detect/stages/models.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""
|
||||
Pydantic Models - GENERATED FILE
|
||||
|
||||
Do not edit directly. Regenerate using modelgen.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class StageConfigField(BaseModel):
|
||||
"""A single tunable config parameter for the editor UI."""
|
||||
name: str
|
||||
type: str
|
||||
default: Any
|
||||
description: str = ""
|
||||
min: Optional[float] = None
|
||||
max: Optional[float] = None
|
||||
options: Optional[List[str]] = None
|
||||
|
||||
class StageIO(BaseModel):
|
||||
"""Declares what a stage reads and writes."""
|
||||
reads: List[str] = Field(default_factory=list)
|
||||
writes: List[str] = Field(default_factory=list)
|
||||
optional_reads: List[str] = Field(default_factory=list)
|
||||
|
||||
class StageDefinition(BaseModel):
|
||||
"""Complete metadata for a pipeline stage."""
|
||||
name: str
|
||||
label: str
|
||||
description: str
|
||||
category: str = "detection"
|
||||
io: StageIO
|
||||
config_fields: List[StageConfigField] = Field(default_factory=list)
|
||||
tracks_element: Optional[str] = None
|
||||
|
||||
class FrameExtractionConfig(BaseModel):
|
||||
"""FrameExtractionConfig(fps: float = 2.0, max_frames: int = 500)"""
|
||||
fps: float = 2.0
|
||||
max_frames: int = 500
|
||||
|
||||
class SceneFilterConfig(BaseModel):
|
||||
"""SceneFilterConfig(hamming_threshold: int = 8, enabled: bool = True)"""
|
||||
hamming_threshold: int = 8
|
||||
enabled: bool = True
|
||||
|
||||
class DetectionConfig(BaseModel):
|
||||
"""DetectionConfig(model_name: str = 'yolov8n.pt', confidence_threshold: float = 0.3, target_classes: List[str] = <factory>)"""
|
||||
model_name: str = "yolov8n.pt"
|
||||
confidence_threshold: float = 0.3
|
||||
target_classes: List[str]
|
||||
|
||||
class OCRConfig(BaseModel):
|
||||
"""OCRConfig(languages: List[str] = <factory>, min_confidence: float = 0.5)"""
|
||||
languages: List[str]
|
||||
min_confidence: float = 0.5
|
||||
|
||||
class ResolverConfig(BaseModel):
|
||||
"""ResolverConfig(fuzzy_threshold: int = 75)"""
|
||||
fuzzy_threshold: int = 75
|
||||
|
||||
class RegionAnalysisConfig(BaseModel):
|
||||
"""RegionAnalysisConfig(enabled: bool = True, edge_canny_low: int = 50, edge_canny_high: int = 150, edge_hough_threshold: int = 80, edge_hough_min_length: int = 100, edge_hough_max_gap: int = 10, edge_pair_max_distance: int = 200, edge_pair_min_distance: int = 15)"""
|
||||
enabled: bool = True
|
||||
edge_canny_low: int = 50
|
||||
edge_canny_high: int = 150
|
||||
edge_hough_threshold: int = 80
|
||||
edge_hough_min_length: int = 100
|
||||
edge_hough_max_gap: int = 10
|
||||
edge_pair_max_distance: int = 200
|
||||
edge_pair_min_distance: int = 15
|
||||
|
||||
class FieldSegmentationConfig(BaseModel):
|
||||
"""FieldSegmentationConfig(enabled: bool = True, hue_low: int = 30, hue_high: int = 85, sat_low: int = 30, sat_high: int = 255, val_low: int = 30, val_high: int = 255, morph_kernel: int = 15, min_area_ratio: float = 0.05)"""
|
||||
enabled: bool = True
|
||||
hue_low: int = 30
|
||||
hue_high: int = 85
|
||||
sat_low: int = 30
|
||||
sat_high: int = 255
|
||||
val_low: int = 30
|
||||
val_high: int = 255
|
||||
morph_kernel: int = 15
|
||||
min_area_ratio: float = 0.05
|
||||
|
||||
class StageRef(BaseModel):
|
||||
"""Reference to a stage in the pipeline graph."""
|
||||
name: str
|
||||
branch: str = "trunk"
|
||||
execution_target: str = "local"
|
||||
|
||||
class Edge(BaseModel):
|
||||
"""Connection between stages in the graph."""
|
||||
source: str
|
||||
target: str
|
||||
condition: str = ""
|
||||
|
||||
class PipelineConfig(BaseModel):
|
||||
"""Pipeline graph topology + routing rules."""
|
||||
name: str
|
||||
profile_name: str
|
||||
stages: List[StageRef] = Field(default_factory=list)
|
||||
edges: List[Edge] = Field(default_factory=list)
|
||||
routing_rules: Dict[str, Any]
|
||||
139
core/detect/stages/ocr_stage.py
Normal file
139
core/detect/stages/ocr_stage.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""
|
||||
Stage 4 — OCR
|
||||
|
||||
Reads text from detected regions (YOLO bounding box crops).
|
||||
Two modes:
|
||||
- remote: calls inference server over HTTP (separate GPU box, or localhost)
|
||||
- local: runs PaddleOCR in-process (single-box setup with enough VRAM)
|
||||
|
||||
The mode is selected by whether inference_url is provided.
|
||||
Model instances are cached at module level so they survive across pipeline runs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
from core.detect import emit
|
||||
from core.detect.models import BoundingBox, Frame, TextCandidate
|
||||
from core.detect.stages.models import OCRConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Module-level cache — avoids reloading the model for every crop or pipeline run
|
||||
_local_ocr_cache: dict[str, object] = {}
|
||||
|
||||
|
||||
def _crop_region(frame: Frame, box: BoundingBox) -> np.ndarray:
|
||||
h, w = frame.image.shape[:2]
|
||||
x1 = max(0, box.x)
|
||||
y1 = max(0, box.y)
|
||||
x2 = min(w, box.x + box.w)
|
||||
y2 = min(h, box.y + box.h)
|
||||
return frame.image[y1:y2, x1:x2]
|
||||
|
||||
|
||||
def _get_local_model(lang: str):
|
||||
if lang not in _local_ocr_cache:
|
||||
from paddleocr import PaddleOCR
|
||||
logger.info("Loading PaddleOCR locally (lang=%s)", lang)
|
||||
_local_ocr_cache[lang] = PaddleOCR(lang=lang)
|
||||
return _local_ocr_cache[lang]
|
||||
|
||||
|
||||
def _parse_ocr_raw(raw, min_confidence: float) -> list[dict]:
|
||||
"""Parse PaddleOCR 3.x result — handles dict-based and nested-list layouts."""
|
||||
results = []
|
||||
for page in (raw or []):
|
||||
if not page:
|
||||
continue
|
||||
if isinstance(page, dict):
|
||||
for text, confidence in zip(page.get("rec_texts", []), page.get("rec_scores", [])):
|
||||
if float(confidence) >= min_confidence:
|
||||
results.append({"text": text, "confidence": float(confidence)})
|
||||
continue
|
||||
for line in page:
|
||||
if not line:
|
||||
continue
|
||||
rec = line[1]
|
||||
if isinstance(rec, (list, tuple)) and len(rec) >= 2:
|
||||
text, confidence = rec[0], rec[1]
|
||||
if float(confidence) >= min_confidence:
|
||||
results.append({"text": text, "confidence": float(confidence)})
|
||||
return results
|
||||
|
||||
|
||||
def run_ocr(
|
||||
frames: list[Frame],
|
||||
boxes_by_frame: dict[int, list[BoundingBox]],
|
||||
config: OCRConfig,
|
||||
inference_url: str | None = None,
|
||||
job_id: str | None = None,
|
||||
) -> list[TextCandidate]:
|
||||
"""
|
||||
Run OCR on cropped regions from YOLO detections.
|
||||
|
||||
inference_url=None → local in-process PaddleOCR (single-box)
|
||||
inference_url=str → remote inference server (split or localhost)
|
||||
"""
|
||||
total_regions = sum(len(boxes) for boxes in boxes_by_frame.values())
|
||||
mode = "remote" if inference_url else "local"
|
||||
|
||||
emit.log(job_id, "OCRStage", "INFO",
|
||||
f"Running OCR on {total_regions} regions (mode={mode})")
|
||||
|
||||
# Build these once per pipeline run, not per crop
|
||||
if inference_url:
|
||||
from core.detect.inference import InferenceClient
|
||||
from core.detect.emit import _run_log_level
|
||||
client = InferenceClient(base_url=inference_url, job_id=job_id or "", log_level=_run_log_level)
|
||||
else:
|
||||
model = _get_local_model(config.languages[0])
|
||||
|
||||
frame_map = {f.sequence: f for f in frames}
|
||||
candidates: list[TextCandidate] = []
|
||||
|
||||
for seq, boxes in boxes_by_frame.items():
|
||||
frame = frame_map.get(seq)
|
||||
if not frame:
|
||||
continue
|
||||
|
||||
for box in boxes:
|
||||
crop = _crop_region(frame, box)
|
||||
if crop.size == 0:
|
||||
continue
|
||||
|
||||
t0 = time.monotonic()
|
||||
if inference_url:
|
||||
raw_results = client.ocr(image=crop, languages=config.languages)
|
||||
texts = [{"text": r.text, "confidence": r.confidence} for r in raw_results]
|
||||
else:
|
||||
raw = model.ocr(crop)
|
||||
texts = _parse_ocr_raw(raw, config.min_confidence)
|
||||
ocr_ms = (time.monotonic() - t0) * 1000
|
||||
|
||||
h, w = crop.shape[:2]
|
||||
text_preview = ", ".join(t["text"][:30] for t in texts) if texts else "(none)"
|
||||
emit.log(job_id, "OCRStage", "DEBUG",
|
||||
f"Frame {seq} box {box.x},{box.y} ({w}x{h}): {ocr_ms:.0f}ms → {text_preview}")
|
||||
|
||||
for t in texts:
|
||||
candidates.append(TextCandidate(
|
||||
frame=frame,
|
||||
bbox=box,
|
||||
text=t["text"],
|
||||
ocr_confidence=t["confidence"],
|
||||
))
|
||||
|
||||
emit.log(job_id, "OCRStage", "INFO",
|
||||
f"Extracted text from {len(candidates)} regions")
|
||||
emit.stats(job_id, regions_resolved_by_ocr=len(candidates))
|
||||
|
||||
return candidates
|
||||
128
core/detect/stages/preprocess.py
Normal file
128
core/detect/stages/preprocess.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
Stage 3.5 — Preprocessing
|
||||
|
||||
Runs between YOLO detection and OCR. Applies configurable image
|
||||
preprocessing to each detected region crop: contrast enhancement,
|
||||
deskewing, binarization.
|
||||
|
||||
Operates on the crops derived from boxes_by_frame, produces
|
||||
preprocessed_crops keyed by (frame_sequence, box_index).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
from core.detect import emit
|
||||
from core.detect.models import BoundingBox, Frame
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _crop_region(frame: Frame, box: BoundingBox) -> np.ndarray:
|
||||
h, w = frame.image.shape[:2]
|
||||
x1 = max(0, box.x)
|
||||
y1 = max(0, box.y)
|
||||
x2 = min(w, box.x + box.w)
|
||||
y2 = min(h, box.y + box.h)
|
||||
return frame.image[y1:y2, x1:x2]
|
||||
|
||||
|
||||
def preprocess_regions(
|
||||
frames: list[Frame],
|
||||
boxes_by_frame: dict[int, list[BoundingBox]],
|
||||
do_contrast: bool = True,
|
||||
do_deskew: bool = False,
|
||||
do_binarize: bool = False,
|
||||
inference_url: str | None = None,
|
||||
job_id: str | None = None,
|
||||
) -> dict[str, np.ndarray]:
|
||||
"""
|
||||
Preprocess cropped regions from YOLO detections.
|
||||
|
||||
Returns dict keyed by "{frame_seq}_{box_idx}" → preprocessed crop.
|
||||
These are passed to the OCR stage instead of raw crops.
|
||||
"""
|
||||
total_regions = sum(len(boxes) for boxes in boxes_by_frame.values())
|
||||
any_active = do_contrast or do_deskew or do_binarize
|
||||
|
||||
if not any_active:
|
||||
emit.log(job_id, "Preprocess", "INFO",
|
||||
f"Preprocessing disabled, passing {total_regions} regions through")
|
||||
return {}
|
||||
|
||||
mode = "remote" if inference_url else "local"
|
||||
emit.log(job_id, "Preprocess", "INFO",
|
||||
f"Preprocessing {total_regions} regions (mode={mode}, "
|
||||
f"contrast={do_contrast}, deskew={do_deskew}, binarize={do_binarize})")
|
||||
|
||||
frame_map = {f.sequence: f for f in frames}
|
||||
preprocessed: dict[str, np.ndarray] = {}
|
||||
processed_count = 0
|
||||
|
||||
for seq, boxes in boxes_by_frame.items():
|
||||
frame = frame_map.get(seq)
|
||||
if not frame:
|
||||
continue
|
||||
|
||||
for idx, box in enumerate(boxes):
|
||||
crop = _crop_region(frame, box)
|
||||
if crop.size == 0:
|
||||
continue
|
||||
|
||||
key = f"{seq}_{idx}"
|
||||
|
||||
if inference_url:
|
||||
result = _preprocess_remote(crop, inference_url,
|
||||
do_contrast, do_deskew, do_binarize)
|
||||
else:
|
||||
result = _preprocess_local(crop, do_contrast, do_deskew, do_binarize)
|
||||
|
||||
preprocessed[key] = result
|
||||
processed_count += 1
|
||||
|
||||
emit.log(job_id, "Preprocess", "INFO",
|
||||
f"Preprocessed {processed_count} regions")
|
||||
|
||||
return preprocessed
|
||||
|
||||
|
||||
def _preprocess_remote(crop: np.ndarray, inference_url: str,
|
||||
do_contrast: bool, do_deskew: bool, do_binarize: bool) -> np.ndarray:
|
||||
"""Call GPU server /preprocess endpoint."""
|
||||
import base64
|
||||
import io
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
img = Image.fromarray(crop)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG", quality=85)
|
||||
image_b64 = base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
resp = requests.post(
|
||||
f"{inference_url.rstrip('/')}/preprocess",
|
||||
json={
|
||||
"image": image_b64,
|
||||
"contrast": do_contrast,
|
||||
"deskew": do_deskew,
|
||||
"binarize": do_binarize,
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
result_bytes = base64.b64decode(data["image"])
|
||||
result_img = Image.open(io.BytesIO(result_bytes)).convert("RGB")
|
||||
return np.array(result_img)
|
||||
|
||||
|
||||
def _preprocess_local(crop: np.ndarray,
|
||||
do_contrast: bool, do_deskew: bool, do_binarize: bool) -> np.ndarray:
|
||||
"""Run preprocessing in-process (requires opencv-python-headless)."""
|
||||
from core.gpu.models.preprocess import preprocess
|
||||
return preprocess(crop, do_binarize=do_binarize, do_deskew=do_deskew, do_contrast=do_contrast)
|
||||
31
core/detect/stages/registry/__init__.py
Normal file
31
core/detect/stages/registry/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
Stage registry — registers all built-in stages.
|
||||
|
||||
Split by category:
|
||||
preprocessing.py — extract_frames, filter_scenes
|
||||
cv_analysis.py — detect_edges (+ future: detect_contours, detect_color, merge_regions)
|
||||
detection.py — detect_objects, run_ocr
|
||||
resolution.py — match_brands
|
||||
escalation.py — escalate_vlm, escalate_cloud
|
||||
output.py — compile_report
|
||||
_serializers.py — shared serialization helpers
|
||||
"""
|
||||
|
||||
from . import preprocessing
|
||||
from . import cv_analysis
|
||||
from . import detection
|
||||
from . import resolution
|
||||
from . import escalation
|
||||
from . import output
|
||||
|
||||
|
||||
def register_all():
|
||||
preprocessing.register()
|
||||
cv_analysis.register()
|
||||
detection.register()
|
||||
resolution.register()
|
||||
escalation.register()
|
||||
output.register()
|
||||
|
||||
|
||||
register_all()
|
||||
25
core/detect/stages/registry/_serializers.py
Normal file
25
core/detect/stages/registry/_serializers.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""
|
||||
Re-export serializers from core/schema/serializers/.
|
||||
|
||||
Stage registry modules import from here for convenience.
|
||||
All serialization logic lives in core/schema/serializers/.
|
||||
"""
|
||||
|
||||
from core.schema.serializers._common import (
|
||||
safe_construct,
|
||||
serialize_dataclass,
|
||||
serialize_dataclass_list,
|
||||
)
|
||||
from core.schema.serializers.pipeline import (
|
||||
serialize_frame_meta,
|
||||
serialize_frames_with_upload as serialize_frames,
|
||||
deserialize_frames_with_download as deserialize_frames,
|
||||
serialize_text_candidate,
|
||||
serialize_text_candidates,
|
||||
deserialize_text_candidate,
|
||||
deserialize_text_candidates,
|
||||
deserialize_bounding_box,
|
||||
deserialize_brand_detection,
|
||||
deserialize_pipeline_stats,
|
||||
deserialize_detection_report,
|
||||
)
|
||||
44
core/detect/stages/registry/cv_analysis.py
Normal file
44
core/detect/stages/registry/cv_analysis.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Registration for CV analysis stages: edge detection."""
|
||||
|
||||
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
|
||||
from core.detect.stages.base import register_stage
|
||||
from ._serializers import serialize_dataclass_list, deserialize_bounding_box
|
||||
|
||||
|
||||
def _ser_regions(state: dict, job_id: str) -> dict:
|
||||
regions = state.get("edge_regions_by_frame", {})
|
||||
serialized = {
|
||||
str(seq): serialize_dataclass_list(bl) for seq, bl in regions.items()
|
||||
}
|
||||
return {"edge_regions_by_frame": serialized}
|
||||
|
||||
|
||||
def _deser_regions(data: dict, job_id: str) -> dict:
|
||||
regions = {}
|
||||
for seq_str, box_dicts in data.get("edge_regions_by_frame", {}).items():
|
||||
regions[int(seq_str)] = [deserialize_bounding_box(b) for b in box_dicts]
|
||||
return {"edge_regions_by_frame": regions}
|
||||
|
||||
|
||||
def register():
|
||||
edge_detection = StageDefinition(
|
||||
name="detect_edges",
|
||||
label="Edge Detection",
|
||||
description="Canny + HoughLinesP — find horizontal line pairs (hoarding boundaries)",
|
||||
category="cv_analysis",
|
||||
io=StageIO(
|
||||
reads=["filtered_frames"],
|
||||
writes=["edge_regions_by_frame"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField(name="enabled", type="bool", default=True, description="Enable region analysis"),
|
||||
StageConfigField(name="edge_canny_low", type="int", default=50, description="Canny low threshold", min=0, max=255),
|
||||
StageConfigField(name="edge_canny_high", type="int", default=150, description="Canny high threshold", min=0, max=255),
|
||||
StageConfigField(name="edge_hough_threshold", type="int", default=80, description="Hough accumulator threshold", min=1, max=500),
|
||||
StageConfigField(name="edge_hough_min_length", type="int", default=100, description="Min line length (px)", min=10, max=2000),
|
||||
StageConfigField(name="edge_hough_max_gap", type="int", default=10, description="Max line gap (px)", min=1, max=100),
|
||||
StageConfigField(name="edge_pair_max_distance", type="int", default=200, description="Max distance between line pair (px)", min=10, max=500),
|
||||
StageConfigField(name="edge_pair_min_distance", type="int", default=15, description="Min distance between line pair (px)", min=5, max=200),
|
||||
],
|
||||
)
|
||||
register_stage(edge_detection, serialize_fn=_ser_regions, deserialize_fn=_deser_regions)
|
||||
60
core/detect/stages/registry/detection.py
Normal file
60
core/detect/stages/registry/detection.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""Registration for detection stages: YOLO, OCR."""
|
||||
|
||||
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
|
||||
from core.detect.stages.base import register_stage
|
||||
from ._serializers import (
|
||||
serialize_dataclass_list,
|
||||
serialize_text_candidates,
|
||||
deserialize_bounding_box,
|
||||
)
|
||||
|
||||
|
||||
def _ser_detect(state: dict, job_id: str) -> dict:
|
||||
boxes = state.get("boxes_by_frame", {})
|
||||
serialized = {str(seq): serialize_dataclass_list(bl) for seq, bl in boxes.items()}
|
||||
return {"boxes_by_frame": serialized}
|
||||
|
||||
|
||||
def _deser_detect(data: dict, job_id: str) -> dict:
|
||||
boxes = {}
|
||||
for seq_str, box_dicts in data.get("boxes_by_frame", {}).items():
|
||||
boxes[int(seq_str)] = [deserialize_bounding_box(b) for b in box_dicts]
|
||||
return {"boxes_by_frame": boxes}
|
||||
|
||||
|
||||
def _ser_ocr(state: dict, job_id: str) -> dict:
|
||||
candidates = state.get("text_candidates", [])
|
||||
return {"text_candidates": serialize_text_candidates(candidates)}
|
||||
|
||||
|
||||
def _deser_ocr(data: dict, job_id: str) -> dict:
|
||||
return {"_text_candidates_raw": data["text_candidates"]}
|
||||
|
||||
|
||||
def register():
|
||||
yolo = StageDefinition(
|
||||
name="detect_objects",
|
||||
label="Object Detection",
|
||||
description="YOLO object detection on filtered frames",
|
||||
category="detection",
|
||||
io=StageIO(reads=["filtered_frames"], writes=["boxes_by_frame"]),
|
||||
config_fields=[
|
||||
StageConfigField(name="model_name", type="str", default="yolov8n.pt", description="YOLO model file"),
|
||||
StageConfigField(name="confidence_threshold", type="float", default=0.3, description="Min detection confidence", min=0.0, max=1.0),
|
||||
StageConfigField(name="target_classes", type="list[str]", default=[], description="YOLO classes to detect (empty = all)"),
|
||||
],
|
||||
)
|
||||
register_stage(yolo, serialize_fn=_ser_detect, deserialize_fn=_deser_detect)
|
||||
|
||||
ocr = StageDefinition(
|
||||
name="run_ocr",
|
||||
label="OCR",
|
||||
description="Extract text from detected regions",
|
||||
category="detection",
|
||||
io=StageIO(reads=["filtered_frames", "boxes_by_frame"], writes=["text_candidates"]),
|
||||
config_fields=[
|
||||
StageConfigField(name="languages", type="list[str]", default=["en"], description="OCR languages"),
|
||||
StageConfigField(name="min_confidence", type="float", default=0.5, description="Min OCR confidence", min=0.0, max=1.0),
|
||||
],
|
||||
)
|
||||
register_stage(ocr, serialize_fn=_ser_ocr, deserialize_fn=_deser_ocr)
|
||||
60
core/detect/stages/registry/escalation.py
Normal file
60
core/detect/stages/registry/escalation.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""Registration for escalation stages: local VLM, cloud LLM."""
|
||||
|
||||
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
|
||||
from core.detect.stages.base import register_stage
|
||||
from ._serializers import (
|
||||
serialize_dataclass_list,
|
||||
serialize_text_candidates,
|
||||
deserialize_brand_detection,
|
||||
)
|
||||
|
||||
|
||||
def _ser_escalation(state: dict, job_id: str) -> dict:
|
||||
matched = state.get("detections", [])
|
||||
unresolved = state.get("unresolved_candidates", [])
|
||||
return {
|
||||
"detections": serialize_dataclass_list(matched),
|
||||
"unresolved_candidates": serialize_text_candidates(unresolved),
|
||||
}
|
||||
|
||||
|
||||
def _deser_escalation(data: dict, job_id: str) -> dict:
|
||||
detections = [deserialize_brand_detection(d) for d in data.get("detections", [])]
|
||||
return {
|
||||
"detections": detections,
|
||||
"_unresolved_raw": data.get("unresolved_candidates", []),
|
||||
}
|
||||
|
||||
|
||||
def register():
|
||||
vlm = StageDefinition(
|
||||
name="escalate_vlm",
|
||||
label="Local VLM",
|
||||
description="Process unresolved crops with moondream2",
|
||||
category="escalation",
|
||||
io=StageIO(
|
||||
reads=["unresolved_candidates"],
|
||||
writes=["detections", "unresolved_candidates"],
|
||||
optional_reads=["source_asset_id"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField(name="min_confidence", type="float", default=0.5, description="Min VLM confidence", min=0.0, max=1.0),
|
||||
],
|
||||
)
|
||||
register_stage(vlm, serialize_fn=_ser_escalation, deserialize_fn=_deser_escalation)
|
||||
|
||||
cloud = StageDefinition(
|
||||
name="escalate_cloud",
|
||||
label="Cloud LLM",
|
||||
description="Escalate remaining crops to cloud provider",
|
||||
category="escalation",
|
||||
io=StageIO(
|
||||
reads=["unresolved_candidates"],
|
||||
writes=["detections"],
|
||||
optional_reads=["source_asset_id"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField(name="min_confidence", type="float", default=0.4, description="Min cloud confidence", min=0.0, max=1.0),
|
||||
],
|
||||
)
|
||||
register_stage(cloud, serialize_fn=_ser_escalation, deserialize_fn=_deser_escalation)
|
||||
30
core/detect/stages/registry/output.py
Normal file
30
core/detect/stages/registry/output.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Registration for output stages: report compilation."""
|
||||
|
||||
from core.detect.stages.base import StageDefinition, StageIO, register_stage
|
||||
from ._serializers import serialize_dataclass, deserialize_detection_report
|
||||
|
||||
|
||||
def _ser_report(state: dict, job_id: str) -> dict:
|
||||
report = state.get("report")
|
||||
if report is None:
|
||||
return {"report": None}
|
||||
return {"report": serialize_dataclass(report)}
|
||||
|
||||
|
||||
def _deser_report(data: dict, job_id: str) -> dict:
|
||||
raw = data.get("report")
|
||||
if raw is None:
|
||||
return {"report": None}
|
||||
return {"report": deserialize_detection_report(raw)}
|
||||
|
||||
|
||||
def register():
|
||||
report = StageDefinition(
|
||||
name="compile_report",
|
||||
label="Report",
|
||||
description="Merge detections and compile final report",
|
||||
category="output",
|
||||
io=StageIO(reads=["detections"], writes=["report"]),
|
||||
config_fields=[],
|
||||
)
|
||||
register_stage(report, serialize_fn=_ser_report, deserialize_fn=_deser_report)
|
||||
82
core/detect/stages/registry/preprocessing.py
Normal file
82
core/detect/stages/registry/preprocessing.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Registration for preprocessing stages: frame extraction, scene filter, image preprocessing."""
|
||||
|
||||
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
|
||||
from core.detect.stages.base import register_stage
|
||||
from ._serializers import serialize_frames, deserialize_frames
|
||||
|
||||
|
||||
def _ser_extract(state: dict, job_id: str) -> dict:
|
||||
frames = state.get("frames", [])
|
||||
meta, manifest = serialize_frames(frames, job_id)
|
||||
return {"frames_meta": meta, "frames_manifest": manifest}
|
||||
|
||||
|
||||
def _deser_extract(data: dict, job_id: str) -> dict:
|
||||
frames = deserialize_frames(data["frames_meta"], data["frames_manifest"], job_id)
|
||||
return {"frames": frames}
|
||||
|
||||
|
||||
def _ser_filter(state: dict, job_id: str) -> dict:
|
||||
filtered = state.get("filtered_frames", [])
|
||||
seqs = [f.sequence for f in filtered]
|
||||
return {"filtered_frame_sequences": seqs}
|
||||
|
||||
|
||||
def _deser_filter(data: dict, job_id: str) -> dict:
|
||||
return {"_filtered_sequences": data["filtered_frame_sequences"]}
|
||||
|
||||
|
||||
def _ser_preprocess(state: dict, job_id: str) -> dict:
|
||||
# Preprocessed crops are numpy arrays — regenerable from frames + boxes + config
|
||||
crops = state.get("preprocessed_crops", {})
|
||||
return {"crop_keys": list(crops.keys()), "count": len(crops)}
|
||||
|
||||
|
||||
def _deser_preprocess(data: dict, job_id: str) -> dict:
|
||||
# Crops are regenerable — no need to restore from checkpoint
|
||||
return {"preprocessed_crops": {}}
|
||||
|
||||
|
||||
def register():
|
||||
extract = StageDefinition(
|
||||
name="extract_frames",
|
||||
label="Frame Extraction",
|
||||
description="Extract frames from video at configurable FPS",
|
||||
category="preprocessing",
|
||||
io=StageIO(reads=["video_path"], writes=["frames"]),
|
||||
config_fields=[
|
||||
StageConfigField(name="fps", type="float", default=2.0, description="Frames per second", min=0.1, max=30.0),
|
||||
StageConfigField(name="max_frames", type="int", default=500, description="Maximum frames to extract", min=1, max=10000),
|
||||
],
|
||||
)
|
||||
register_stage(extract, serialize_fn=_ser_extract, deserialize_fn=_deser_extract)
|
||||
|
||||
scene_filter = StageDefinition(
|
||||
name="filter_scenes",
|
||||
label="Scene Filter",
|
||||
description="Deduplicate similar frames using perceptual hashing",
|
||||
category="preprocessing",
|
||||
io=StageIO(reads=["frames"], writes=["filtered_frames"]),
|
||||
config_fields=[
|
||||
StageConfigField(name="hamming_threshold", type="int", default=8, description="Hamming distance threshold", min=0, max=64),
|
||||
StageConfigField(name="enabled", type="bool", default=True, description="Enable scene filtering"),
|
||||
],
|
||||
)
|
||||
register_stage(scene_filter, serialize_fn=_ser_filter, deserialize_fn=_deser_filter)
|
||||
|
||||
preprocess = StageDefinition(
|
||||
name="preprocess",
|
||||
label="Preprocess",
|
||||
description="Image preprocessing on detected regions before OCR",
|
||||
category="preprocessing",
|
||||
io=StageIO(
|
||||
reads=["filtered_frames", "boxes_by_frame"],
|
||||
writes=["preprocessed_crops"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField(name="contrast", type="bool", default=True, description="CLAHE contrast enhancement"),
|
||||
StageConfigField(name="deskew", type="bool", default=False, description="Correct slight rotation"),
|
||||
StageConfigField(name="binarize", type="bool", default=False, description="Otsu binarization"),
|
||||
],
|
||||
)
|
||||
register_stage(preprocess, serialize_fn=_ser_preprocess, deserialize_fn=_deser_preprocess)
|
||||
44
core/detect/stages/registry/resolution.py
Normal file
44
core/detect/stages/registry/resolution.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Registration for resolution stages: brand resolver."""
|
||||
|
||||
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
|
||||
from core.detect.stages.base import register_stage
|
||||
from ._serializers import (
|
||||
serialize_dataclass_list,
|
||||
serialize_text_candidates,
|
||||
deserialize_brand_detection,
|
||||
)
|
||||
|
||||
|
||||
def _ser_brands(state: dict, job_id: str) -> dict:
|
||||
matched = state.get("detections", [])
|
||||
unresolved = state.get("unresolved_candidates", [])
|
||||
return {
|
||||
"detections": serialize_dataclass_list(matched),
|
||||
"unresolved_candidates": serialize_text_candidates(unresolved),
|
||||
}
|
||||
|
||||
|
||||
def _deser_brands(data: dict, job_id: str) -> dict:
|
||||
detections = [deserialize_brand_detection(d) for d in data.get("detections", [])]
|
||||
return {
|
||||
"detections": detections,
|
||||
"_unresolved_raw": data.get("unresolved_candidates", []),
|
||||
}
|
||||
|
||||
|
||||
def register():
|
||||
resolver = StageDefinition(
|
||||
name="match_brands",
|
||||
label="Brand Resolver",
|
||||
description="Match OCR text against known brands (session + global DB)",
|
||||
category="resolution",
|
||||
io=StageIO(
|
||||
reads=["text_candidates"],
|
||||
writes=["detections", "unresolved_candidates"],
|
||||
optional_reads=["session_brands", "source_asset_id"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField(name="fuzzy_threshold", type="int", default=75, description="Fuzzy match threshold", min=0, max=100),
|
||||
],
|
||||
)
|
||||
register_stage(resolver, serialize_fn=_ser_brands, deserialize_fn=_deser_brands)
|
||||
86
core/detect/stages/scene_filter.py
Normal file
86
core/detect/stages/scene_filter.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""
|
||||
Stage 2 — Scene Filter
|
||||
|
||||
Removes near-duplicate frames using perceptual hashing (pHash).
|
||||
Frames with a hamming distance below the threshold are considered
|
||||
duplicates and dropped. This dramatically reduces work for downstream
|
||||
CV stages without losing unique visual content.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
import imagehash
|
||||
from PIL import Image
|
||||
|
||||
from core.detect import emit
|
||||
from core.detect.models import Frame
|
||||
from core.detect.stages.models import SceneFilterConfig
|
||||
|
||||
|
||||
def _compute_hashes(frames: list[Frame]) -> list[imagehash.ImageHash]:
|
||||
"""Compute perceptual hashes for all frames."""
|
||||
hashes = []
|
||||
for f in frames:
|
||||
img = Image.fromarray(f.image)
|
||||
h = imagehash.phash(img)
|
||||
f.perceptual_hash = str(h)
|
||||
hashes.append(h)
|
||||
return hashes
|
||||
|
||||
|
||||
def _dedup(frames: list[Frame], hashes: list[imagehash.ImageHash], threshold: int) -> list[Frame]:
|
||||
"""Greedy dedup: keep a frame if it's sufficiently different from all kept frames."""
|
||||
kept = [frames[0]]
|
||||
kept_hashes = [hashes[0]]
|
||||
|
||||
for i in range(1, len(frames)):
|
||||
is_duplicate = any(hashes[i] - kh < threshold for kh in kept_hashes)
|
||||
if not is_duplicate:
|
||||
kept.append(frames[i])
|
||||
kept_hashes.append(hashes[i])
|
||||
|
||||
return kept
|
||||
|
||||
|
||||
def scene_filter(
|
||||
frames: list[Frame],
|
||||
config: SceneFilterConfig,
|
||||
job_id: str | None = None,
|
||||
) -> list[Frame]:
|
||||
"""
|
||||
Filter near-duplicate frames based on perceptual hash distance.
|
||||
|
||||
Keeps the first frame in each group of similar frames.
|
||||
Returns a new list — does not mutate the input.
|
||||
"""
|
||||
if not config.enabled:
|
||||
emit.log(job_id, "SceneFilter", "INFO", "Scene filter disabled, passing all frames through")
|
||||
return frames
|
||||
|
||||
if not frames:
|
||||
return []
|
||||
|
||||
emit.log(job_id, "SceneFilter", "INFO",
|
||||
f"Filtering {len(frames)} frames (hamming_threshold={config.hamming_threshold})")
|
||||
|
||||
t0 = time.monotonic()
|
||||
hashes = _compute_hashes(frames)
|
||||
hash_ms = (time.monotonic() - t0) * 1000
|
||||
emit.log(job_id, "SceneFilter", "DEBUG",
|
||||
f"Computed {len(hashes)} perceptual hashes in {hash_ms:.0f}ms ({hash_ms/max(len(hashes),1):.1f}ms/frame)")
|
||||
|
||||
t0 = time.monotonic()
|
||||
kept = _dedup(frames, hashes, config.hamming_threshold)
|
||||
dedup_ms = (time.monotonic() - t0) * 1000
|
||||
emit.log(job_id, "SceneFilter", "DEBUG", f"Dedup pass: {dedup_ms:.0f}ms")
|
||||
|
||||
dropped = len(frames) - len(kept)
|
||||
pct = (dropped / len(frames) * 100) if frames else 0
|
||||
|
||||
emit.log(job_id, "SceneFilter", "INFO",
|
||||
f"Kept {len(kept)} frames, dropped {dropped} ({pct:.0f}% reduction)")
|
||||
emit.stats(job_id, frames_extracted=len(frames), frames_after_scene_filter=len(kept))
|
||||
|
||||
return kept
|
||||
201
core/detect/stages/vlm_cloud.py
Normal file
201
core/detect/stages/vlm_cloud.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
Stage 7 — Cloud LLM escalation
|
||||
|
||||
Last resort for crops the local VLM couldn't resolve.
|
||||
Provider-agnostic — switch via CLOUD_LLM_PROVIDER env var.
|
||||
Each provider has its own file under detect/providers/.
|
||||
|
||||
Tracks token usage and cost.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from core.detect import emit
|
||||
from core.detect.models import BrandDetection, PipelineStats, TextCandidate
|
||||
from core.detect.models import CropContext
|
||||
from core.detect.providers import get_provider, has_api_key
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ESTIMATED_TOKENS_PER_CROP = 500
|
||||
|
||||
|
||||
def _register_discovered_brand(brand: str, source_asset_id: str | None,
|
||||
timestamp: float, confidence: float):
|
||||
"""Register a cloud-confirmed brand in the DB."""
|
||||
try:
|
||||
from core.detect.stages.brand_resolver import _register_brand, _record_sighting
|
||||
brand_id = _register_brand(brand, "cloud_llm")
|
||||
if brand_id and source_asset_id:
|
||||
_record_sighting(source_asset_id, brand_id, brand, timestamp, confidence, "cloud_llm")
|
||||
except Exception as e:
|
||||
logger.debug("Failed to register brand %s: %s", brand, e)
|
||||
|
||||
|
||||
def _encode_crop(crop: np.ndarray) -> str:
|
||||
img = Image.fromarray(crop)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG", quality=85)
|
||||
return base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
|
||||
def _crop_image(candidate: TextCandidate) -> np.ndarray:
|
||||
frame = candidate.frame
|
||||
box = candidate.bbox
|
||||
h, w = frame.image.shape[:2]
|
||||
x1 = max(0, box.x)
|
||||
y1 = max(0, box.y)
|
||||
x2 = min(w, box.x + box.w)
|
||||
y2 = min(h, box.y + box.h)
|
||||
return frame.image[y1:y2, x1:x2]
|
||||
|
||||
|
||||
def _parse_response(answer: str, total_tokens: int) -> dict:
|
||||
"""Parse LLM free-text response into structured output."""
|
||||
parts = [p.strip() for p in answer.split(",", 2)]
|
||||
|
||||
brand = parts[0] if parts else ""
|
||||
confidence = 0.5
|
||||
reasoning = answer
|
||||
|
||||
if len(parts) >= 2:
|
||||
try:
|
||||
confidence = float(parts[1])
|
||||
confidence = max(0.0, min(1.0, confidence))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if len(parts) >= 3:
|
||||
reasoning = parts[2]
|
||||
|
||||
return {
|
||||
"brand": brand,
|
||||
"confidence": confidence,
|
||||
"reasoning": reasoning,
|
||||
"tokens": total_tokens or ESTIMATED_TOKENS_PER_CROP,
|
||||
}
|
||||
|
||||
|
||||
def _call_cloud_api(image_b64: str, prompt: str) -> dict:
|
||||
"""Route to the configured provider and parse the response."""
|
||||
provider = get_provider()
|
||||
result = provider.call(image_b64, prompt)
|
||||
return _parse_response(result.answer, result.total_tokens)
|
||||
|
||||
|
||||
def escalate_cloud(
|
||||
candidates: list[TextCandidate],
|
||||
vlm_prompt_fn,
|
||||
stats: PipelineStats,
|
||||
min_confidence: float = 0.4,
|
||||
content_type: str = "",
|
||||
source_asset_id: str | None = None,
|
||||
job_id: str | None = None,
|
||||
) -> list[BrandDetection]:
|
||||
"""
|
||||
Send remaining unresolved crops to cloud LLM.
|
||||
|
||||
Provider is selected via CLOUD_LLM_PROVIDER env var (groq, gemini, openai).
|
||||
Updates stats with call count and cost.
|
||||
"""
|
||||
if not candidates:
|
||||
return []
|
||||
|
||||
if os.environ.get("SKIP_CLOUD", "").strip() == "1":
|
||||
emit.log(job_id, "CloudLLM", "INFO",
|
||||
f"SKIP_CLOUD=1, skipping {len(candidates)} crops")
|
||||
return []
|
||||
|
||||
if not has_api_key():
|
||||
emit.log(job_id, "CloudLLM", "WARNING",
|
||||
f"No API key set for cloud provider, skipping {len(candidates)} crops")
|
||||
return []
|
||||
|
||||
provider = get_provider()
|
||||
emit.log(job_id, "CloudLLM", "INFO",
|
||||
f"Escalating {len(candidates)} crops to {provider.name}")
|
||||
|
||||
matched: list[BrandDetection] = []
|
||||
total_cost = 0.0
|
||||
|
||||
for i, candidate in enumerate(candidates):
|
||||
crop = _crop_image(candidate)
|
||||
if crop.size == 0:
|
||||
continue
|
||||
|
||||
crop_context = CropContext(
|
||||
image=b"",
|
||||
surrounding_text=candidate.text,
|
||||
position_hint=f"frame {candidate.frame.sequence}",
|
||||
)
|
||||
prompt = vlm_prompt_fn(crop_context)
|
||||
image_b64 = _encode_crop(crop)
|
||||
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
result = _call_cloud_api(image_b64, prompt)
|
||||
except Exception as e:
|
||||
call_ms = (time.monotonic() - t0) * 1000
|
||||
emit.log(job_id, "CloudLLM", "DEBUG",
|
||||
f"[{i+1}/{len(candidates)}] FAILED '{candidate.text[:30]}': {e} ({call_ms:.0f}ms)")
|
||||
continue
|
||||
call_ms = (time.monotonic() - t0) * 1000
|
||||
|
||||
stats.cloud_llm_calls += 1
|
||||
model_info = provider.models.get(provider.model)
|
||||
cost_per_token = model_info.cost_per_input_token if model_info else 0.00001
|
||||
call_cost = result["tokens"] * cost_per_token
|
||||
total_cost += call_cost
|
||||
|
||||
brand = result["brand"]
|
||||
confidence = result["confidence"]
|
||||
|
||||
emit.log(job_id, "CloudLLM", "DEBUG",
|
||||
f"[{i+1}/{len(candidates)}] '{candidate.text[:30]}' → "
|
||||
f"{'✓ ' + brand if brand else '✗'} "
|
||||
f"(conf={confidence:.2f}, {result['tokens']}tok, ${call_cost:.4f}, {call_ms:.0f}ms)")
|
||||
|
||||
if brand and confidence >= min_confidence:
|
||||
detection = BrandDetection(
|
||||
brand=brand,
|
||||
timestamp=candidate.frame.timestamp,
|
||||
duration=0.5,
|
||||
confidence=confidence,
|
||||
source="cloud_llm",
|
||||
bbox=candidate.bbox,
|
||||
frame_ref=candidate.frame.sequence,
|
||||
content_type=content_type,
|
||||
)
|
||||
matched.append(detection)
|
||||
|
||||
emit.detection(
|
||||
job_id,
|
||||
brand=brand,
|
||||
confidence=confidence,
|
||||
source="cloud_llm",
|
||||
timestamp=candidate.frame.timestamp,
|
||||
content_type=content_type,
|
||||
frame_ref=candidate.frame.sequence,
|
||||
)
|
||||
|
||||
# Register newly discovered brand in DB
|
||||
_register_discovered_brand(brand, source_asset_id,
|
||||
candidate.frame.timestamp, confidence)
|
||||
|
||||
stats.estimated_cloud_cost_usd += total_cost
|
||||
stats.regions_escalated_to_cloud_llm = len(candidates)
|
||||
|
||||
emit.log(job_id, "CloudLLM", "INFO",
|
||||
f"Cloud resolved {len(matched)}/{len(candidates)} — "
|
||||
f"cost ${total_cost:.4f} ({stats.cloud_llm_calls} calls total)")
|
||||
|
||||
return matched
|
||||
157
core/detect/stages/vlm_local.py
Normal file
157
core/detect/stages/vlm_local.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""
|
||||
Stage 6 — Local VLM escalation (moondream2)
|
||||
|
||||
Processes unresolved text candidates by sending crop images + prompt
|
||||
to the local VLM on the inference server. Produces BrandDetection
|
||||
objects for crops the VLM can identify.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from core.detect import emit
|
||||
from core.detect.models import BrandDetection, TextCandidate
|
||||
from core.detect.models import CropContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _register_discovered_brand(brand: str, source_asset_id: str | None,
|
||||
timestamp: float, confidence: float, source: str):
|
||||
"""Register a VLM-confirmed brand in the DB."""
|
||||
try:
|
||||
from core.detect.stages.brand_resolver import _register_brand, _record_sighting
|
||||
brand_id = _register_brand(brand, source)
|
||||
if brand_id and source_asset_id:
|
||||
_record_sighting(source_asset_id, brand_id, brand, timestamp, confidence, source)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to register brand %s: %s", brand, e)
|
||||
|
||||
|
||||
def _crop_image(candidate: TextCandidate) -> np.ndarray:
|
||||
frame = candidate.frame
|
||||
box = candidate.bbox
|
||||
h, w = frame.image.shape[:2]
|
||||
x1 = max(0, box.x)
|
||||
y1 = max(0, box.y)
|
||||
x2 = min(w, box.x + box.w)
|
||||
y2 = min(h, box.y + box.h)
|
||||
return frame.image[y1:y2, x1:x2]
|
||||
|
||||
|
||||
def escalate_vlm(
|
||||
candidates: list[TextCandidate],
|
||||
vlm_prompt_fn,
|
||||
inference_url: str | None = None,
|
||||
min_confidence: float = 0.5,
|
||||
content_type: str = "",
|
||||
source_asset_id: str | None = None,
|
||||
job_id: str | None = None,
|
||||
) -> tuple[list[BrandDetection], list[TextCandidate]]:
|
||||
"""
|
||||
Send unresolved crops to local VLM for brand identification.
|
||||
|
||||
Returns:
|
||||
- matched: BrandDetections the VLM confirmed
|
||||
- still_unresolved: candidates the VLM couldn't resolve (→ cloud escalation)
|
||||
"""
|
||||
if not candidates:
|
||||
return [], []
|
||||
|
||||
if os.environ.get("SKIP_VLM", "").strip() == "1":
|
||||
emit.log(job_id, "VLMLocal", "INFO",
|
||||
f"SKIP_VLM=1, skipping {len(candidates)} crops")
|
||||
return [], candidates
|
||||
|
||||
emit.log(job_id, "VLMLocal", "INFO",
|
||||
f"Processing {len(candidates)} unresolved crops with moondream2")
|
||||
|
||||
matched: list[BrandDetection] = []
|
||||
still_unresolved: list[TextCandidate] = []
|
||||
|
||||
if inference_url:
|
||||
from core.detect.inference import InferenceClient
|
||||
from core.detect.emit import _run_log_level
|
||||
client = InferenceClient(base_url=inference_url, job_id=job_id or "", log_level=_run_log_level)
|
||||
|
||||
for i, candidate in enumerate(candidates):
|
||||
crop = _crop_image(candidate)
|
||||
if crop.size == 0:
|
||||
still_unresolved.append(candidate)
|
||||
continue
|
||||
|
||||
crop_context = CropContext(
|
||||
image=b"", # not used for prompt generation
|
||||
surrounding_text=candidate.text,
|
||||
position_hint=f"frame {candidate.frame.sequence}",
|
||||
)
|
||||
prompt = vlm_prompt_fn(crop_context)
|
||||
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
if inference_url:
|
||||
result = client.vlm(image=crop, prompt=prompt)
|
||||
brand = result.brand
|
||||
confidence = result.confidence
|
||||
reasoning = result.reasoning
|
||||
else:
|
||||
brand, confidence, reasoning = _vlm_local(crop, prompt)
|
||||
except Exception as e:
|
||||
vlm_ms = (time.monotonic() - t0) * 1000
|
||||
emit.log(job_id, "VLMLocal", "DEBUG",
|
||||
f"[{i+1}/{len(candidates)}] FAILED '{candidate.text[:30]}': {e} ({vlm_ms:.0f}ms)")
|
||||
still_unresolved.append(candidate)
|
||||
continue
|
||||
vlm_ms = (time.monotonic() - t0) * 1000
|
||||
emit.log(job_id, "VLMLocal", "DEBUG",
|
||||
f"[{i+1}/{len(candidates)}] '{candidate.text[:30]}' → "
|
||||
f"{'✓ ' + brand if brand else '✗ unresolved'} "
|
||||
f"(conf={confidence:.2f}, {vlm_ms:.0f}ms)")
|
||||
|
||||
if brand and confidence >= min_confidence:
|
||||
detection = BrandDetection(
|
||||
brand=brand,
|
||||
timestamp=candidate.frame.timestamp,
|
||||
duration=0.5,
|
||||
confidence=confidence,
|
||||
source="local_vlm",
|
||||
bbox=candidate.bbox,
|
||||
frame_ref=candidate.frame.sequence,
|
||||
content_type=content_type,
|
||||
)
|
||||
matched.append(detection)
|
||||
|
||||
emit.detection(
|
||||
job_id,
|
||||
brand=brand,
|
||||
confidence=confidence,
|
||||
source="local_vlm",
|
||||
timestamp=candidate.frame.timestamp,
|
||||
content_type=content_type,
|
||||
frame_ref=candidate.frame.sequence,
|
||||
)
|
||||
|
||||
# Register newly discovered brand in DB
|
||||
_register_discovered_brand(brand, source_asset_id,
|
||||
candidate.frame.timestamp, confidence, "local_vlm")
|
||||
|
||||
logger.debug("VLM matched: %s (%.2f) — %s", brand, confidence, reasoning)
|
||||
else:
|
||||
still_unresolved.append(candidate)
|
||||
|
||||
emit.log(job_id, "VLMLocal", "INFO",
|
||||
f"VLM resolved {len(matched)}, unresolved {len(still_unresolved)} → cloud")
|
||||
|
||||
return matched, still_unresolved
|
||||
|
||||
|
||||
def _vlm_local(crop: np.ndarray, prompt: str) -> tuple[str, float, str]:
|
||||
"""Run moondream2 in-process (single-box mode)."""
|
||||
from core.gpu.models.vlm import query
|
||||
result = query(crop, prompt)
|
||||
return result["brand"], result["confidence"], result["reasoning"]
|
||||
138
core/detect/stages/yolo_detector.py
Normal file
138
core/detect/stages/yolo_detector.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""
|
||||
Stage 3 — YOLO Object Detection
|
||||
|
||||
Detects regions of interest (logos, text, banners) in frames.
|
||||
Two modes:
|
||||
- Remote: calls inference server over HTTP (GPU on another machine)
|
||||
- Local: imports ultralytics directly (GPU on same machine)
|
||||
|
||||
Emits frame_update events with bounding boxes for the UI.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import time
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from core.detect import emit
|
||||
from core.detect.models import BoundingBox, Frame
|
||||
from core.detect.stages.models import DetectionConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _frame_to_b64(frame: Frame) -> str:
|
||||
"""Encode frame as base64 JPEG for SSE frame_update events."""
|
||||
img = Image.fromarray(frame.image)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG", quality=70)
|
||||
return base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
|
||||
def _detect_remote(frame: Frame, config: DetectionConfig, inference_url: str,
|
||||
job_id: str = "", log_level: str = "INFO") -> list[BoundingBox]:
|
||||
"""Call the inference server over HTTP."""
|
||||
from core.detect.inference import InferenceClient
|
||||
client = InferenceClient(base_url=inference_url, job_id=job_id, log_level=log_level)
|
||||
results = client.detect(
|
||||
image=frame.image,
|
||||
model=config.model_name,
|
||||
confidence=config.confidence_threshold,
|
||||
target_classes=config.target_classes,
|
||||
)
|
||||
boxes = []
|
||||
for r in results:
|
||||
box = BoundingBox(
|
||||
x=r.x, y=r.y, w=r.w, h=r.h,
|
||||
confidence=r.confidence, label=r.label,
|
||||
)
|
||||
boxes.append(box)
|
||||
return boxes
|
||||
|
||||
|
||||
def _detect_local(frame: Frame, config: DetectionConfig) -> list[BoundingBox]:
|
||||
"""Run YOLO in-process (requires ultralytics installed)."""
|
||||
from ultralytics import YOLO
|
||||
model = YOLO(config.model_name)
|
||||
results = model(frame.image, conf=config.confidence_threshold, verbose=False)
|
||||
|
||||
boxes = []
|
||||
for r in results:
|
||||
for det in r.boxes:
|
||||
x1, y1, x2, y2 = det.xyxy[0].tolist()
|
||||
label = r.names[int(det.cls[0])]
|
||||
|
||||
if config.target_classes and label not in config.target_classes:
|
||||
continue
|
||||
|
||||
box = BoundingBox(
|
||||
x=int(x1), y=int(y1),
|
||||
w=int(x2 - x1), h=int(y2 - y1),
|
||||
confidence=float(det.conf[0]),
|
||||
label=label,
|
||||
)
|
||||
boxes.append(box)
|
||||
return boxes
|
||||
|
||||
|
||||
def detect_objects(
|
||||
frames: list[Frame],
|
||||
config: DetectionConfig,
|
||||
inference_url: str | None = None,
|
||||
job_id: str | None = None,
|
||||
) -> dict[int, list[BoundingBox]]:
|
||||
"""
|
||||
Run object detection on all frames.
|
||||
|
||||
If inference_url is provided, calls the remote GPU server.
|
||||
Otherwise, imports ultralytics and runs locally.
|
||||
|
||||
Returns a dict mapping frame sequence → list of bounding boxes.
|
||||
"""
|
||||
mode = "remote" if inference_url else "local"
|
||||
emit.log(job_id, "YOLODetector", "INFO",
|
||||
f"Detecting objects in {len(frames)} frames "
|
||||
f"(model={config.model_name}, conf={config.confidence_threshold}, mode={mode})")
|
||||
|
||||
all_boxes: dict[int, list[BoundingBox]] = {}
|
||||
total_regions = 0
|
||||
|
||||
for i, frame in enumerate(frames):
|
||||
t0 = time.monotonic()
|
||||
if inference_url:
|
||||
from core.detect.emit import _run_log_level
|
||||
boxes = _detect_remote(frame, config, inference_url,
|
||||
job_id=job_id or "", log_level=_run_log_level)
|
||||
else:
|
||||
boxes = _detect_local(frame, config)
|
||||
det_ms = (time.monotonic() - t0) * 1000
|
||||
|
||||
all_boxes[frame.sequence] = boxes
|
||||
total_regions += len(boxes)
|
||||
|
||||
emit.log(job_id, "YOLODetector", "DEBUG",
|
||||
f"Frame {frame.sequence}: {len(boxes)} regions in {det_ms:.0f}ms"
|
||||
f" [{', '.join(b.label for b in boxes)}]" if boxes else
|
||||
f"Frame {frame.sequence}: 0 regions in {det_ms:.0f}ms")
|
||||
|
||||
if boxes and job_id:
|
||||
box_dicts = [{"x": b.x, "y": b.y, "w": b.w, "h": b.h,
|
||||
"confidence": b.confidence, "label": b.label}
|
||||
for b in boxes]
|
||||
emit.frame_update(
|
||||
job_id,
|
||||
frame_ref=frame.sequence,
|
||||
timestamp=frame.timestamp,
|
||||
jpeg_b64=_frame_to_b64(frame),
|
||||
boxes=box_dicts,
|
||||
)
|
||||
|
||||
emit.log(job_id, "YOLODetector", "INFO",
|
||||
f"Detected {total_regions} regions across {len(frames)} frames")
|
||||
emit.stats(job_id, regions_detected=total_regions)
|
||||
|
||||
return all_boxes
|
||||
43
core/detect/state.py
Normal file
43
core/detect/state.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
LangGraph state definition for the detection pipeline.
|
||||
|
||||
This TypedDict flows through all graph nodes. Each node reads what
|
||||
it needs and writes its outputs. LangGraph manages the state transitions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypedDict
|
||||
|
||||
from core.detect.models import BoundingBox, BrandDetection, DetectionReport, Frame, PipelineStats, TextCandidate
|
||||
|
||||
|
||||
class DetectState(TypedDict, total=False):
|
||||
# Input
|
||||
video_path: str
|
||||
job_id: str
|
||||
profile_name: str
|
||||
source_asset_id: str # UUID of the source MediaAsset
|
||||
|
||||
# Stage outputs
|
||||
frames: list[Frame]
|
||||
filtered_frames: list[Frame]
|
||||
field_masks: dict # {seq: np.ndarray} — pitch mask per frame
|
||||
field_boundaries: dict # {seq: [(x,y), ...]} — pitch boundary per frame
|
||||
field_coverage: dict # {seq: float} — pitch coverage ratio per frame
|
||||
edge_regions_by_frame: dict[int, list[BoundingBox]]
|
||||
boxes_by_frame: dict[int, list[BoundingBox]]
|
||||
preprocessed_crops: dict # "{frame_seq}_{box_idx}" → np.ndarray
|
||||
text_candidates: list[TextCandidate]
|
||||
unresolved_candidates: list[TextCandidate]
|
||||
detections: list[BrandDetection]
|
||||
report: DetectionReport
|
||||
|
||||
# Session brands (accumulated during the run, persisted to DB)
|
||||
session_brands: dict # {normalized_name: canonical_name}
|
||||
|
||||
# Running stats (updated by each stage)
|
||||
stats: PipelineStats
|
||||
|
||||
# Config overrides for replay (merged into profile configs dict)
|
||||
config_overrides: dict
|
||||
131
core/detect/tracing.py
Normal file
131
core/detect/tracing.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
Langfuse tracing for the detection pipeline.
|
||||
|
||||
Provides span helpers that graph nodes use to record timing, frame counts,
|
||||
and stage-level metadata. The Langfuse client is optional — if not configured
|
||||
(no LANGFUSE_SECRET_KEY), tracing is a no-op.
|
||||
|
||||
Usage in graph nodes:
|
||||
from core.detect.tracing import trace_node
|
||||
|
||||
def node_extract_frames(state):
|
||||
with trace_node(state, "extract_frames") as span:
|
||||
...
|
||||
span.set_output({"frames": len(frames)})
|
||||
return {...}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_client = None
|
||||
_enabled: bool | None = None
|
||||
|
||||
|
||||
def _get_client():
|
||||
"""Lazy-init Langfuse client. Returns None if not configured."""
|
||||
global _client, _enabled
|
||||
if _enabled is False:
|
||||
return None
|
||||
if _client is not None:
|
||||
return _client
|
||||
|
||||
secret = os.environ.get("LANGFUSE_SECRET_KEY", "")
|
||||
if not secret:
|
||||
_enabled = False
|
||||
logger.info("Langfuse not configured (no LANGFUSE_SECRET_KEY), tracing disabled")
|
||||
return None
|
||||
|
||||
try:
|
||||
from langfuse import Langfuse
|
||||
_client = Langfuse()
|
||||
_enabled = True
|
||||
logger.info("Langfuse tracing enabled")
|
||||
return _client
|
||||
except Exception as e:
|
||||
_enabled = False
|
||||
logger.warning("Langfuse init failed: %s — tracing disabled", e)
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpanContext:
|
||||
"""Wraps a Langfuse span with convenience methods."""
|
||||
_span: object | None = None
|
||||
_start: float = field(default_factory=time.monotonic)
|
||||
metadata: dict = field(default_factory=dict)
|
||||
|
||||
def set_output(self, output: dict) -> None:
|
||||
self.metadata.update(output)
|
||||
|
||||
def set_error(self, error: str) -> None:
|
||||
self.metadata["error"] = error
|
||||
|
||||
def _finish(self, status: str = "ok") -> None:
|
||||
elapsed = time.monotonic() - self._start
|
||||
self.metadata["duration_seconds"] = round(elapsed, 3)
|
||||
self.metadata["status"] = status
|
||||
|
||||
if self._span is not None:
|
||||
try:
|
||||
self._span.update(
|
||||
output=self.metadata,
|
||||
level="ERROR" if status == "error" else "DEFAULT",
|
||||
)
|
||||
self._span.end()
|
||||
except Exception as e:
|
||||
logger.debug("Failed to end Langfuse span: %s", e)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def trace_node(state: dict, node_name: str):
|
||||
"""
|
||||
Context manager that creates a Langfuse span for a pipeline node.
|
||||
|
||||
Usage:
|
||||
with trace_node(state, "extract_frames") as span:
|
||||
frames = do_work()
|
||||
span.set_output({"frames": len(frames)})
|
||||
"""
|
||||
job_id = state.get("job_id", "unknown")
|
||||
profile = state.get("profile_name", "")
|
||||
client = _get_client()
|
||||
|
||||
span_obj = None
|
||||
if client is not None:
|
||||
try:
|
||||
trace = client.trace(
|
||||
name=f"detect:{job_id}",
|
||||
session_id=job_id,
|
||||
metadata={"profile": profile},
|
||||
)
|
||||
span_obj = trace.span(
|
||||
name=node_name,
|
||||
input={"job_id": job_id, "profile": profile},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to create Langfuse span: %s", e)
|
||||
|
||||
ctx = SpanContext(_span=span_obj)
|
||||
try:
|
||||
yield ctx
|
||||
ctx._finish("ok")
|
||||
except Exception:
|
||||
ctx._finish("error")
|
||||
raise
|
||||
|
||||
|
||||
def flush():
|
||||
"""Flush pending Langfuse events. Call at pipeline end."""
|
||||
if _client is not None:
|
||||
try:
|
||||
_client.flush()
|
||||
except Exception as e:
|
||||
logger.debug("Langfuse flush failed: %s", e)
|
||||
Reference in New Issue
Block a user