phase 10
This commit is contained in:
123
tests/detect/manual/test_replay.py
Normal file
123
tests/detect/manual/test_replay.py
Normal file
@@ -0,0 +1,123 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test checkpoint + replay flow end-to-end.
|
||||
|
||||
1. Runs the pipeline with checkpointing enabled on a test video
|
||||
2. Lists available checkpoints
|
||||
3. Replays from run_ocr with different config
|
||||
4. Compares detection counts
|
||||
|
||||
Usage:
|
||||
MPR_CHECKPOINT=1 INFERENCE_URL=http://mcrndeb:8000 python tests/detect/manual/test_replay.py [--job JOB_ID]
|
||||
|
||||
Requires: inference server running, MinIO/S3 running, test video available
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Load ctrl/.env
|
||||
env_file = Path(__file__).resolve().parents[3] / "ctrl" / ".env"
|
||||
if env_file.exists():
|
||||
for line in env_file.read_text().splitlines():
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#") and "=" in line:
|
||||
key, _, value = line.partition("=")
|
||||
os.environ.setdefault(key.strip(), value.strip())
|
||||
|
||||
sys.path.insert(0, ".")
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Force checkpointing on
|
||||
os.environ["MPR_CHECKPOINT"] = "1"
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
import time
|
||||
default_job = f"replay-{int(time.time()) % 100000}"
|
||||
parser.add_argument("--job", default=default_job)
|
||||
parser.add_argument("--port", type=int, default=6382)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Override Redis to localhost (ctrl/.env has k8s hostname)
|
||||
os.environ["REDIS_URL"] = f"redis://localhost:{args.port}/0"
|
||||
|
||||
from detect.graph import get_pipeline, NODES
|
||||
from detect.checkpoint import list_checkpoints
|
||||
from detect.checkpoint import replay_from
|
||||
from detect.state import DetectState
|
||||
|
||||
VIDEO = "media/out/chunks/95043d50-4df6-4ac8-bbd5-2ba873117c6e/chunk_0000.mp4"
|
||||
|
||||
logger.info("Job: %s", args.job)
|
||||
logger.info("Checkpoint: enabled")
|
||||
logger.info("Video: %s", VIDEO)
|
||||
logger.info("Open: http://mpr.local.ar/detection/?job=%s", args.job)
|
||||
input("\nPress Enter to run initial pipeline...")
|
||||
|
||||
# --- Initial run ---
|
||||
pipeline = get_pipeline(checkpoint=True)
|
||||
initial_state = DetectState(
|
||||
video_path=VIDEO,
|
||||
job_id=args.job,
|
||||
profile_name="soccer_broadcast",
|
||||
)
|
||||
|
||||
logger.info("Running initial pipeline...")
|
||||
result = pipeline.invoke(initial_state)
|
||||
|
||||
detections = result.get("detections", [])
|
||||
report = result.get("report")
|
||||
logger.info("Initial run: %d detections, %d brands",
|
||||
len(detections), len(report.brands) if report else 0)
|
||||
|
||||
# --- List checkpoints ---
|
||||
stages = list_checkpoints(args.job)
|
||||
logger.info("Available checkpoints: %s", stages)
|
||||
|
||||
if "detect_objects" not in stages:
|
||||
logger.error("Expected checkpoint for detect_objects — aborting replay test")
|
||||
return
|
||||
|
||||
input("\nPress Enter to replay from run_ocr with different config...")
|
||||
|
||||
# --- Replay with different OCR config ---
|
||||
overrides = {"ocr": {"min_confidence": 0.3, "languages": ["en", "es"]}}
|
||||
logger.info("Replaying from run_ocr with overrides: %s", overrides)
|
||||
|
||||
replay_result = replay_from(
|
||||
job_id=args.job,
|
||||
start_stage="run_ocr",
|
||||
config_overrides=overrides,
|
||||
)
|
||||
|
||||
replay_detections = replay_result.get("detections", [])
|
||||
replay_report = replay_result.get("report")
|
||||
logger.info("Replay run: %d detections, %d brands",
|
||||
len(replay_detections),
|
||||
len(replay_report.brands) if replay_report else 0)
|
||||
|
||||
# --- Compare ---
|
||||
logger.info("--- Comparison ---")
|
||||
logger.info("Initial: %d detections", len(detections))
|
||||
logger.info("Replay: %d detections (min_confidence 0.5 → 0.3)", len(replay_detections))
|
||||
|
||||
diff = len(replay_detections) - len(detections)
|
||||
if diff > 0:
|
||||
logger.info("Replay found %d more detections with lower threshold", diff)
|
||||
elif diff == 0:
|
||||
logger.info("Same count — threshold change didn't affect this video")
|
||||
else:
|
||||
logger.warning("Replay found fewer detections — unexpected")
|
||||
|
||||
logger.info("Done.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user