308 lines
10 KiB
Python
308 lines
10 KiB
Python
"""
|
|
Pipeline replay — re-run from any stage with different config.
|
|
|
|
Loads stage outputs from DB, frames from timeline cache,
|
|
reconstitutes state, and runs from a target stage onward.
|
|
|
|
Creates a new Job (run_type=REPLAY) for each replay invocation.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import os
|
|
import uuid
|
|
|
|
from core.detect import emit
|
|
from core.detect.graph import NODES, get_pipeline
|
|
from core.detect.graph.runner import PipelineRunner
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _build_state_for_replay(
|
|
job_id: str,
|
|
up_to_stage: str,
|
|
) -> dict:
|
|
"""
|
|
Reconstitute pipeline state from a completed job's stage outputs,
|
|
up to (but not including) the target stage.
|
|
|
|
Loads frames from timeline cache + stage outputs from DB.
|
|
"""
|
|
from .storage import load_stage_outputs_for_job, get_checkpoints_for_job
|
|
from .frames import load_cached_frames
|
|
from core.db.connection import get_session
|
|
from core.db.job import get_job
|
|
|
|
# Load the job to get timeline_id and profile
|
|
with get_session() as session:
|
|
job = get_job(session, uuid.UUID(job_id))
|
|
if not job:
|
|
raise ValueError(f"Job not found: {job_id}")
|
|
|
|
timeline_id = str(job.timeline_id) if job.timeline_id else ""
|
|
if not timeline_id:
|
|
raise ValueError(f"Job {job_id} has no timeline")
|
|
|
|
# Load frames from timeline cache
|
|
frames = load_cached_frames(timeline_id)
|
|
if not frames:
|
|
raise ValueError(f"No cached frames for timeline {timeline_id}. Run the pipeline first.")
|
|
|
|
# Load all stage outputs for this job
|
|
all_outputs = load_stage_outputs_for_job(job_id)
|
|
|
|
# Build state with envelope + frames
|
|
state = {
|
|
"job_id": job_id,
|
|
"timeline_id": timeline_id,
|
|
"video_path": job.video_path,
|
|
"profile_name": job.profile_name,
|
|
"source_asset_id": str(job.source_asset_id),
|
|
"frames": frames,
|
|
"config_overrides": {},
|
|
}
|
|
|
|
# Apply stage outputs in pipeline order, up to the target stage
|
|
target_idx = NODES.index(up_to_stage)
|
|
for stage_name in NODES[:target_idx]:
|
|
output = all_outputs.get(stage_name)
|
|
if output:
|
|
# Stage outputs contain serialized data — merge into state
|
|
# The stage registry's deserialize_fn can reconstitute if needed
|
|
for key, value in output.items():
|
|
state[key] = value
|
|
|
|
# Filtered frames: reconstruct from sequence list if present
|
|
filtered_seqs = state.get("filtered_frame_sequences")
|
|
if filtered_seqs:
|
|
seq_set = set(filtered_seqs)
|
|
state["filtered_frames"] = [f for f in frames if f.sequence in seq_set]
|
|
elif "filtered_frames" not in state:
|
|
state["filtered_frames"] = frames
|
|
|
|
return state
|
|
|
|
|
|
def replay_from(
|
|
job_id: str,
|
|
start_stage: str,
|
|
config_overrides: dict | None = None,
|
|
checkpoint: bool = True,
|
|
) -> dict:
|
|
"""
|
|
Replay the pipeline from a specific stage.
|
|
|
|
Loads state from the original job's stage outputs up to start_stage,
|
|
applies config overrides, and runs from start_stage onward.
|
|
|
|
Creates a new Job (run_type=REPLAY).
|
|
Returns the final state dict.
|
|
"""
|
|
if start_stage not in NODES:
|
|
raise ValueError(f"Unknown stage: {start_stage!r}. Options: {NODES}")
|
|
|
|
start_idx = NODES.index(start_stage)
|
|
if start_idx == 0:
|
|
raise ValueError("Cannot replay from the first stage — just run the full pipeline")
|
|
|
|
logger.info("Replaying job %s from %s", job_id, start_stage)
|
|
|
|
state = _build_state_for_replay(job_id, start_stage)
|
|
|
|
# Apply config overrides
|
|
if config_overrides:
|
|
state["config_overrides"] = config_overrides
|
|
|
|
# Create replay job
|
|
from core.db.connection import get_session
|
|
from core.db.job import create_job, get_job
|
|
with get_session() as session:
|
|
original = get_job(session, uuid.UUID(job_id))
|
|
replay_job = create_job(
|
|
session,
|
|
source_asset_id=original.source_asset_id,
|
|
video_path=original.video_path,
|
|
timeline_id=original.timeline_id,
|
|
profile_name=original.profile_name,
|
|
run_type="replay",
|
|
parent_id=original.id,
|
|
config_overrides=config_overrides,
|
|
)
|
|
replay_job_id = str(replay_job.id)
|
|
|
|
# Update state with new job ID
|
|
state["job_id"] = replay_job_id
|
|
|
|
# Set run context for SSE events
|
|
emit.set_run_context(
|
|
run_id=replay_job_id,
|
|
parent_job_id=job_id,
|
|
run_type="replay",
|
|
)
|
|
|
|
# Run from start_stage onward
|
|
pipeline = get_pipeline(
|
|
checkpoint=checkpoint,
|
|
profile_name=state["profile_name"],
|
|
start_from=start_stage,
|
|
)
|
|
|
|
try:
|
|
result = pipeline.invoke(state)
|
|
finally:
|
|
emit.clear_run_context()
|
|
|
|
return result
|
|
|
|
|
|
def replay_single_stage(
|
|
job_id: str,
|
|
stage: str,
|
|
frame_refs: list[int] | None = None,
|
|
config_overrides: dict | None = None,
|
|
debug: bool = False,
|
|
) -> dict:
|
|
"""
|
|
Replay a single stage on specific frames (or all frames from checkpoint).
|
|
|
|
Fast path for interactive parameter tuning — runs only the target stage
|
|
function, not the full pipeline tail. Returns the stage output directly.
|
|
"""
|
|
if stage not in NODES:
|
|
raise ValueError(f"Unknown stage: {stage!r}. Options: {NODES}")
|
|
|
|
stage_idx = NODES.index(stage)
|
|
if stage_idx == 0:
|
|
raise ValueError("Cannot replay the first stage — just run the full pipeline")
|
|
|
|
logger.info("Single-stage replay: job %s, stage %s (debug=%s)", job_id, stage, debug)
|
|
|
|
state = _build_state_for_replay(job_id, stage)
|
|
|
|
# Build profile with overrides
|
|
from core.detect.profile import get_profile, get_stage_config
|
|
profile = get_profile(state.get("profile_name", "soccer_broadcast"))
|
|
if config_overrides:
|
|
merged_configs = dict(profile.get("configs", {}))
|
|
for sname, soverrides in config_overrides.items():
|
|
if sname in merged_configs:
|
|
merged_configs[sname] = {**merged_configs[sname], **soverrides}
|
|
else:
|
|
merged_configs[sname] = soverrides
|
|
profile = {**profile, "configs": merged_configs}
|
|
|
|
# Subset frames if requested
|
|
frames = state.get("filtered_frames", state.get("frames", []))
|
|
if frame_refs:
|
|
ref_set = set(frame_refs)
|
|
frames = [f for f in frames if f.sequence in ref_set]
|
|
|
|
# Run the specific stage
|
|
if stage == "detect_edges":
|
|
return _replay_detect_edges(state, profile, frames, job_id, debug)
|
|
elif stage == "field_segmentation":
|
|
return _replay_field_segmentation(state, profile, frames, job_id, debug)
|
|
else:
|
|
raise ValueError(
|
|
f"Single-stage replay not yet implemented for {stage!r}. "
|
|
f"Use replay_from() for full pipeline replay."
|
|
)
|
|
|
|
|
|
def _replay_detect_edges(
|
|
state: dict,
|
|
profile,
|
|
frames: list,
|
|
job_id: str,
|
|
debug: bool,
|
|
) -> dict:
|
|
"""Run edge detection on checkpoint frames, optionally with debug overlays."""
|
|
from core.detect.stages.edge_detector import detect_edge_regions
|
|
from core.detect.profile import get_stage_config
|
|
from core.detect.stages.models import RegionAnalysisConfig
|
|
|
|
config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
|
|
inference_url = os.environ.get("INFERENCE_URL")
|
|
field_masks = state.get("field_masks", {})
|
|
|
|
result = detect_edge_regions(
|
|
frames=frames,
|
|
config=config,
|
|
inference_url=inference_url,
|
|
job_id=job_id,
|
|
field_masks=field_masks,
|
|
)
|
|
output = {"edge_regions_by_frame": result}
|
|
|
|
if debug and frames:
|
|
debug_data = {}
|
|
if inference_url:
|
|
from core.detect.inference import InferenceClient
|
|
client = InferenceClient(base_url=inference_url, job_id=job_id)
|
|
for frame in frames:
|
|
dr = client.detect_edges_debug(
|
|
image=frame.image,
|
|
edge_canny_low=config.edge_canny_low,
|
|
edge_canny_high=config.edge_canny_high,
|
|
edge_hough_threshold=config.edge_hough_threshold,
|
|
edge_hough_min_length=config.edge_hough_min_length,
|
|
edge_hough_max_gap=config.edge_hough_max_gap,
|
|
edge_pair_max_distance=config.edge_pair_max_distance,
|
|
edge_pair_min_distance=config.edge_pair_min_distance,
|
|
)
|
|
debug_data[frame.sequence] = {
|
|
"edge_overlay_b64": dr.edge_overlay_b64,
|
|
"lines_overlay_b64": dr.lines_overlay_b64,
|
|
"horizontal_count": dr.horizontal_count,
|
|
"pair_count": dr.pair_count,
|
|
}
|
|
else:
|
|
from core.detect.stages.edge_detector import _load_cv_edges
|
|
edges_mod = _load_cv_edges()
|
|
for frame in frames:
|
|
dr = edges_mod.detect_edges_debug(
|
|
frame.image,
|
|
canny_low=config.edge_canny_low,
|
|
canny_high=config.edge_canny_high,
|
|
hough_threshold=config.edge_hough_threshold,
|
|
hough_min_length=config.edge_hough_min_length,
|
|
hough_max_gap=config.edge_hough_max_gap,
|
|
pair_max_distance=config.edge_pair_max_distance,
|
|
pair_min_distance=config.edge_pair_min_distance,
|
|
)
|
|
debug_data[frame.sequence] = {
|
|
"edge_overlay_b64": dr["edge_overlay_b64"],
|
|
"lines_overlay_b64": dr["lines_overlay_b64"],
|
|
"horizontal_count": dr["horizontal_count"],
|
|
"pair_count": dr["pair_count"],
|
|
}
|
|
output["debug"] = debug_data
|
|
|
|
return output
|
|
|
|
|
|
def _replay_field_segmentation(
|
|
state: dict,
|
|
profile,
|
|
frames: list,
|
|
job_id: str,
|
|
debug: bool,
|
|
) -> dict:
|
|
"""Run field segmentation on checkpoint frames."""
|
|
from core.detect.stages.field_segmentation import run_field_segmentation
|
|
from core.detect.profile import get_stage_config
|
|
from core.detect.stages.models import FieldSegmentationConfig
|
|
|
|
config = FieldSegmentationConfig(**get_stage_config(profile, "field_segmentation"))
|
|
inference_url = os.environ.get("INFERENCE_URL")
|
|
|
|
result = run_field_segmentation(
|
|
frames=frames,
|
|
config=config,
|
|
inference_url=inference_url,
|
|
job_id=job_id,
|
|
)
|
|
return result
|