phase 10
This commit is contained in:
132
detect/checkpoint/replay.py
Normal file
132
detect/checkpoint/replay.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""
|
||||
Pipeline replay — re-run from any stage with different config.
|
||||
|
||||
Loads a checkpoint, applies config overrides, builds a subgraph
|
||||
starting from the target stage, and invokes it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import uuid
|
||||
|
||||
from detect import emit
|
||||
from detect.checkpoint import load_checkpoint, list_checkpoints
|
||||
from detect.graph import NODES, build_graph
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OverrideProfile:
|
||||
"""
|
||||
Wraps a ContentTypeProfile and patches config methods with overrides.
|
||||
|
||||
Override dict structure:
|
||||
{
|
||||
"frame_extraction": {"fps": 1.0},
|
||||
"scene_filter": {"hamming_threshold": 12},
|
||||
"detection": {"confidence_threshold": 0.5},
|
||||
"ocr": {"languages": ["en", "es"], "min_confidence": 0.3},
|
||||
"resolver": {"fuzzy_threshold": 60},
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, base, overrides: dict):
|
||||
self._base = base
|
||||
self._overrides = overrides
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._base, name)
|
||||
|
||||
def _patch(self, config, key: str):
|
||||
patches = self._overrides.get(key, {})
|
||||
for k, v in patches.items():
|
||||
if hasattr(config, k):
|
||||
setattr(config, k, v)
|
||||
return config
|
||||
|
||||
def frame_extraction_config(self):
|
||||
return self._patch(self._base.frame_extraction_config(), "frame_extraction")
|
||||
|
||||
def scene_filter_config(self):
|
||||
return self._patch(self._base.scene_filter_config(), "scene_filter")
|
||||
|
||||
def detection_config(self):
|
||||
return self._patch(self._base.detection_config(), "detection")
|
||||
|
||||
def ocr_config(self):
|
||||
return self._patch(self._base.ocr_config(), "ocr")
|
||||
|
||||
def resolver_config(self):
|
||||
return self._patch(self._base.resolver_config(), "resolver")
|
||||
|
||||
def vlm_prompt(self, crop_context):
|
||||
return self._base.vlm_prompt(crop_context)
|
||||
|
||||
def aggregate(self, detections):
|
||||
return self._base.aggregate(detections)
|
||||
|
||||
def auxiliary_detections(self, source):
|
||||
return self._base.auxiliary_detections(source)
|
||||
|
||||
|
||||
def replay_from(
|
||||
job_id: str,
|
||||
start_stage: str,
|
||||
config_overrides: dict | None = None,
|
||||
checkpoint: bool = True,
|
||||
) -> dict:
|
||||
"""
|
||||
Replay the pipeline from a specific stage.
|
||||
|
||||
Loads the checkpoint from the stage immediately before start_stage,
|
||||
applies config overrides, and runs the subgraph from start_stage onward.
|
||||
|
||||
Returns the final state dict.
|
||||
"""
|
||||
if start_stage not in NODES:
|
||||
raise ValueError(f"Unknown stage: {start_stage!r}. Options: {NODES}")
|
||||
|
||||
start_idx = NODES.index(start_stage)
|
||||
|
||||
# Load checkpoint from the stage before start_stage
|
||||
if start_idx == 0:
|
||||
raise ValueError("Cannot replay from the first stage — just run the full pipeline")
|
||||
|
||||
previous_stage = NODES[start_idx - 1]
|
||||
|
||||
available = list_checkpoints(job_id)
|
||||
if previous_stage not in available:
|
||||
raise ValueError(
|
||||
f"No checkpoint for stage {previous_stage!r} (job {job_id}). "
|
||||
f"Available: {available}"
|
||||
)
|
||||
|
||||
logger.info("Replaying job %s from %s (loading checkpoint: %s)",
|
||||
job_id, start_stage, previous_stage)
|
||||
|
||||
state = load_checkpoint(job_id, previous_stage)
|
||||
|
||||
# Apply config overrides
|
||||
if config_overrides:
|
||||
state["config_overrides"] = config_overrides
|
||||
|
||||
# Set run context for SSE events
|
||||
run_id = str(uuid.uuid4())[:8]
|
||||
emit.set_run_context(
|
||||
run_id=run_id,
|
||||
parent_job_id=job_id,
|
||||
run_type="replay",
|
||||
)
|
||||
|
||||
# Build subgraph starting from start_stage
|
||||
graph = build_graph(checkpoint=checkpoint, start_from=start_stage)
|
||||
pipeline = graph.compile()
|
||||
|
||||
try:
|
||||
result = pipeline.invoke(state)
|
||||
finally:
|
||||
emit.clear_run_context()
|
||||
|
||||
return result
|
||||
Reference in New Issue
Block a user