refactor stage 1
This commit is contained in:
@@ -9,7 +9,8 @@ from sqlmodel import select
|
||||
|
||||
from .connection import get_session
|
||||
from .models import (
|
||||
DetectJob, StageCheckpoint, KnownBrand, SourceBrandSighting,
|
||||
DetectJob, Timeline, Checkpoint,
|
||||
KnownBrand, SourceBrandSighting,
|
||||
)
|
||||
|
||||
|
||||
@@ -55,72 +56,86 @@ def list_detect_jobs(
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# StageCheckpoint
|
||||
# Timeline
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def save_stage_checkpoint(**fields) -> StageCheckpoint:
|
||||
def create_timeline(**fields) -> Timeline:
|
||||
timeline = Timeline(**fields)
|
||||
with get_session() as session:
|
||||
# Upsert: replace if same job_id + stage
|
||||
job_id = fields.get("job_id")
|
||||
stage = fields.get("stage")
|
||||
if job_id and stage:
|
||||
stmt = select(StageCheckpoint).where(
|
||||
StageCheckpoint.job_id == job_id,
|
||||
StageCheckpoint.stage == stage,
|
||||
)
|
||||
existing = session.exec(stmt).first()
|
||||
if existing:
|
||||
for k, v in fields.items():
|
||||
setattr(existing, k, v)
|
||||
session.commit()
|
||||
session.refresh(existing)
|
||||
return existing
|
||||
session.add(timeline)
|
||||
session.commit()
|
||||
session.refresh(timeline)
|
||||
return timeline
|
||||
|
||||
checkpoint = StageCheckpoint(**fields)
|
||||
|
||||
def get_timeline(timeline_id: UUID) -> Timeline | None:
|
||||
with get_session() as session:
|
||||
return session.get(Timeline, timeline_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Checkpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def save_checkpoint(**fields) -> Checkpoint:
|
||||
checkpoint = Checkpoint(**fields)
|
||||
with get_session() as session:
|
||||
session.add(checkpoint)
|
||||
session.commit()
|
||||
session.refresh(checkpoint)
|
||||
return checkpoint
|
||||
|
||||
|
||||
def get_stage_checkpoint(job_id: UUID, stage: str) -> StageCheckpoint | None:
|
||||
def get_checkpoint(checkpoint_id: UUID) -> Checkpoint | None:
|
||||
with get_session() as session:
|
||||
stmt = select(StageCheckpoint).where(
|
||||
StageCheckpoint.job_id == job_id,
|
||||
StageCheckpoint.stage == stage,
|
||||
return session.get(Checkpoint, checkpoint_id)
|
||||
|
||||
|
||||
def get_latest_checkpoint(timeline_id: UUID, parent_id: UUID | None = None) -> Checkpoint | None:
|
||||
"""Get the most recent checkpoint for a timeline, optionally from a specific parent."""
|
||||
with get_session() as session:
|
||||
stmt = (
|
||||
select(Checkpoint)
|
||||
.where(Checkpoint.timeline_id == timeline_id)
|
||||
)
|
||||
if parent_id is not None:
|
||||
stmt = stmt.where(Checkpoint.parent_id == parent_id)
|
||||
stmt = stmt.order_by(Checkpoint.created_at.desc())
|
||||
return session.exec(stmt).first()
|
||||
|
||||
|
||||
def list_checkpoints(timeline_id: UUID) -> list[Checkpoint]:
|
||||
"""List all checkpoints for a timeline."""
|
||||
with get_session() as session:
|
||||
stmt = (
|
||||
select(Checkpoint)
|
||||
.where(Checkpoint.timeline_id == timeline_id)
|
||||
.order_by(Checkpoint.created_at)
|
||||
)
|
||||
return list(session.exec(stmt).all())
|
||||
|
||||
|
||||
def get_root_checkpoint(timeline_id: UUID) -> Checkpoint | None:
|
||||
"""Get the root checkpoint (no parent) for a timeline."""
|
||||
with get_session() as session:
|
||||
stmt = select(Checkpoint).where(
|
||||
Checkpoint.timeline_id == timeline_id,
|
||||
Checkpoint.parent_id == None,
|
||||
)
|
||||
return session.exec(stmt).first()
|
||||
|
||||
|
||||
def list_stage_checkpoints(job_id: UUID) -> list[str]:
|
||||
with get_session() as session:
|
||||
stmt = (
|
||||
select(StageCheckpoint.stage)
|
||||
.where(StageCheckpoint.job_id == job_id)
|
||||
.order_by(StageCheckpoint.stage_index)
|
||||
)
|
||||
return list(session.exec(stmt).all())
|
||||
|
||||
|
||||
def list_scenarios() -> list[StageCheckpoint]:
|
||||
def list_scenarios() -> list[Checkpoint]:
|
||||
"""List all checkpoints marked as scenarios."""
|
||||
with get_session() as session:
|
||||
stmt = (
|
||||
select(StageCheckpoint)
|
||||
.where(StageCheckpoint.is_scenario == True)
|
||||
.order_by(StageCheckpoint.created_at.desc())
|
||||
select(Checkpoint)
|
||||
.where(Checkpoint.is_scenario == True)
|
||||
.order_by(Checkpoint.created_at.desc())
|
||||
)
|
||||
return list(session.exec(stmt).all())
|
||||
|
||||
|
||||
def delete_stage_checkpoints(job_id: UUID) -> None:
|
||||
with get_session() as session:
|
||||
stmt = select(StageCheckpoint).where(StageCheckpoint.job_id == job_id)
|
||||
for cp in session.exec(stmt).all():
|
||||
session.delete(cp)
|
||||
session.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# KnownBrand
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -181,24 +181,30 @@ class DetectJob(SQLModel, table=True):
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
class StageCheckpoint(SQLModel, table=True):
|
||||
"""A checkpoint saved after a pipeline stage completes."""
|
||||
__tablename__ = "stage_checkpoints"
|
||||
class Timeline(SQLModel, table=True):
|
||||
"""Frame sequence from a source video. Independent of stages."""
|
||||
__tablename__ = "timelines"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
job_id: UUID = Field(index=True)
|
||||
stage: str
|
||||
stage_index: int
|
||||
source_asset_id: Optional[UUID] = Field(default=None, index=True)
|
||||
source_video: str = ""
|
||||
profile_name: str = ""
|
||||
fps: float = 2.0
|
||||
frames_prefix: str = ""
|
||||
frames_manifest: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||
frames_meta: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
|
||||
filtered_frame_sequences: List[int] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
|
||||
stage_output_key: str = "" # s3 key: checkpoints/{job_id}/stages/{stage}.bson
|
||||
stats: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||
config_snapshot: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
||||
|
||||
class Checkpoint(SQLModel, table=True):
|
||||
"""Snapshot of pipeline state. parent_id forms a tree."""
|
||||
__tablename__ = "checkpoints"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
timeline_id: UUID = Field(index=True)
|
||||
parent_id: Optional[UUID] = Field(default=None, index=True)
|
||||
stage_outputs: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||
config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||
video_path: str = ""
|
||||
profile_name: str = ""
|
||||
stats: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||
is_scenario: bool = False
|
||||
scenario_label: str = ""
|
||||
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
||||
|
||||
Reference in New Issue
Block a user