major refactor
This commit is contained in:
@@ -11,7 +11,7 @@ that don't belong to any stage.
|
||||
from __future__ import annotations
|
||||
|
||||
from core.schema.serializers._common import serialize_dataclass
|
||||
from core.schema.serializers.detect_pipeline import (
|
||||
from core.schema.serializers.pipeline import (
|
||||
deserialize_pipeline_stats,
|
||||
deserialize_text_candidates,
|
||||
)
|
||||
|
||||
@@ -33,83 +33,81 @@ def create_timeline(
|
||||
|
||||
Returns (timeline_id, checkpoint_id).
|
||||
"""
|
||||
from core.db.detect import create_timeline as db_create_timeline
|
||||
from core.db.detect import save_checkpoint
|
||||
|
||||
# Create timeline
|
||||
timeline = db_create_timeline(
|
||||
source_video=source_video,
|
||||
profile_name=profile_name,
|
||||
source_asset_id=source_asset_id,
|
||||
fps=fps,
|
||||
)
|
||||
tid = str(timeline.id)
|
||||
|
||||
# Upload frames to MinIO
|
||||
manifest = save_frames(tid, frames)
|
||||
|
||||
# Store frame metadata on the timeline
|
||||
frames_meta = [
|
||||
{
|
||||
"sequence": f.sequence,
|
||||
"chunk_id": getattr(f, "chunk_id", 0),
|
||||
"timestamp": f.timestamp,
|
||||
"perceptual_hash": getattr(f, "perceptual_hash", ""),
|
||||
}
|
||||
for f in frames
|
||||
]
|
||||
|
||||
timeline.frames_prefix = f"{CHECKPOINT_PREFIX}/{tid}/frames/"
|
||||
timeline.frames_manifest = {str(k): v for k, v in manifest.items()}
|
||||
timeline.frames_meta = frames_meta
|
||||
|
||||
from core.db.tables import Timeline, Checkpoint
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
timeline = Timeline(
|
||||
source_video=source_video,
|
||||
profile_name=profile_name,
|
||||
source_asset_id=source_asset_id,
|
||||
fps=fps,
|
||||
)
|
||||
session.add(timeline)
|
||||
session.flush()
|
||||
tid = str(timeline.id)
|
||||
|
||||
# Upload frames to MinIO
|
||||
manifest = save_frames(tid, frames)
|
||||
|
||||
frames_meta = [
|
||||
{
|
||||
"sequence": f.sequence,
|
||||
"chunk_id": getattr(f, "chunk_id", 0),
|
||||
"timestamp": f.timestamp,
|
||||
"perceptual_hash": getattr(f, "perceptual_hash", ""),
|
||||
}
|
||||
for f in frames
|
||||
]
|
||||
|
||||
timeline.frames_prefix = f"{CHECKPOINT_PREFIX}/{tid}/frames/"
|
||||
timeline.frames_manifest = {str(k): v for k, v in manifest.items()}
|
||||
timeline.frames_meta = frames_meta
|
||||
|
||||
checkpoint = Checkpoint(
|
||||
timeline_id=timeline.id,
|
||||
parent_id=None,
|
||||
stage_outputs={},
|
||||
stats={"frames_extracted": len(frames)},
|
||||
)
|
||||
session.add(checkpoint)
|
||||
session.commit()
|
||||
session.refresh(checkpoint)
|
||||
cid = str(checkpoint.id)
|
||||
|
||||
# Create root checkpoint (no parent, no stage outputs yet)
|
||||
checkpoint = save_checkpoint(
|
||||
timeline_id=timeline.id,
|
||||
parent_id=None,
|
||||
stage_outputs={},
|
||||
stats={"frames_extracted": len(frames)},
|
||||
)
|
||||
|
||||
logger.info("Timeline created: %s (%d frames, root checkpoint %s)",
|
||||
tid, len(frames), checkpoint.id)
|
||||
return tid, str(checkpoint.id)
|
||||
logger.info("Timeline created: %s (%d frames, root checkpoint %s)", tid, len(frames), cid)
|
||||
return tid, cid
|
||||
|
||||
|
||||
def get_timeline_frames(timeline_id: str) -> list:
|
||||
"""Load frames from a timeline (from MinIO) as Frame objects."""
|
||||
from core.db.detect import get_timeline
|
||||
from core.db.tables import Timeline
|
||||
from core.db.connection import get_session
|
||||
|
||||
timeline = get_timeline(timeline_id)
|
||||
with get_session() as session:
|
||||
timeline = session.get(Timeline, UUID(timeline_id))
|
||||
if not timeline:
|
||||
raise ValueError(f"Timeline not found: {timeline_id}")
|
||||
|
||||
raw_manifest = timeline.frames_manifest or {}
|
||||
manifest = {int(k): v for k, v in raw_manifest.items()}
|
||||
frame_metadata = timeline.frames_meta or []
|
||||
|
||||
return load_frames(manifest, frame_metadata)
|
||||
return load_frames(manifest, timeline.frames_meta or [])
|
||||
|
||||
|
||||
def get_timeline_frames_b64(timeline_id: str) -> list[dict]:
|
||||
"""Load frames as base64 JPEG (lightweight, no numpy)."""
|
||||
from core.db.detect import get_timeline
|
||||
from core.db.tables import Timeline
|
||||
from core.db.connection import get_session
|
||||
from .frames import load_frames_b64
|
||||
|
||||
timeline = get_timeline(timeline_id)
|
||||
with get_session() as session:
|
||||
timeline = session.get(Timeline, UUID(timeline_id))
|
||||
if not timeline:
|
||||
raise ValueError(f"Timeline not found: {timeline_id}")
|
||||
|
||||
raw_manifest = timeline.frames_manifest or {}
|
||||
manifest = {int(k): v for k, v in raw_manifest.items()}
|
||||
frame_metadata = timeline.frames_meta or []
|
||||
|
||||
return load_frames_b64(manifest, frame_metadata)
|
||||
return load_frames_b64(manifest, timeline.frames_meta or [])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -132,47 +130,46 @@ def save_stage_output(
|
||||
Carries forward stage outputs from parent + adds the new one.
|
||||
Returns the new checkpoint ID.
|
||||
"""
|
||||
from core.db.detect import get_checkpoint, save_checkpoint
|
||||
from core.db.tables import Checkpoint
|
||||
from core.db.connection import get_session
|
||||
|
||||
# Carry forward from parent
|
||||
parent_outputs = {}
|
||||
parent_stats = {}
|
||||
parent_config = {}
|
||||
if parent_checkpoint_id:
|
||||
parent = get_checkpoint(parent_checkpoint_id)
|
||||
if parent:
|
||||
parent_outputs = dict(parent.stage_outputs or {})
|
||||
parent_stats = dict(parent.stats or {})
|
||||
parent_config = dict(parent.config_overrides or {})
|
||||
with get_session() as session:
|
||||
parent_outputs = {}
|
||||
parent_stats = {}
|
||||
parent_config = {}
|
||||
if parent_checkpoint_id:
|
||||
parent = session.get(Checkpoint, UUID(parent_checkpoint_id))
|
||||
if parent:
|
||||
parent_outputs = dict(parent.stage_outputs or {})
|
||||
parent_stats = dict(parent.stats or {})
|
||||
parent_config = dict(parent.config_overrides or {})
|
||||
|
||||
# Add new stage output
|
||||
stage_outputs = {**parent_outputs, stage_name: output_json}
|
||||
|
||||
# Merge stats and config
|
||||
merged_stats = {**parent_stats, **(stats or {})}
|
||||
merged_config = {**parent_config, **(config_overrides or {})}
|
||||
|
||||
checkpoint = save_checkpoint(
|
||||
timeline_id=timeline_id,
|
||||
parent_id=parent_checkpoint_id,
|
||||
stage_outputs=stage_outputs,
|
||||
config_overrides=merged_config,
|
||||
stats=merged_stats,
|
||||
is_scenario=is_scenario,
|
||||
scenario_label=scenario_label,
|
||||
)
|
||||
checkpoint = Checkpoint(
|
||||
timeline_id=UUID(timeline_id),
|
||||
parent_id=UUID(parent_checkpoint_id) if parent_checkpoint_id else None,
|
||||
stage_outputs={**parent_outputs, stage_name: output_json},
|
||||
config_overrides={**parent_config, **(config_overrides or {})},
|
||||
stats={**parent_stats, **(stats or {})},
|
||||
is_scenario=is_scenario,
|
||||
scenario_label=scenario_label,
|
||||
)
|
||||
session.add(checkpoint)
|
||||
session.commit()
|
||||
session.refresh(checkpoint)
|
||||
cid = str(checkpoint.id)
|
||||
|
||||
logger.info("Checkpoint saved: %s (timeline %s, stage %s, parent %s)",
|
||||
checkpoint.id, timeline_id, stage_name, parent_checkpoint_id)
|
||||
return str(checkpoint.id)
|
||||
cid, timeline_id, stage_name, parent_checkpoint_id)
|
||||
return cid
|
||||
|
||||
|
||||
def load_stage_output(checkpoint_id: str, stage_name: str) -> dict | None:
|
||||
"""Load a stage's output from a checkpoint."""
|
||||
from core.db.detect import get_checkpoint
|
||||
from core.db.tables import Checkpoint
|
||||
from core.db.connection import get_session
|
||||
|
||||
checkpoint = get_checkpoint(checkpoint_id)
|
||||
with get_session() as session:
|
||||
checkpoint = session.get(Checkpoint, UUID(checkpoint_id))
|
||||
if not checkpoint:
|
||||
return None
|
||||
|
||||
return (checkpoint.stage_outputs or {}).get(stage_name)
|
||||
|
||||
@@ -326,6 +326,7 @@ def node_compile_report(state: DetectState) -> dict:
|
||||
|
||||
_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1"
|
||||
_frames_manifest: dict[str, dict[int, str]] = {} # job_id → manifest (cached per job)
|
||||
_latest_checkpoint: dict[str, str] = {} # job_id → latest checkpoint_id
|
||||
|
||||
|
||||
class PipelineCancelled(Exception):
|
||||
@@ -361,17 +362,33 @@ def _checkpointing_node(node_name: str, node_fn):
|
||||
if not job_id:
|
||||
return result
|
||||
|
||||
from detect.checkpoint import save_checkpoint, save_frames
|
||||
from detect.checkpoint import save_stage_output, save_frames
|
||||
from detect.stages.base import _REGISTRY
|
||||
|
||||
merged = {**state, **result}
|
||||
|
||||
# Save frames once (first checkpoint), reuse manifest after
|
||||
# Save frames once (first node), reuse manifest after
|
||||
manifest = _frames_manifest.get(job_id)
|
||||
if manifest is None and node_name == "extract_frames":
|
||||
manifest = save_frames(job_id, merged.get("frames", []))
|
||||
_frames_manifest[job_id] = manifest
|
||||
|
||||
save_checkpoint(job_id, node_name, stage_index, merged, frames_manifest=manifest)
|
||||
# Serialize stage output using the stage's serialize_fn if available
|
||||
stage_cls = _REGISTRY.get(node_name)
|
||||
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
|
||||
if serialize_fn:
|
||||
output_json = serialize_fn(merged, job_id)
|
||||
else:
|
||||
output_json = {}
|
||||
|
||||
parent_id = _latest_checkpoint.get(job_id)
|
||||
new_checkpoint_id = save_stage_output(
|
||||
timeline_id=job_id,
|
||||
parent_checkpoint_id=parent_id,
|
||||
stage_name=node_name,
|
||||
output_json=output_json,
|
||||
)
|
||||
_latest_checkpoint[job_id] = new_checkpoint_id
|
||||
return result
|
||||
|
||||
wrapper.__name__ = node_fn.__name__
|
||||
|
||||
@@ -1,11 +1,6 @@
|
||||
"""
|
||||
Re-export pipeline runtime models from core/schema/models/detect_pipeline.py.
|
||||
"""Re-export pipeline runtime models from core/schema/models/pipeline.py."""
|
||||
|
||||
All models are defined in core/schema/ — this module exists for backward
|
||||
compatibility so existing imports (from detect.models import Frame) keep working.
|
||||
"""
|
||||
|
||||
from core.schema.models.detect_pipeline import (
|
||||
from core.schema.models.pipeline import (
|
||||
BoundingBox,
|
||||
BrandDetection,
|
||||
BrandStats,
|
||||
|
||||
@@ -4,14 +4,10 @@ Stage 5 — Brand Resolver (discovery mode)
|
||||
Discovery-first brand matching. No static dictionary — all brands live in the DB.
|
||||
|
||||
Flow:
|
||||
1. Check session sightings first (brands already seen in this source)
|
||||
1. Check session brands first (brands already seen in this run, in-memory)
|
||||
2. Check global known brands (accumulated across all runs)
|
||||
3. Unresolved candidates → escalate to VLM/cloud
|
||||
4. Confirmed brands get added to DB for future runs
|
||||
|
||||
The resolver is an enricher, not a gatekeeper. Every OCR text candidate
|
||||
passes through — the question is whether we can resolve it cheaply (DB lookup)
|
||||
or need to escalate (VLM/cloud).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -33,41 +29,30 @@ def _normalize(text: str) -> str:
|
||||
|
||||
def _has_db() -> bool:
|
||||
try:
|
||||
from core.db.detect import find_brand_by_text as _
|
||||
from admin.mpr.media_assets.models import KnownBrand as _
|
||||
from core.db import find_brand_by_text as _
|
||||
return True
|
||||
except (ImportError, Exception):
|
||||
return False
|
||||
|
||||
|
||||
def _match_session(text: str, session_brands: dict[str, str]) -> str | None:
|
||||
"""
|
||||
Check against session brands (already seen in this source).
|
||||
|
||||
session_brands: {normalized_name: canonical_name, ...}
|
||||
Includes aliases.
|
||||
"""
|
||||
normalized = _normalize(text)
|
||||
return session_brands.get(normalized)
|
||||
return session_brands.get(_normalize(text))
|
||||
|
||||
|
||||
def _match_known(text: str, threshold: int) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
Check against global known brands in DB.
|
||||
|
||||
Returns (canonical_name, brand_id) or (None, None).
|
||||
"""
|
||||
"""Check against global known brands in DB. Returns (canonical_name, brand_id) or (None, None)."""
|
||||
if not _has_db():
|
||||
return None, None
|
||||
|
||||
from core.db.detect import find_brand_by_text
|
||||
brand = find_brand_by_text(text)
|
||||
if brand:
|
||||
return brand.canonical_name, str(brand.id)
|
||||
from core.db import find_brand_by_text, list_brands
|
||||
from core.db.connection import get_session
|
||||
|
||||
# Fuzzy match against all known brands
|
||||
from core.db.detect import list_all_brands
|
||||
all_brands = list_all_brands()
|
||||
with get_session() as session:
|
||||
brand = find_brand_by_text(session, text)
|
||||
if brand:
|
||||
return brand.canonical_name, str(brand.id)
|
||||
|
||||
all_brands = list_brands(session)
|
||||
|
||||
normalized = _normalize(text)
|
||||
best_brand = None
|
||||
@@ -92,58 +77,62 @@ def _register_brand(canonical_name: str, source: str) -> str | None:
|
||||
if not _has_db():
|
||||
return None
|
||||
|
||||
from core.db.detect import get_or_create_brand
|
||||
brand, created = get_or_create_brand(canonical_name, source=source)
|
||||
from core.db import get_or_create_brand
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
brand, created = get_or_create_brand(session, canonical_name, source=source)
|
||||
session.commit()
|
||||
if created:
|
||||
logger.info("New brand discovered: %s (source=%s)", canonical_name, source)
|
||||
return str(brand.id)
|
||||
|
||||
|
||||
def _record_sighting(source_asset_id: str | None, brand_id: str,
|
||||
brand_name: str, timestamp: float,
|
||||
confidence: float, source: str):
|
||||
"""Record a brand sighting for this source."""
|
||||
if not _has_db() or not source_asset_id:
|
||||
def _record_airing(timeline_id: str | None, brand_id: str,
|
||||
frame_seq: int, confidence: float, source: str):
|
||||
"""Record a brand airing on a timeline."""
|
||||
if not _has_db() or not timeline_id:
|
||||
return
|
||||
|
||||
from core.db.detect import record_sighting
|
||||
import uuid
|
||||
asset_id = uuid.UUID(source_asset_id) if isinstance(source_asset_id, str) else source_asset_id
|
||||
brand_uuid = uuid.UUID(brand_id) if isinstance(brand_id, str) else brand_id
|
||||
record_sighting(asset_id, brand_uuid, brand_name, timestamp, confidence, source)
|
||||
from core.db import record_airing
|
||||
from core.db.connection import get_session
|
||||
from uuid import UUID
|
||||
|
||||
with get_session() as session:
|
||||
record_airing(
|
||||
session,
|
||||
brand_id=UUID(brand_id),
|
||||
timeline_id=UUID(timeline_id),
|
||||
frame_start=frame_seq,
|
||||
frame_end=frame_seq,
|
||||
confidence=confidence,
|
||||
source=source,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
|
||||
def build_session_dict(source_asset_id: str | None) -> dict[str, str]:
|
||||
def build_session_dict(source_asset_id: str | None = None) -> dict[str, str]:
|
||||
"""
|
||||
Load session brands from DB for this source.
|
||||
Load known brands from DB as a session lookup dict.
|
||||
|
||||
Returns {normalized_name: canonical_name, ...} including aliases.
|
||||
"""
|
||||
if not _has_db() or not source_asset_id:
|
||||
if not _has_db():
|
||||
return {}
|
||||
|
||||
from core.db.detect import get_source_sightings
|
||||
import uuid
|
||||
from core.db import list_brands
|
||||
from core.db.connection import get_session
|
||||
|
||||
asset_id = uuid.UUID(source_asset_id) if isinstance(source_asset_id, str) else source_asset_id
|
||||
sightings = get_source_sightings(asset_id)
|
||||
with get_session() as session:
|
||||
all_brands = list_brands(session)
|
||||
|
||||
session = {}
|
||||
for s in sightings:
|
||||
canonical = s.brand_name
|
||||
session[_normalize(canonical)] = canonical
|
||||
session_dict = {}
|
||||
for brand in all_brands:
|
||||
session_dict[_normalize(brand.canonical_name)] = brand.canonical_name
|
||||
for alias in (brand.aliases or []):
|
||||
session_dict[_normalize(alias)] = brand.canonical_name
|
||||
|
||||
# Also load aliases from KnownBrand for each sighted brand
|
||||
if _has_db():
|
||||
from core.db.detect import list_all_brands
|
||||
all_brands = list_all_brands()
|
||||
sighted_names = {s.brand_name for s in sightings}
|
||||
for brand in all_brands:
|
||||
if brand.canonical_name in sighted_names:
|
||||
for alias in (brand.aliases or []):
|
||||
session[_normalize(alias)] = brand.canonical_name
|
||||
|
||||
return session
|
||||
return session_dict
|
||||
|
||||
|
||||
def resolve_brands(
|
||||
@@ -158,7 +147,7 @@ def resolve_brands(
|
||||
Match text candidates against known brands (session → global → unresolved).
|
||||
|
||||
session_brands: pre-loaded session dict (from build_session_dict)
|
||||
source_asset_id: for recording new sightings in DB
|
||||
job_id: timeline_id — used to record airings
|
||||
"""
|
||||
if session_brands is None:
|
||||
session_brands = {}
|
||||
@@ -187,7 +176,6 @@ def resolve_brands(
|
||||
brand_name, brand_id = _match_known(text, config.fuzzy_threshold)
|
||||
if brand_name:
|
||||
known_hits += 1
|
||||
# Add to session for subsequent candidates in this run
|
||||
session_brands[_normalize(brand_name)] = brand_name
|
||||
|
||||
if brand_name:
|
||||
@@ -203,11 +191,10 @@ def resolve_brands(
|
||||
)
|
||||
matched.append(detection)
|
||||
|
||||
# Record sighting in DB
|
||||
if brand_id:
|
||||
_record_sighting(
|
||||
source_asset_id, brand_id, brand_name,
|
||||
candidate.frame.timestamp, candidate.ocr_confidence, match_source,
|
||||
_record_airing(
|
||||
job_id, brand_id,
|
||||
candidate.frame.sequence, candidate.ocr_confidence, match_source,
|
||||
)
|
||||
|
||||
emit.detection(
|
||||
|
||||
@@ -10,7 +10,7 @@ from core.schema.serializers._common import (
|
||||
serialize_dataclass,
|
||||
serialize_dataclass_list,
|
||||
)
|
||||
from core.schema.serializers.detect_pipeline import (
|
||||
from core.schema.serializers.pipeline import (
|
||||
serialize_frame_meta,
|
||||
serialize_frames_with_upload as serialize_frames,
|
||||
deserialize_frames_with_download as deserialize_frames,
|
||||
|
||||
Reference in New Issue
Block a user