Files
mediaproc/core/detect/checkpoint/replay.py
2026-03-30 07:22:14 -03:00

233 lines
8.0 KiB
Python

"""
Pipeline replay — re-run from any stage with different config.
Loads a checkpoint, applies config overrides, builds a subgraph
starting from the target stage, and invokes it.
"""
from __future__ import annotations
import logging
import uuid
from core.detect import emit
# TODO: migrate to Timeline/Branch/Checkpoint model
# These old functions no longer exist — replay needs rework
def _not_migrated(*args, **kwargs):
raise NotImplementedError("Replay not yet migrated to Timeline/Branch/Checkpoint model")
load_checkpoint = _not_migrated
list_checkpoints = _not_migrated
from core.detect.graph import NODES, build_graph
logger = logging.getLogger(__name__)
# OverrideProfile removed — config overrides are now handled by dict merging
# in _load_profile() (nodes.py) and replay_single_stage (below).
def replay_from(
job_id: str,
start_stage: str,
config_overrides: dict | None = None,
checkpoint: bool = True,
) -> dict:
"""
Replay the pipeline from a specific stage.
Loads the checkpoint from the stage immediately before start_stage,
applies config overrides, and runs the subgraph from start_stage onward.
Returns the final state dict.
"""
if start_stage not in NODES:
raise ValueError(f"Unknown stage: {start_stage!r}. Options: {NODES}")
start_idx = NODES.index(start_stage)
# Load checkpoint from the stage before start_stage
if start_idx == 0:
raise ValueError("Cannot replay from the first stage — just run the full pipeline")
previous_stage = NODES[start_idx - 1]
available = list_checkpoints(job_id)
if previous_stage not in available:
raise ValueError(
f"No checkpoint for stage {previous_stage!r} (job {job_id}). "
f"Available: {available}"
)
logger.info("Replaying job %s from %s (loading checkpoint: %s)",
job_id, start_stage, previous_stage)
state = load_checkpoint(job_id, previous_stage)
# Apply config overrides
if config_overrides:
state["config_overrides"] = config_overrides
# Set run context for SSE events
run_id = str(uuid.uuid4())[:8]
emit.set_run_context(
run_id=run_id,
parent_job_id=job_id,
run_type="replay",
)
# Build subgraph starting from start_stage
graph = build_graph(checkpoint=checkpoint, start_from=start_stage)
pipeline = graph.compile()
try:
result = pipeline.invoke(state)
finally:
emit.clear_run_context()
return result
def replay_single_stage(
job_id: str,
stage: str,
frame_refs: list[int] | None = None,
config_overrides: dict | None = None,
debug: bool = False,
) -> dict:
"""
Replay a single stage on specific frames (or all frames from checkpoint).
Fast path for interactive parameter tuning — runs only the target stage
function, not the full pipeline tail. Returns the stage output directly.
When debug=True and stage is detect_edges, returns additional overlay
data (Canny edges, Hough lines) for visual feedback in the editor.
For detect_edges: returns {"edge_regions_by_frame": {seq: [box, ...]}}
With debug=True, also returns {"debug": {seq: {edge_overlay_b64, lines_overlay_b64, ...}}}
"""
if stage not in NODES:
raise ValueError(f"Unknown stage: {stage!r}. Options: {NODES}")
stage_idx = NODES.index(stage)
if stage_idx == 0:
raise ValueError("Cannot replay the first stage — just run the full pipeline")
previous_stage = NODES[stage_idx - 1]
available = list_checkpoints(job_id)
if previous_stage not in available:
raise ValueError(
f"No checkpoint for stage {previous_stage!r} (job {job_id}). "
f"Available: {available}"
)
logger.info("Single-stage replay: job %s, stage %s (loading checkpoint: %s, debug=%s)",
job_id, stage, previous_stage, debug)
state = load_checkpoint(job_id, previous_stage)
# Build profile with overrides
from core.detect.profile import get_profile, get_stage_config
profile = get_profile(state.get("profile_name", "soccer_broadcast"))
if config_overrides:
merged_configs = dict(profile.get("configs", {}))
for sname, soverrides in config_overrides.items():
if sname in merged_configs:
merged_configs[sname] = {**merged_configs[sname], **soverrides}
else:
merged_configs[sname] = soverrides
profile = {**profile, "configs": merged_configs}
# Run the stage function directly (not through the graph)
if stage == "detect_edges":
return _replay_detect_edges(state, profile, frame_refs, job_id, debug)
else:
raise ValueError(
f"Single-stage replay not yet implemented for {stage!r}. "
f"Use replay_from() for full pipeline replay."
)
def _replay_detect_edges(
state: dict,
profile,
frame_refs: list[int] | None,
job_id: str,
debug: bool,
) -> dict:
"""Run edge detection on checkpoint frames, optionally with debug overlays."""
import os
from core.detect.stages.edge_detector import detect_edge_regions
from core.detect.profile import get_stage_config
from core.detect.stages.models import RegionAnalysisConfig
config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
frames = state.get("filtered_frames", [])
if frame_refs:
ref_set = set(frame_refs)
frames = [f for f in frames if f.sequence in ref_set]
inference_url = os.environ.get("INFERENCE_URL")
# Normal run — always needed for the boxes
result = detect_edge_regions(
frames=frames,
config=config,
inference_url=inference_url,
job_id=job_id,
)
output = {"edge_regions_by_frame": result}
# Debug overlays — call debug endpoint (remote) or local debug function
if debug and frames:
debug_data = {}
if inference_url:
from core.detect.inference import InferenceClient
client = InferenceClient(base_url=inference_url, job_id=job_id)
for frame in frames:
dr = client.detect_edges_debug(
image=frame.image,
edge_canny_low=config.edge_canny_low,
edge_canny_high=config.edge_canny_high,
edge_hough_threshold=config.edge_hough_threshold,
edge_hough_min_length=config.edge_hough_min_length,
edge_hough_max_gap=config.edge_hough_max_gap,
edge_pair_max_distance=config.edge_pair_max_distance,
edge_pair_min_distance=config.edge_pair_min_distance,
)
debug_data[frame.sequence] = {
"edge_overlay_b64": dr.edge_overlay_b64,
"lines_overlay_b64": dr.lines_overlay_b64,
"horizontal_count": dr.horizontal_count,
"pair_count": dr.pair_count,
}
else:
# Local mode — import GPU module directly
from core.detect.stages.edge_detector import _load_cv_edges
edges_mod = _load_cv_edges()
for frame in frames:
dr = edges_mod.detect_edges_debug(
frame.image,
canny_low=config.edge_canny_low,
canny_high=config.edge_canny_high,
hough_threshold=config.edge_hough_threshold,
hough_min_length=config.edge_hough_min_length,
hough_max_gap=config.edge_hough_max_gap,
pair_max_distance=config.edge_pair_max_distance,
pair_min_distance=config.edge_pair_min_distance,
)
debug_data[frame.sequence] = {
"edge_overlay_b64": dr["edge_overlay_b64"],
"lines_overlay_b64": dr["lines_overlay_b64"],
"horizontal_count": dr["horizontal_count"],
"pair_count": dr["pair_count"],
}
output["debug"] = debug_data
return output