""" Pipeline replay — re-run from any stage with different config. Loads stage outputs from DB, frames from timeline cache, reconstitutes state, and runs from a target stage onward. Creates a new Job (run_type=REPLAY) for each replay invocation. """ from __future__ import annotations import logging import os import uuid from core.detect import emit from core.detect.graph import NODES, get_pipeline from core.detect.graph.runner import PipelineRunner logger = logging.getLogger(__name__) def _build_state_for_replay( job_id: str, up_to_stage: str, ) -> dict: """ Reconstitute pipeline state from a completed job's stage outputs, up to (but not including) the target stage. Loads frames from timeline cache + stage outputs from DB. """ from .storage import load_stage_outputs_for_job, get_checkpoints_for_job from .frames import load_cached_frames from core.db.connection import get_session from core.db.job import get_job # Load the job to get timeline_id and profile with get_session() as session: job = get_job(session, uuid.UUID(job_id)) if not job: raise ValueError(f"Job not found: {job_id}") timeline_id = str(job.timeline_id) if job.timeline_id else "" if not timeline_id: raise ValueError(f"Job {job_id} has no timeline") # Load frames from timeline cache frames = load_cached_frames(timeline_id) if not frames: raise ValueError(f"No cached frames for timeline {timeline_id}. Run the pipeline first.") # Load all stage outputs for this job all_outputs = load_stage_outputs_for_job(job_id) # Build state with envelope + frames state = { "job_id": job_id, "timeline_id": timeline_id, "video_path": job.video_path, "profile_name": job.profile_name, "source_asset_id": str(job.source_asset_id), "frames": frames, "config_overrides": {}, } # Apply stage outputs in pipeline order, up to the target stage target_idx = NODES.index(up_to_stage) for stage_name in NODES[:target_idx]: output = all_outputs.get(stage_name) if output: # Stage outputs contain serialized data — merge into state # The stage registry's deserialize_fn can reconstitute if needed for key, value in output.items(): state[key] = value # Filtered frames: reconstruct from sequence list if present filtered_seqs = state.get("filtered_frame_sequences") if filtered_seqs: seq_set = set(filtered_seqs) state["filtered_frames"] = [f for f in frames if f.sequence in seq_set] elif "filtered_frames" not in state: state["filtered_frames"] = frames return state def replay_from( job_id: str, start_stage: str, config_overrides: dict | None = None, checkpoint: bool = True, ) -> dict: """ Replay the pipeline from a specific stage. Loads state from the original job's stage outputs up to start_stage, applies config overrides, and runs from start_stage onward. Creates a new Job (run_type=REPLAY). Returns the final state dict. """ if start_stage not in NODES: raise ValueError(f"Unknown stage: {start_stage!r}. Options: {NODES}") start_idx = NODES.index(start_stage) if start_idx == 0: raise ValueError("Cannot replay from the first stage — just run the full pipeline") logger.info("Replaying job %s from %s", job_id, start_stage) state = _build_state_for_replay(job_id, start_stage) # Apply config overrides if config_overrides: state["config_overrides"] = config_overrides # Create replay job from core.db.connection import get_session from core.db.job import create_job, get_job with get_session() as session: original = get_job(session, uuid.UUID(job_id)) replay_job = create_job( session, source_asset_id=original.source_asset_id, video_path=original.video_path, timeline_id=original.timeline_id, profile_name=original.profile_name, run_type="replay", parent_id=original.id, config_overrides=config_overrides, ) replay_job_id = str(replay_job.id) # Update state with new job ID state["job_id"] = replay_job_id # Set run context for SSE events emit.set_run_context( run_id=replay_job_id, parent_job_id=job_id, run_type="replay", ) # Run from start_stage onward pipeline = get_pipeline( checkpoint=checkpoint, profile_name=state["profile_name"], start_from=start_stage, ) try: result = pipeline.invoke(state) finally: emit.clear_run_context() return result def replay_single_stage( job_id: str, stage: str, frame_refs: list[int] | None = None, config_overrides: dict | None = None, debug: bool = False, ) -> dict: """ Replay a single stage on specific frames (or all frames from checkpoint). Fast path for interactive parameter tuning — runs only the target stage function, not the full pipeline tail. Returns the stage output directly. """ if stage not in NODES: raise ValueError(f"Unknown stage: {stage!r}. Options: {NODES}") stage_idx = NODES.index(stage) if stage_idx == 0: raise ValueError("Cannot replay the first stage — just run the full pipeline") logger.info("Single-stage replay: job %s, stage %s (debug=%s)", job_id, stage, debug) state = _build_state_for_replay(job_id, stage) # Build profile with overrides from core.detect.profile import get_profile, get_stage_config profile = get_profile(state.get("profile_name", "soccer_broadcast")) if config_overrides: merged_configs = dict(profile.get("configs", {})) for sname, soverrides in config_overrides.items(): if sname in merged_configs: merged_configs[sname] = {**merged_configs[sname], **soverrides} else: merged_configs[sname] = soverrides profile = {**profile, "configs": merged_configs} # Subset frames if requested frames = state.get("filtered_frames", state.get("frames", [])) if frame_refs: ref_set = set(frame_refs) frames = [f for f in frames if f.sequence in ref_set] # Run the specific stage if stage == "detect_edges": return _replay_detect_edges(state, profile, frames, job_id, debug) elif stage == "field_segmentation": return _replay_field_segmentation(state, profile, frames, job_id, debug) else: raise ValueError( f"Single-stage replay not yet implemented for {stage!r}. " f"Use replay_from() for full pipeline replay." ) def _replay_detect_edges( state: dict, profile, frames: list, job_id: str, debug: bool, ) -> dict: """Run edge detection on checkpoint frames, optionally with debug overlays.""" from core.detect.stages.edge_detector import detect_edge_regions from core.detect.profile import get_stage_config from core.detect.stages.models import RegionAnalysisConfig config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges")) inference_url = os.environ.get("INFERENCE_URL") field_masks = state.get("field_masks", {}) result = detect_edge_regions( frames=frames, config=config, inference_url=inference_url, job_id=job_id, field_masks=field_masks, ) output = {"edge_regions_by_frame": result} if debug and frames: debug_data = {} if inference_url: from core.detect.inference import InferenceClient client = InferenceClient(base_url=inference_url, job_id=job_id) for frame in frames: dr = client.detect_edges_debug( image=frame.image, edge_canny_low=config.edge_canny_low, edge_canny_high=config.edge_canny_high, edge_hough_threshold=config.edge_hough_threshold, edge_hough_min_length=config.edge_hough_min_length, edge_hough_max_gap=config.edge_hough_max_gap, edge_pair_max_distance=config.edge_pair_max_distance, edge_pair_min_distance=config.edge_pair_min_distance, ) debug_data[frame.sequence] = { "edge_overlay_b64": dr.edge_overlay_b64, "lines_overlay_b64": dr.lines_overlay_b64, "horizontal_count": dr.horizontal_count, "pair_count": dr.pair_count, } else: from core.detect.stages.edge_detector import _load_cv_edges edges_mod = _load_cv_edges() for frame in frames: dr = edges_mod.detect_edges_debug( frame.image, canny_low=config.edge_canny_low, canny_high=config.edge_canny_high, hough_threshold=config.edge_hough_threshold, hough_min_length=config.edge_hough_min_length, hough_max_gap=config.edge_hough_max_gap, pair_max_distance=config.edge_pair_max_distance, pair_min_distance=config.edge_pair_min_distance, ) debug_data[frame.sequence] = { "edge_overlay_b64": dr["edge_overlay_b64"], "lines_overlay_b64": dr["lines_overlay_b64"], "horizontal_count": dr["horizontal_count"], "pair_count": dr["pair_count"], } output["debug"] = debug_data return output def _replay_field_segmentation( state: dict, profile, frames: list, job_id: str, debug: bool, ) -> dict: """Run field segmentation on checkpoint frames.""" from core.detect.stages.field_segmentation import run_field_segmentation from core.detect.profile import get_stage_config from core.detect.stages.models import FieldSegmentationConfig config = FieldSegmentationConfig(**get_stage_config(profile, "field_segmentation")) inference_url = os.environ.get("INFERENCE_URL") result = run_field_segmentation( frames=frames, config=config, inference_url=inference_url, job_id=job_id, ) return result