72 lines
1.9 KiB
Python
72 lines
1.9 KiB
Python
"""
|
|
Celery tasks for detection pipeline async operations.
|
|
|
|
retry_candidates: re-run VLM/cloud escalation with different config.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
|
|
from celery import shared_task
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@shared_task(bind=True, max_retries=1, default_retry_delay=30)
|
|
def retry_candidates(
|
|
self,
|
|
job_id: str,
|
|
config_overrides: dict | None = None,
|
|
start_stage: str = "escalate_vlm",
|
|
):
|
|
"""
|
|
Retry unresolved candidates with different config.
|
|
|
|
Loads the checkpoint from the stage before start_stage,
|
|
applies config overrides (e.g. different cloud provider),
|
|
and runs from start_stage onward.
|
|
"""
|
|
from detect.checkpoint.replay import replay_from
|
|
|
|
run_id = str(uuid.uuid4())[:8]
|
|
logger.info("Retry task %s: job=%s, from=%s, overrides=%s",
|
|
run_id, job_id, start_stage, config_overrides)
|
|
|
|
try:
|
|
result = replay_from(
|
|
job_id=job_id,
|
|
start_stage=start_stage,
|
|
config_overrides=config_overrides,
|
|
)
|
|
|
|
detections = result.get("detections", [])
|
|
report = result.get("report")
|
|
brands_found = len(report.brands) if report else 0
|
|
|
|
logger.info("Retry %s complete: %d detections, %d brands",
|
|
run_id, len(detections), brands_found)
|
|
|
|
return {
|
|
"status": "completed",
|
|
"run_id": run_id,
|
|
"job_id": job_id,
|
|
"detections": len(detections),
|
|
"brands_found": brands_found,
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.exception("Retry %s failed: %s", run_id, e)
|
|
|
|
if self.request.retries < self.max_retries:
|
|
raise self.retry(exc=e)
|
|
|
|
return {
|
|
"status": "failed",
|
|
"run_id": run_id,
|
|
"job_id": job_id,
|
|
"error": str(e),
|
|
}
|