compare view

This commit is contained in:
2026-03-30 13:05:28 -03:00
parent aac27b8504
commit 55e83e4203
23 changed files with 1321 additions and 201 deletions

View File

@@ -1,32 +1,88 @@
"""
Pipeline replay — re-run from any stage with different config.
Loads a checkpoint, applies config overrides, builds a subgraph
starting from the target stage, and invokes it.
Loads stage outputs from DB, frames from timeline cache,
reconstitutes state, and runs from a target stage onward.
Creates a new Job (run_type=REPLAY) for each replay invocation.
"""
from __future__ import annotations
import logging
import os
import uuid
from core.detect import emit
# TODO: migrate to Timeline/Branch/Checkpoint model
# These old functions no longer exist — replay needs rework
def _not_migrated(*args, **kwargs):
raise NotImplementedError("Replay not yet migrated to Timeline/Branch/Checkpoint model")
load_checkpoint = _not_migrated
list_checkpoints = _not_migrated
from core.detect.graph import NODES, build_graph
from core.detect.graph import NODES, get_pipeline
from core.detect.graph.runner import PipelineRunner
logger = logging.getLogger(__name__)
def _build_state_for_replay(
job_id: str,
up_to_stage: str,
) -> dict:
"""
Reconstitute pipeline state from a completed job's stage outputs,
up to (but not including) the target stage.
# OverrideProfile removed — config overrides are now handled by dict merging
# in _load_profile() (nodes.py) and replay_single_stage (below).
Loads frames from timeline cache + stage outputs from DB.
"""
from .storage import load_stage_outputs_for_job, get_checkpoints_for_job
from .frames import load_cached_frames
from core.db.connection import get_session
from core.db.job import get_job
# Load the job to get timeline_id and profile
with get_session() as session:
job = get_job(session, uuid.UUID(job_id))
if not job:
raise ValueError(f"Job not found: {job_id}")
timeline_id = str(job.timeline_id) if job.timeline_id else ""
if not timeline_id:
raise ValueError(f"Job {job_id} has no timeline")
# Load frames from timeline cache
frames = load_cached_frames(timeline_id)
if not frames:
raise ValueError(f"No cached frames for timeline {timeline_id}. Run the pipeline first.")
# Load all stage outputs for this job
all_outputs = load_stage_outputs_for_job(job_id)
# Build state with envelope + frames
state = {
"job_id": job_id,
"timeline_id": timeline_id,
"video_path": job.video_path,
"profile_name": job.profile_name,
"source_asset_id": str(job.source_asset_id),
"frames": frames,
"config_overrides": {},
}
# Apply stage outputs in pipeline order, up to the target stage
target_idx = NODES.index(up_to_stage)
for stage_name in NODES[:target_idx]:
output = all_outputs.get(stage_name)
if output:
# Stage outputs contain serialized data — merge into state
# The stage registry's deserialize_fn can reconstitute if needed
for key, value in output.items():
state[key] = value
# Filtered frames: reconstruct from sequence list if present
filtered_seqs = state.get("filtered_frame_sequences")
if filtered_seqs:
seq_set = set(filtered_seqs)
state["filtered_frames"] = [f for f in frames if f.sequence in seq_set]
elif "filtered_frames" not in state:
state["filtered_frames"] = frames
return state
def replay_from(
@@ -38,49 +94,60 @@ def replay_from(
"""
Replay the pipeline from a specific stage.
Loads the checkpoint from the stage immediately before start_stage,
applies config overrides, and runs the subgraph from start_stage onward.
Loads state from the original job's stage outputs up to start_stage,
applies config overrides, and runs from start_stage onward.
Creates a new Job (run_type=REPLAY).
Returns the final state dict.
"""
if start_stage not in NODES:
raise ValueError(f"Unknown stage: {start_stage!r}. Options: {NODES}")
start_idx = NODES.index(start_stage)
# Load checkpoint from the stage before start_stage
if start_idx == 0:
raise ValueError("Cannot replay from the first stage — just run the full pipeline")
previous_stage = NODES[start_idx - 1]
logger.info("Replaying job %s from %s", job_id, start_stage)
available = list_checkpoints(job_id)
if previous_stage not in available:
raise ValueError(
f"No checkpoint for stage {previous_stage!r} (job {job_id}). "
f"Available: {available}"
)
logger.info("Replaying job %s from %s (loading checkpoint: %s)",
job_id, start_stage, previous_stage)
state = load_checkpoint(job_id, previous_stage)
state = _build_state_for_replay(job_id, start_stage)
# Apply config overrides
if config_overrides:
state["config_overrides"] = config_overrides
# Create replay job
from core.db.connection import get_session
from core.db.job import create_job, get_job
with get_session() as session:
original = get_job(session, uuid.UUID(job_id))
replay_job = create_job(
session,
source_asset_id=original.source_asset_id,
video_path=original.video_path,
timeline_id=original.timeline_id,
profile_name=original.profile_name,
run_type="replay",
parent_id=original.id,
config_overrides=config_overrides,
)
replay_job_id = str(replay_job.id)
# Update state with new job ID
state["job_id"] = replay_job_id
# Set run context for SSE events
run_id = str(uuid.uuid4())[:8]
emit.set_run_context(
run_id=run_id,
run_id=replay_job_id,
parent_job_id=job_id,
run_type="replay",
)
# Build subgraph starting from start_stage
graph = build_graph(checkpoint=checkpoint, start_from=start_stage)
pipeline = graph.compile()
# Run from start_stage onward
pipeline = get_pipeline(
checkpoint=checkpoint,
profile_name=state["profile_name"],
start_from=start_stage,
)
try:
result = pipeline.invoke(state)
@@ -102,12 +169,6 @@ def replay_single_stage(
Fast path for interactive parameter tuning — runs only the target stage
function, not the full pipeline tail. Returns the stage output directly.
When debug=True and stage is detect_edges, returns additional overlay
data (Canny edges, Hough lines) for visual feedback in the editor.
For detect_edges: returns {"edge_regions_by_frame": {seq: [box, ...]}}
With debug=True, also returns {"debug": {seq: {edge_overlay_b64, lines_overlay_b64, ...}}}
"""
if stage not in NODES:
raise ValueError(f"Unknown stage: {stage!r}. Options: {NODES}")
@@ -116,19 +177,9 @@ def replay_single_stage(
if stage_idx == 0:
raise ValueError("Cannot replay the first stage — just run the full pipeline")
previous_stage = NODES[stage_idx - 1]
logger.info("Single-stage replay: job %s, stage %s (debug=%s)", job_id, stage, debug)
available = list_checkpoints(job_id)
if previous_stage not in available:
raise ValueError(
f"No checkpoint for stage {previous_stage!r} (job {job_id}). "
f"Available: {available}"
)
logger.info("Single-stage replay: job %s, stage %s (loading checkpoint: %s, debug=%s)",
job_id, stage, previous_stage, debug)
state = load_checkpoint(job_id, previous_stage)
state = _build_state_for_replay(job_id, stage)
# Build profile with overrides
from core.detect.profile import get_profile, get_stage_config
@@ -142,9 +193,17 @@ def replay_single_stage(
merged_configs[sname] = soverrides
profile = {**profile, "configs": merged_configs}
# Run the stage function directly (not through the graph)
# Subset frames if requested
frames = state.get("filtered_frames", state.get("frames", []))
if frame_refs:
ref_set = set(frame_refs)
frames = [f for f in frames if f.sequence in ref_set]
# Run the specific stage
if stage == "detect_edges":
return _replay_detect_edges(state, profile, frame_refs, job_id, debug)
return _replay_detect_edges(state, profile, frames, job_id, debug)
elif stage == "field_segmentation":
return _replay_field_segmentation(state, profile, frames, job_id, debug)
else:
raise ValueError(
f"Single-stage replay not yet implemented for {stage!r}. "
@@ -155,35 +214,28 @@ def replay_single_stage(
def _replay_detect_edges(
state: dict,
profile,
frame_refs: list[int] | None,
frames: list,
job_id: str,
debug: bool,
) -> dict:
"""Run edge detection on checkpoint frames, optionally with debug overlays."""
import os
from core.detect.stages.edge_detector import detect_edge_regions
from core.detect.profile import get_stage_config
from core.detect.stages.models import RegionAnalysisConfig
config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
frames = state.get("filtered_frames", [])
if frame_refs:
ref_set = set(frame_refs)
frames = [f for f in frames if f.sequence in ref_set]
inference_url = os.environ.get("INFERENCE_URL")
field_masks = state.get("field_masks", {})
# Normal run — always needed for the boxes
result = detect_edge_regions(
frames=frames,
config=config,
inference_url=inference_url,
job_id=job_id,
field_masks=field_masks,
)
output = {"edge_regions_by_frame": result}
# Debug overlays — call debug endpoint (remote) or local debug function
if debug and frames:
debug_data = {}
if inference_url:
@@ -207,7 +259,6 @@ def _replay_detect_edges(
"pair_count": dr.pair_count,
}
else:
# Local mode — import GPU module directly
from core.detect.stages.edge_detector import _load_cv_edges
edges_mod = _load_cv_edges()
for frame in frames:
@@ -230,3 +281,27 @@ def _replay_detect_edges(
output["debug"] = debug_data
return output
def _replay_field_segmentation(
state: dict,
profile,
frames: list,
job_id: str,
debug: bool,
) -> dict:
"""Run field segmentation on checkpoint frames."""
from core.detect.stages.field_segmentation import run_field_segmentation
from core.detect.profile import get_stage_config
from core.detect.stages.models import FieldSegmentationConfig
config = FieldSegmentationConfig(**get_stage_config(profile, "field_segmentation"))
inference_url = os.environ.get("INFERENCE_URL")
result = run_field_segmentation(
frames=frames,
config=config,
inference_url=inference_url,
job_id=job_id,
)
return result