124 lines
4.0 KiB
Python
124 lines
4.0 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test checkpoint + replay flow end-to-end.
|
|
|
|
1. Runs the pipeline with checkpointing enabled on a test video
|
|
2. Lists available checkpoints
|
|
3. Replays from run_ocr with different config
|
|
4. Compares detection counts
|
|
|
|
Usage:
|
|
MPR_CHECKPOINT=1 INFERENCE_URL=http://mcrndeb:8000 python tests/detect/manual/test_replay.py [--job JOB_ID]
|
|
|
|
Requires: inference server running, MinIO/S3 running, test video available
|
|
"""
|
|
|
|
import argparse
|
|
import logging
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
# Load ctrl/.env
|
|
env_file = Path(__file__).resolve().parents[3] / "ctrl" / ".env"
|
|
if env_file.exists():
|
|
for line in env_file.read_text().splitlines():
|
|
line = line.strip()
|
|
if line and not line.startswith("#") and "=" in line:
|
|
key, _, value = line.partition("=")
|
|
os.environ.setdefault(key.strip(), value.strip())
|
|
|
|
sys.path.insert(0, ".")
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(message)s")
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Force checkpointing on
|
|
os.environ["MPR_CHECKPOINT"] = "1"
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
import time
|
|
default_job = f"replay-{int(time.time()) % 100000}"
|
|
parser.add_argument("--job", default=default_job)
|
|
parser.add_argument("--port", type=int, default=6382)
|
|
args = parser.parse_args()
|
|
|
|
# Override Redis to localhost (ctrl/.env has k8s hostname)
|
|
os.environ["REDIS_URL"] = f"redis://localhost:{args.port}/0"
|
|
|
|
from detect.graph import get_pipeline, NODES
|
|
from detect.checkpoint import list_checkpoints
|
|
from detect.checkpoint import replay_from
|
|
from detect.state import DetectState
|
|
|
|
VIDEO = "media/out/chunks/95043d50-4df6-4ac8-bbd5-2ba873117c6e/chunk_0000.mp4"
|
|
|
|
logger.info("Job: %s", args.job)
|
|
logger.info("Checkpoint: enabled")
|
|
logger.info("Video: %s", VIDEO)
|
|
logger.info("Open: http://mpr.local.ar/detection/?job=%s", args.job)
|
|
input("\nPress Enter to run initial pipeline...")
|
|
|
|
# --- Initial run ---
|
|
pipeline = get_pipeline(checkpoint=True)
|
|
initial_state = DetectState(
|
|
video_path=VIDEO,
|
|
job_id=args.job,
|
|
profile_name="soccer_broadcast",
|
|
)
|
|
|
|
logger.info("Running initial pipeline...")
|
|
result = pipeline.invoke(initial_state)
|
|
|
|
detections = result.get("detections", [])
|
|
report = result.get("report")
|
|
logger.info("Initial run: %d detections, %d brands",
|
|
len(detections), len(report.brands) if report else 0)
|
|
|
|
# --- List checkpoints ---
|
|
stages = list_checkpoints(args.job)
|
|
logger.info("Available checkpoints: %s", stages)
|
|
|
|
if "detect_objects" not in stages:
|
|
logger.error("Expected checkpoint for detect_objects — aborting replay test")
|
|
return
|
|
|
|
input("\nPress Enter to replay from run_ocr with different config...")
|
|
|
|
# --- Replay with different OCR config ---
|
|
overrides = {"ocr": {"min_confidence": 0.3, "languages": ["en", "es"]}}
|
|
logger.info("Replaying from run_ocr with overrides: %s", overrides)
|
|
|
|
replay_result = replay_from(
|
|
job_id=args.job,
|
|
start_stage="run_ocr",
|
|
config_overrides=overrides,
|
|
)
|
|
|
|
replay_detections = replay_result.get("detections", [])
|
|
replay_report = replay_result.get("report")
|
|
logger.info("Replay run: %d detections, %d brands",
|
|
len(replay_detections),
|
|
len(replay_report.brands) if replay_report else 0)
|
|
|
|
# --- Compare ---
|
|
logger.info("--- Comparison ---")
|
|
logger.info("Initial: %d detections", len(detections))
|
|
logger.info("Replay: %d detections (min_confidence 0.5 → 0.3)", len(replay_detections))
|
|
|
|
diff = len(replay_detections) - len(detections)
|
|
if diff > 0:
|
|
logger.info("Replay found %d more detections with lower threshold", diff)
|
|
elif diff == 0:
|
|
logger.info("Same count — threshold change didn't affect this video")
|
|
else:
|
|
logger.warning("Replay found fewer detections — unexpected")
|
|
|
|
logger.info("Done.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|