277 lines
9.2 KiB
Python
277 lines
9.2 KiB
Python
"""
|
|
Pipeline replay — re-run from any stage with different config.
|
|
|
|
Loads a checkpoint, applies config overrides, builds a subgraph
|
|
starting from the target stage, and invokes it.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
|
|
import uuid
|
|
|
|
from detect import emit
|
|
# TODO: migrate to Timeline/Branch/Checkpoint model
|
|
# These old functions no longer exist — replay needs rework
|
|
def _not_migrated(*args, **kwargs):
|
|
raise NotImplementedError("Replay not yet migrated to Timeline/Branch/Checkpoint model")
|
|
|
|
load_checkpoint = _not_migrated
|
|
list_checkpoints = _not_migrated
|
|
from detect.graph import NODES, build_graph
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class OverrideProfile:
|
|
"""
|
|
Wraps a ContentTypeProfile and patches config methods with overrides.
|
|
|
|
Override dict structure:
|
|
{
|
|
"frame_extraction": {"fps": 1.0},
|
|
"scene_filter": {"hamming_threshold": 12},
|
|
"region_analysis": {"edge_canny_low": 30, "edge_canny_high": 120},
|
|
"detection": {"confidence_threshold": 0.5},
|
|
"ocr": {"languages": ["en", "es"], "min_confidence": 0.3},
|
|
"resolver": {"fuzzy_threshold": 60},
|
|
}
|
|
"""
|
|
|
|
def __init__(self, base, overrides: dict):
|
|
self._base = base
|
|
self._overrides = overrides
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self._base, name)
|
|
|
|
def _patch(self, config, key: str):
|
|
patches = self._overrides.get(key, {})
|
|
for k, v in patches.items():
|
|
if hasattr(config, k):
|
|
setattr(config, k, v)
|
|
return config
|
|
|
|
def frame_extraction_config(self):
|
|
return self._patch(self._base.frame_extraction_config(), "frame_extraction")
|
|
|
|
def scene_filter_config(self):
|
|
return self._patch(self._base.scene_filter_config(), "scene_filter")
|
|
|
|
def region_analysis_config(self):
|
|
return self._patch(self._base.region_analysis_config(), "region_analysis")
|
|
|
|
def detection_config(self):
|
|
return self._patch(self._base.detection_config(), "detection")
|
|
|
|
def ocr_config(self):
|
|
return self._patch(self._base.ocr_config(), "ocr")
|
|
|
|
def resolver_config(self):
|
|
return self._patch(self._base.resolver_config(), "resolver")
|
|
|
|
def vlm_prompt(self, crop_context):
|
|
return self._base.vlm_prompt(crop_context)
|
|
|
|
def aggregate(self, detections):
|
|
return self._base.aggregate(detections)
|
|
|
|
def auxiliary_detections(self, source):
|
|
return self._base.auxiliary_detections(source)
|
|
|
|
|
|
def replay_from(
|
|
job_id: str,
|
|
start_stage: str,
|
|
config_overrides: dict | None = None,
|
|
checkpoint: bool = True,
|
|
) -> dict:
|
|
"""
|
|
Replay the pipeline from a specific stage.
|
|
|
|
Loads the checkpoint from the stage immediately before start_stage,
|
|
applies config overrides, and runs the subgraph from start_stage onward.
|
|
|
|
Returns the final state dict.
|
|
"""
|
|
if start_stage not in NODES:
|
|
raise ValueError(f"Unknown stage: {start_stage!r}. Options: {NODES}")
|
|
|
|
start_idx = NODES.index(start_stage)
|
|
|
|
# Load checkpoint from the stage before start_stage
|
|
if start_idx == 0:
|
|
raise ValueError("Cannot replay from the first stage — just run the full pipeline")
|
|
|
|
previous_stage = NODES[start_idx - 1]
|
|
|
|
available = list_checkpoints(job_id)
|
|
if previous_stage not in available:
|
|
raise ValueError(
|
|
f"No checkpoint for stage {previous_stage!r} (job {job_id}). "
|
|
f"Available: {available}"
|
|
)
|
|
|
|
logger.info("Replaying job %s from %s (loading checkpoint: %s)",
|
|
job_id, start_stage, previous_stage)
|
|
|
|
state = load_checkpoint(job_id, previous_stage)
|
|
|
|
# Apply config overrides
|
|
if config_overrides:
|
|
state["config_overrides"] = config_overrides
|
|
|
|
# Set run context for SSE events
|
|
run_id = str(uuid.uuid4())[:8]
|
|
emit.set_run_context(
|
|
run_id=run_id,
|
|
parent_job_id=job_id,
|
|
run_type="replay",
|
|
)
|
|
|
|
# Build subgraph starting from start_stage
|
|
graph = build_graph(checkpoint=checkpoint, start_from=start_stage)
|
|
pipeline = graph.compile()
|
|
|
|
try:
|
|
result = pipeline.invoke(state)
|
|
finally:
|
|
emit.clear_run_context()
|
|
|
|
return result
|
|
|
|
|
|
def replay_single_stage(
|
|
job_id: str,
|
|
stage: str,
|
|
frame_refs: list[int] | None = None,
|
|
config_overrides: dict | None = None,
|
|
debug: bool = False,
|
|
) -> dict:
|
|
"""
|
|
Replay a single stage on specific frames (or all frames from checkpoint).
|
|
|
|
Fast path for interactive parameter tuning — runs only the target stage
|
|
function, not the full pipeline tail. Returns the stage output directly.
|
|
|
|
When debug=True and stage is detect_edges, returns additional overlay
|
|
data (Canny edges, Hough lines) for visual feedback in the editor.
|
|
|
|
For detect_edges: returns {"edge_regions_by_frame": {seq: [box, ...]}}
|
|
With debug=True, also returns {"debug": {seq: {edge_overlay_b64, lines_overlay_b64, ...}}}
|
|
"""
|
|
if stage not in NODES:
|
|
raise ValueError(f"Unknown stage: {stage!r}. Options: {NODES}")
|
|
|
|
stage_idx = NODES.index(stage)
|
|
if stage_idx == 0:
|
|
raise ValueError("Cannot replay the first stage — just run the full pipeline")
|
|
|
|
previous_stage = NODES[stage_idx - 1]
|
|
|
|
available = list_checkpoints(job_id)
|
|
if previous_stage not in available:
|
|
raise ValueError(
|
|
f"No checkpoint for stage {previous_stage!r} (job {job_id}). "
|
|
f"Available: {available}"
|
|
)
|
|
|
|
logger.info("Single-stage replay: job %s, stage %s (loading checkpoint: %s, debug=%s)",
|
|
job_id, stage, previous_stage, debug)
|
|
|
|
state = load_checkpoint(job_id, previous_stage)
|
|
|
|
# Build profile with overrides
|
|
from detect.profiles import get_profile
|
|
profile = get_profile(state.get("profile_name", "soccer_broadcast"))
|
|
if config_overrides:
|
|
profile = OverrideProfile(profile, config_overrides)
|
|
|
|
# Run the stage function directly (not through the graph)
|
|
if stage == "detect_edges":
|
|
return _replay_detect_edges(state, profile, frame_refs, job_id, debug)
|
|
else:
|
|
raise ValueError(
|
|
f"Single-stage replay not yet implemented for {stage!r}. "
|
|
f"Use replay_from() for full pipeline replay."
|
|
)
|
|
|
|
|
|
def _replay_detect_edges(
|
|
state: dict,
|
|
profile,
|
|
frame_refs: list[int] | None,
|
|
job_id: str,
|
|
debug: bool,
|
|
) -> dict:
|
|
"""Run edge detection on checkpoint frames, optionally with debug overlays."""
|
|
import os
|
|
from detect.stages.edge_detector import detect_edge_regions
|
|
|
|
config = profile.region_analysis_config()
|
|
frames = state.get("filtered_frames", [])
|
|
|
|
if frame_refs:
|
|
ref_set = set(frame_refs)
|
|
frames = [f for f in frames if f.sequence in ref_set]
|
|
|
|
inference_url = os.environ.get("INFERENCE_URL")
|
|
|
|
# Normal run — always needed for the boxes
|
|
result = detect_edge_regions(
|
|
frames=frames,
|
|
config=config,
|
|
inference_url=inference_url,
|
|
job_id=job_id,
|
|
)
|
|
output = {"edge_regions_by_frame": result}
|
|
|
|
# Debug overlays — call debug endpoint (remote) or local debug function
|
|
if debug and frames:
|
|
debug_data = {}
|
|
if inference_url:
|
|
from detect.inference import InferenceClient
|
|
client = InferenceClient(base_url=inference_url, job_id=job_id)
|
|
for frame in frames:
|
|
dr = client.detect_edges_debug(
|
|
image=frame.image,
|
|
edge_canny_low=config.edge_canny_low,
|
|
edge_canny_high=config.edge_canny_high,
|
|
edge_hough_threshold=config.edge_hough_threshold,
|
|
edge_hough_min_length=config.edge_hough_min_length,
|
|
edge_hough_max_gap=config.edge_hough_max_gap,
|
|
edge_pair_max_distance=config.edge_pair_max_distance,
|
|
edge_pair_min_distance=config.edge_pair_min_distance,
|
|
)
|
|
debug_data[frame.sequence] = {
|
|
"edge_overlay_b64": dr.edge_overlay_b64,
|
|
"lines_overlay_b64": dr.lines_overlay_b64,
|
|
"horizontal_count": dr.horizontal_count,
|
|
"pair_count": dr.pair_count,
|
|
}
|
|
else:
|
|
# Local mode — import GPU module directly
|
|
from detect.stages.edge_detector import _load_cv_edges
|
|
edges_mod = _load_cv_edges()
|
|
for frame in frames:
|
|
dr = edges_mod.detect_edges_debug(
|
|
frame.image,
|
|
canny_low=config.edge_canny_low,
|
|
canny_high=config.edge_canny_high,
|
|
hough_threshold=config.edge_hough_threshold,
|
|
hough_min_length=config.edge_hough_min_length,
|
|
hough_max_gap=config.edge_hough_max_gap,
|
|
pair_max_distance=config.edge_pair_max_distance,
|
|
pair_min_distance=config.edge_pair_min_distance,
|
|
)
|
|
debug_data[frame.sequence] = {
|
|
"edge_overlay_b64": dr["edge_overlay_b64"],
|
|
"lines_overlay_b64": dr["lines_overlay_b64"],
|
|
"horizontal_count": dr["horizontal_count"],
|
|
"pair_count": dr["pair_count"],
|
|
}
|
|
output["debug"] = debug_data
|
|
|
|
return output
|