compare view

This commit is contained in:
2026-03-30 13:05:28 -03:00
parent aac27b8504
commit 55e83e4203
23 changed files with 1321 additions and 201 deletions

View File

@@ -30,12 +30,21 @@ class ConfigUpdate(BaseModel):
preprocessing: dict | None = None
class StageOutputHintInfo(BaseModel):
key: str
type: str
label: str = ""
default_opacity: float = 0.5
src_format: str = "png"
class StageConfigInfo(BaseModel):
name: str
label: str
description: str
category: str
config_fields: list[dict]
output_hints: list[StageOutputHintInfo] = []
reads: list[str]
writes: list[str]
@@ -121,6 +130,13 @@ def _stage_to_info(stage) -> StageConfigInfo:
}
for f in stage.config_fields
],
output_hints=[
StageOutputHintInfo(
key=h.key, type=h.type, label=h.label,
default_opacity=h.default_opacity, src_format=h.src_format,
)
for h in getattr(stage, "output_hints", [])
],
reads=stage.io.reads,
writes=stage.io.writes,
)

View File

@@ -41,34 +41,24 @@ class ScenarioInfo(BaseModel):
class ReplayRequest(BaseModel):
timeline_id: str
job_id: str
start_stage: str
config_overrides: dict | None = None
class ReplayResponse(BaseModel):
status: str
timeline_id: str
job_id: str
replay_job_id: str
start_stage: str
detections: int = 0
brands_found: int = 0
class RetryRequest(BaseModel):
timeline_id: str
config_overrides: dict | None = None
start_stage: str = "escalate_vlm"
schedule_seconds: float | None = None # delay before execution (off-peak)
class RetryResponse(BaseModel):
status: str
task_id: str
timeline_id: str
class ReplaySingleStageRequest(BaseModel):
timeline_id: str
job_id: str
stage: str
frame_refs: list[int] | None = None
config_overrides: dict | None = None
@@ -103,16 +93,24 @@ class ReplaySingleStageResponse(BaseModel):
# --- Endpoints ---
@router.get("/checkpoints/{timeline_id}")
def list_checkpoints(timeline_id: str) -> list[CheckpointInfo]:
"""List available checkpoint stages for a job."""
from core.detect.checkpoint import list_checkpoints as _list
def list_checkpoints_endpoint(timeline_id: str) -> list[CheckpointInfo]:
"""List available checkpoint stages for a timeline."""
from core.detect.checkpoint.storage import get_checkpoints_for_timeline
try:
stages = _list(timeline_id)
checkpoints = get_checkpoints_for_timeline(timeline_id)
except Exception as e:
raise HTTPException(status_code=404, detail=f"No checkpoints for job {timeline_id}: {e}")
raise HTTPException(status_code=404, detail=f"No checkpoints for timeline {timeline_id}: {e}")
result = [CheckpointInfo(stage=s) for s in stages]
result = [
CheckpointInfo(
stage=c["stage_name"],
is_scenario=c.get("is_scenario", False),
scenario_label=c.get("scenario_label", ""),
)
for c in checkpoints
if c["stage_name"]
]
return result
@@ -211,11 +209,11 @@ def list_scenarios_endpoint():
@router.post("/replay", response_model=ReplayResponse)
def replay(req: ReplayRequest):
"""Replay pipeline from a specific stage with optional config overrides."""
from core.detect.checkpoint import replay_from
from core.detect.checkpoint.replay import replay_from
try:
result = replay_from(
timeline_id=req.timeline_id,
job_id=req.job_id,
start_stage=req.start_stage,
config_overrides=req.config_overrides,
)
@@ -230,7 +228,8 @@ def replay(req: ReplayRequest):
response = ReplayResponse(
status="completed",
timeline_id=req.timeline_id,
job_id=req.job_id,
replay_job_id=result.get("job_id", ""),
start_stage=req.start_stage,
detections=len(detections),
brands_found=brands_found,
@@ -238,29 +237,6 @@ def replay(req: ReplayRequest):
return response
@router.post("/retry", response_model=RetryResponse)
def retry(req: RetryRequest):
"""Queue an async retry of unresolved candidates with different config."""
from core.detect.checkpoint.tasks import retry_candidates
kwargs = {
"timeline_id": req.timeline_id,
"config_overrides": req.config_overrides,
"start_stage": req.start_stage,
}
if req.schedule_seconds:
task = retry_candidates.apply_async(kwargs=kwargs, countdown=req.schedule_seconds)
else:
task = retry_candidates.delay(**kwargs)
response = RetryResponse(
status="queued",
task_id=task.id,
timeline_id=req.timeline_id,
)
return response
@router.post("/replay-stage", response_model=ReplaySingleStageResponse)
def replay_single_stage(req: ReplaySingleStageRequest):
@@ -269,7 +245,7 @@ def replay_single_stage(req: ReplaySingleStageRequest):
try:
result = _replay(
timeline_id=req.timeline_id,
job_id=req.job_id,
stage=req.stage,
frame_refs=req.frame_refs,
config_overrides=req.config_overrides,
@@ -324,6 +300,151 @@ def _gpu_url() -> str:
return url.rstrip("/")
# --- Overlay cache — save/load debug overlay images ---
class SaveOverlaysRequest(BaseModel):
timeline_id: str
job_id: str
stage: str
seq: int
overlays: dict[str, str] # {overlay_key: base64_png}
@router.post("/overlays")
def save_overlays_endpoint(req: SaveOverlaysRequest):
"""Save debug overlay images to blob storage cache."""
from core.detect.checkpoint.frames import save_overlays
save_overlays(req.timeline_id, req.job_id, req.stage, req.seq, req.overlays)
return {"status": "saved", "count": len(req.overlays)}
@router.get("/overlays/{timeline_id}/{job_id}/{stage}/{seq}")
def load_overlays_endpoint(timeline_id: str, job_id: str, stage: str, seq: int):
"""Load cached debug overlay images."""
from core.detect.checkpoint.frames import load_overlays
overlays = load_overlays(timeline_id, job_id, stage, seq)
return {"overlays": overlays or {}}
def _generate_debug_overlays(job_id: str, stage: str, frame) -> dict[str, str] | None:
"""Generate debug overlay images for a single frame."""
import os
inference_url = os.environ.get("INFERENCE_URL")
if stage == "detect_edges":
from core.detect.profile import get_profile, get_stage_config
from core.detect.stages.models import RegionAnalysisConfig
from core.db.connection import get_session
from core.db.job import get_job
from uuid import UUID
with get_session() as session:
job = get_job(session, UUID(job_id))
if not job:
return None
profile = get_profile(job.profile_name)
config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
if inference_url:
from core.detect.inference import InferenceClient
client = InferenceClient(base_url=inference_url, job_id=job_id)
dr = client.detect_edges_debug(
image=frame.image,
edge_canny_low=config.edge_canny_low,
edge_canny_high=config.edge_canny_high,
edge_hough_threshold=config.edge_hough_threshold,
edge_hough_min_length=config.edge_hough_min_length,
edge_hough_max_gap=config.edge_hough_max_gap,
edge_pair_max_distance=config.edge_pair_max_distance,
edge_pair_min_distance=config.edge_pair_min_distance,
)
return {
"edge_overlay_b64": dr.edge_overlay_b64,
"lines_overlay_b64": dr.lines_overlay_b64,
}
else:
from core.detect.stages.edge_detector import _load_cv_edges
edges_mod = _load_cv_edges()
dr = edges_mod.detect_edges_debug(
frame.image,
canny_low=config.edge_canny_low,
canny_high=config.edge_canny_high,
hough_threshold=config.edge_hough_threshold,
hough_min_length=config.edge_hough_min_length,
hough_max_gap=config.edge_hough_max_gap,
pair_max_distance=config.edge_pair_max_distance,
pair_min_distance=config.edge_pair_min_distance,
)
return {
"edge_overlay_b64": dr["edge_overlay_b64"],
"lines_overlay_b64": dr["lines_overlay_b64"],
}
elif stage == "field_segmentation":
from core.detect.profile import get_profile, get_stage_config
from core.detect.stages.models import FieldSegmentationConfig
from core.db.connection import get_session
from core.db.job import get_job
from uuid import UUID
with get_session() as session:
job = get_job(session, UUID(job_id))
if not job:
return None
profile = get_profile(job.profile_name)
config = FieldSegmentationConfig(**get_stage_config(profile, "field_segmentation"))
if inference_url:
import httpx, json, base64, io
from PIL import Image
import numpy as np
buf = io.BytesIO()
Image.fromarray(frame.image).save(buf, format="JPEG", quality=85)
img_b64 = base64.b64encode(buf.getvalue()).decode()
resp = httpx.post(
f"{inference_url.rstrip('/')}/segment_field/debug",
json={
"image_b64": img_b64,
"hue_low": config.hue_low,
"hue_high": config.hue_high,
"sat_low": config.sat_low,
"sat_high": config.sat_high,
"val_low": config.val_low,
"val_high": config.val_high,
"morph_kernel": config.morph_kernel,
"min_area_ratio": config.min_area_ratio,
},
timeout=30.0,
)
if resp.status_code == 200:
data = resp.json()
return {"mask_overlay_b64": data.get("mask_b64", "")}
return None
return None
@router.get("/overlays/{timeline_id}/{job_id}/{stage}")
def list_overlay_frames_endpoint(timeline_id: str, job_id: str, stage: str):
"""List frame sequences that have cached overlays."""
from core.detect.checkpoint.frames import list_overlay_frames
seqs = list_overlay_frames(timeline_id, job_id, stage)
return {"frames": seqs}
# --- GPU proxy — thin passthrough to inference server for interactive editor ---
@router.post("/gpu/detect_edges")
async def gpu_detect_edges(request: Request):
"""Proxy to GPU inference server — browser can't reach it directly."""

View File

@@ -147,6 +147,108 @@ def load_cached_frames_b64(timeline_id: str) -> list[dict]:
return result
# ---------------------------------------------------------------------------
# Debug overlay storage — per job/stage/frame
# ---------------------------------------------------------------------------
def _overlay_prefix(timeline_id: str, job_id: str, stage: str) -> str:
return f"{CACHE_PREFIX}/{timeline_id}/overlays/{job_id}/{stage}/"
def _overlay_key(timeline_id: str, job_id: str, stage: str, seq: int, name: str) -> str:
return f"{CACHE_PREFIX}/{timeline_id}/overlays/{job_id}/{stage}/{seq}_{name}.png"
def save_overlays(
timeline_id: str,
job_id: str,
stage: str,
seq: int,
overlays: dict[str, str],
):
"""
Save debug overlay images (base64 PNG) to blob storage.
overlays: {overlay_key: base64_png_string}
e.g. {"edge_overlay_b64": "iVBOR...", "lines_overlay_b64": "iVBOR..."}
"""
from core.storage.s3 import upload_file
import tempfile
for name, b64_data in overlays.items():
key = _overlay_key(timeline_id, job_id, stage, seq, name)
raw = base64.b64decode(b64_data)
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
tmp.write(raw)
tmp_path = tmp.name
try:
upload_file(tmp_path, BUCKET, key)
finally:
os.unlink(tmp_path)
logger.info("Saved %d overlays for timeline %s job %s stage %s frame %d",
len(overlays), timeline_id, job_id, stage, seq)
def load_overlays(
timeline_id: str,
job_id: str,
stage: str,
seq: int,
) -> dict[str, str] | None:
"""
Load debug overlay images from blob storage as base64 strings.
Returns {overlay_key: base64_png_string} or None if no overlays cached.
"""
from core.storage.s3 import list_objects, download_to_temp
prefix = _overlay_prefix(timeline_id, job_id, stage)
seq_prefix = f"{seq}_"
objects = list_objects(BUCKET, prefix)
overlays = {}
for obj in objects:
filename = obj["key"].rsplit("/", 1)[-1]
if not filename.startswith(seq_prefix):
continue
name = filename[len(seq_prefix):].replace(".png", "")
tmp_path = download_to_temp(BUCKET, obj["key"])
try:
with open(tmp_path, "rb") as f:
overlays[name] = base64.b64encode(f.read()).decode()
finally:
os.unlink(tmp_path)
return overlays if overlays else None
def list_overlay_frames(
timeline_id: str,
job_id: str,
stage: str,
) -> list[int]:
"""List frame sequences that have cached overlays."""
from core.storage.s3 import list_objects
prefix = _overlay_prefix(timeline_id, job_id, stage)
objects = list_objects(BUCKET, prefix)
seqs = set()
for obj in objects:
filename = obj["key"].rsplit("/", 1)[-1]
seq_str = filename.split("_")[0]
try:
seqs.add(int(seq_str))
except ValueError:
continue
return sorted(seqs)
def clear_cache(timeline_id: str):
"""Delete the frame cache for a timeline."""
from core.storage.s3 import delete_objects

View File

@@ -1,32 +1,88 @@
"""
Pipeline replay — re-run from any stage with different config.
Loads a checkpoint, applies config overrides, builds a subgraph
starting from the target stage, and invokes it.
Loads stage outputs from DB, frames from timeline cache,
reconstitutes state, and runs from a target stage onward.
Creates a new Job (run_type=REPLAY) for each replay invocation.
"""
from __future__ import annotations
import logging
import os
import uuid
from core.detect import emit
# TODO: migrate to Timeline/Branch/Checkpoint model
# These old functions no longer exist — replay needs rework
def _not_migrated(*args, **kwargs):
raise NotImplementedError("Replay not yet migrated to Timeline/Branch/Checkpoint model")
load_checkpoint = _not_migrated
list_checkpoints = _not_migrated
from core.detect.graph import NODES, build_graph
from core.detect.graph import NODES, get_pipeline
from core.detect.graph.runner import PipelineRunner
logger = logging.getLogger(__name__)
def _build_state_for_replay(
job_id: str,
up_to_stage: str,
) -> dict:
"""
Reconstitute pipeline state from a completed job's stage outputs,
up to (but not including) the target stage.
# OverrideProfile removed — config overrides are now handled by dict merging
# in _load_profile() (nodes.py) and replay_single_stage (below).
Loads frames from timeline cache + stage outputs from DB.
"""
from .storage import load_stage_outputs_for_job, get_checkpoints_for_job
from .frames import load_cached_frames
from core.db.connection import get_session
from core.db.job import get_job
# Load the job to get timeline_id and profile
with get_session() as session:
job = get_job(session, uuid.UUID(job_id))
if not job:
raise ValueError(f"Job not found: {job_id}")
timeline_id = str(job.timeline_id) if job.timeline_id else ""
if not timeline_id:
raise ValueError(f"Job {job_id} has no timeline")
# Load frames from timeline cache
frames = load_cached_frames(timeline_id)
if not frames:
raise ValueError(f"No cached frames for timeline {timeline_id}. Run the pipeline first.")
# Load all stage outputs for this job
all_outputs = load_stage_outputs_for_job(job_id)
# Build state with envelope + frames
state = {
"job_id": job_id,
"timeline_id": timeline_id,
"video_path": job.video_path,
"profile_name": job.profile_name,
"source_asset_id": str(job.source_asset_id),
"frames": frames,
"config_overrides": {},
}
# Apply stage outputs in pipeline order, up to the target stage
target_idx = NODES.index(up_to_stage)
for stage_name in NODES[:target_idx]:
output = all_outputs.get(stage_name)
if output:
# Stage outputs contain serialized data — merge into state
# The stage registry's deserialize_fn can reconstitute if needed
for key, value in output.items():
state[key] = value
# Filtered frames: reconstruct from sequence list if present
filtered_seqs = state.get("filtered_frame_sequences")
if filtered_seqs:
seq_set = set(filtered_seqs)
state["filtered_frames"] = [f for f in frames if f.sequence in seq_set]
elif "filtered_frames" not in state:
state["filtered_frames"] = frames
return state
def replay_from(
@@ -38,49 +94,60 @@ def replay_from(
"""
Replay the pipeline from a specific stage.
Loads the checkpoint from the stage immediately before start_stage,
applies config overrides, and runs the subgraph from start_stage onward.
Loads state from the original job's stage outputs up to start_stage,
applies config overrides, and runs from start_stage onward.
Creates a new Job (run_type=REPLAY).
Returns the final state dict.
"""
if start_stage not in NODES:
raise ValueError(f"Unknown stage: {start_stage!r}. Options: {NODES}")
start_idx = NODES.index(start_stage)
# Load checkpoint from the stage before start_stage
if start_idx == 0:
raise ValueError("Cannot replay from the first stage — just run the full pipeline")
previous_stage = NODES[start_idx - 1]
logger.info("Replaying job %s from %s", job_id, start_stage)
available = list_checkpoints(job_id)
if previous_stage not in available:
raise ValueError(
f"No checkpoint for stage {previous_stage!r} (job {job_id}). "
f"Available: {available}"
)
logger.info("Replaying job %s from %s (loading checkpoint: %s)",
job_id, start_stage, previous_stage)
state = load_checkpoint(job_id, previous_stage)
state = _build_state_for_replay(job_id, start_stage)
# Apply config overrides
if config_overrides:
state["config_overrides"] = config_overrides
# Create replay job
from core.db.connection import get_session
from core.db.job import create_job, get_job
with get_session() as session:
original = get_job(session, uuid.UUID(job_id))
replay_job = create_job(
session,
source_asset_id=original.source_asset_id,
video_path=original.video_path,
timeline_id=original.timeline_id,
profile_name=original.profile_name,
run_type="replay",
parent_id=original.id,
config_overrides=config_overrides,
)
replay_job_id = str(replay_job.id)
# Update state with new job ID
state["job_id"] = replay_job_id
# Set run context for SSE events
run_id = str(uuid.uuid4())[:8]
emit.set_run_context(
run_id=run_id,
run_id=replay_job_id,
parent_job_id=job_id,
run_type="replay",
)
# Build subgraph starting from start_stage
graph = build_graph(checkpoint=checkpoint, start_from=start_stage)
pipeline = graph.compile()
# Run from start_stage onward
pipeline = get_pipeline(
checkpoint=checkpoint,
profile_name=state["profile_name"],
start_from=start_stage,
)
try:
result = pipeline.invoke(state)
@@ -102,12 +169,6 @@ def replay_single_stage(
Fast path for interactive parameter tuning — runs only the target stage
function, not the full pipeline tail. Returns the stage output directly.
When debug=True and stage is detect_edges, returns additional overlay
data (Canny edges, Hough lines) for visual feedback in the editor.
For detect_edges: returns {"edge_regions_by_frame": {seq: [box, ...]}}
With debug=True, also returns {"debug": {seq: {edge_overlay_b64, lines_overlay_b64, ...}}}
"""
if stage not in NODES:
raise ValueError(f"Unknown stage: {stage!r}. Options: {NODES}")
@@ -116,19 +177,9 @@ def replay_single_stage(
if stage_idx == 0:
raise ValueError("Cannot replay the first stage — just run the full pipeline")
previous_stage = NODES[stage_idx - 1]
logger.info("Single-stage replay: job %s, stage %s (debug=%s)", job_id, stage, debug)
available = list_checkpoints(job_id)
if previous_stage not in available:
raise ValueError(
f"No checkpoint for stage {previous_stage!r} (job {job_id}). "
f"Available: {available}"
)
logger.info("Single-stage replay: job %s, stage %s (loading checkpoint: %s, debug=%s)",
job_id, stage, previous_stage, debug)
state = load_checkpoint(job_id, previous_stage)
state = _build_state_for_replay(job_id, stage)
# Build profile with overrides
from core.detect.profile import get_profile, get_stage_config
@@ -142,9 +193,17 @@ def replay_single_stage(
merged_configs[sname] = soverrides
profile = {**profile, "configs": merged_configs}
# Run the stage function directly (not through the graph)
# Subset frames if requested
frames = state.get("filtered_frames", state.get("frames", []))
if frame_refs:
ref_set = set(frame_refs)
frames = [f for f in frames if f.sequence in ref_set]
# Run the specific stage
if stage == "detect_edges":
return _replay_detect_edges(state, profile, frame_refs, job_id, debug)
return _replay_detect_edges(state, profile, frames, job_id, debug)
elif stage == "field_segmentation":
return _replay_field_segmentation(state, profile, frames, job_id, debug)
else:
raise ValueError(
f"Single-stage replay not yet implemented for {stage!r}. "
@@ -155,35 +214,28 @@ def replay_single_stage(
def _replay_detect_edges(
state: dict,
profile,
frame_refs: list[int] | None,
frames: list,
job_id: str,
debug: bool,
) -> dict:
"""Run edge detection on checkpoint frames, optionally with debug overlays."""
import os
from core.detect.stages.edge_detector import detect_edge_regions
from core.detect.profile import get_stage_config
from core.detect.stages.models import RegionAnalysisConfig
config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
frames = state.get("filtered_frames", [])
if frame_refs:
ref_set = set(frame_refs)
frames = [f for f in frames if f.sequence in ref_set]
inference_url = os.environ.get("INFERENCE_URL")
field_masks = state.get("field_masks", {})
# Normal run — always needed for the boxes
result = detect_edge_regions(
frames=frames,
config=config,
inference_url=inference_url,
job_id=job_id,
field_masks=field_masks,
)
output = {"edge_regions_by_frame": result}
# Debug overlays — call debug endpoint (remote) or local debug function
if debug and frames:
debug_data = {}
if inference_url:
@@ -207,7 +259,6 @@ def _replay_detect_edges(
"pair_count": dr.pair_count,
}
else:
# Local mode — import GPU module directly
from core.detect.stages.edge_detector import _load_cv_edges
edges_mod = _load_cv_edges()
for frame in frames:
@@ -230,3 +281,27 @@ def _replay_detect_edges(
output["debug"] = debug_data
return output
def _replay_field_segmentation(
state: dict,
profile,
frames: list,
job_id: str,
debug: bool,
) -> dict:
"""Run field segmentation on checkpoint frames."""
from core.detect.stages.field_segmentation import run_field_segmentation
from core.detect.profile import get_stage_config
from core.detect.stages.models import FieldSegmentationConfig
config = FieldSegmentationConfig(**get_stage_config(profile, "field_segmentation"))
inference_url = os.environ.get("INFERENCE_URL")
result = run_field_segmentation(
frames=frames,
config=config,
inference_url=inference_url,
job_id=job_id,
)
return result

View File

@@ -39,13 +39,21 @@ def checkpoint_after_stage(job_id: str, stage_name: str, state: dict, result: di
return
from .storage import save_checkpoint, save_stage_output
from core.detect.stages.base import _REGISTRY
from core.detect.stages.base import _REGISTRY, _LEGACY_REGISTRY
merged = {**state, **result}
# Serialize stage output using the stage's serialize_fn if available
# Check new-style registry first, then legacy (some stages are in both)
serialize_fn = None
stage_cls = _REGISTRY.get(stage_name)
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
if stage_cls:
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
if not serialize_fn:
legacy = _LEGACY_REGISTRY.get(stage_name)
if legacy:
serialize_fn = legacy.serialize_fn
if serialize_fn:
output_json = serialize_fn(merged, job_id)
else:

View File

@@ -146,6 +146,7 @@ def node_field_segmentation(state: DetectState) -> dict:
_emit(state, "field_segmentation", "done")
return {
"field_masks": result["field_masks"],
"field_mask_overlays": result.get("field_mask_overlays", {}),
"field_boundaries": result["field_boundaries"],
"field_coverage": result["field_coverage"],
}

View File

@@ -24,7 +24,7 @@ from PIL import Image
from core.detect import emit
from core.detect.models import BoundingBox, Frame
from core.detect.stages.base import Stage
from core.detect.stages.models import StageDefinition, StageConfigField, StageIO
from core.detect.stages.models import StageDefinition, StageConfigField, StageIO, StageOutputHint
logger = logging.getLogger(__name__)
@@ -42,7 +42,6 @@ class EdgeDetectionStage(Stage):
writes=["edge_regions_by_frame"],
),
config_fields=[
StageConfigField(name="enabled", type="bool", default=True, description="Enable edge detection"),
StageConfigField(name="edge_canny_low", type="int", default=50, description="Canny low threshold", min=0, max=255),
StageConfigField(name="edge_canny_high", type="int", default=150, description="Canny high threshold", min=0, max=255),
StageConfigField(name="edge_hough_threshold", type="int", default=80, description="Hough accumulator threshold", min=1, max=500),
@@ -51,6 +50,11 @@ class EdgeDetectionStage(Stage):
StageConfigField(name="edge_pair_max_distance", type="int", default=200, description="Max distance between line pair (px)", min=10, max=500),
StageConfigField(name="edge_pair_min_distance", type="int", default=15, description="Min distance between line pair (px)", min=5, max=200),
],
output_hints=[
StageOutputHint(key="edge_regions_by_frame", type="boxes_by_frame", label="Edge regions"),
StageOutputHint(key="edge_overlay_b64", type="overlay", label="Canny edges", default_opacity=0.25),
StageOutputHint(key="lines_overlay_b64", type="overlay", label="Hough lines", default_opacity=0.25),
],
tracks_element="edge_region",
)

View File

@@ -24,6 +24,7 @@ from core.detect.stages.models import (
StageConfigField,
StageDefinition,
StageIO,
StageOutputHint,
)
logger = logging.getLogger(__name__)
@@ -41,7 +42,6 @@ class FieldSegmentationStage(Stage):
writes=["field_mask"],
),
config_fields=[
StageConfigField(name="enabled", type="bool", default=True, description="Enable field segmentation"),
StageConfigField(name="hue_low", type="int", default=30, description="HSV hue lower bound", min=0, max=180),
StageConfigField(name="hue_high", type="int", default=85, description="HSV hue upper bound", min=0, max=180),
StageConfigField(name="sat_low", type="int", default=30, description="HSV saturation lower bound", min=0, max=255),
@@ -51,6 +51,9 @@ class FieldSegmentationStage(Stage):
StageConfigField(name="morph_kernel", type="int", default=15, description="Morphology kernel size", min=3, max=51),
StageConfigField(name="min_area_ratio", type="float", default=0.05, description="Min contour area as fraction of frame", min=0.01, max=0.5),
],
output_hints=[
StageOutputHint(key="mask_overlay_b64", type="overlay", label="Field mask", default_opacity=0.5, src_format="png"),
],
)
@@ -102,6 +105,7 @@ def run_field_segmentation(
client = InferenceClient(base_url=url, job_id=job_id or "", log_level=_run_log_level)
field_masks = {}
field_mask_overlays = {}
field_boundaries = {}
field_coverage = {}
@@ -126,6 +130,7 @@ def run_field_segmentation(
mask_b64 = resp.get("mask_b64", "")
if mask_b64:
field_masks[frame.sequence] = _decode_mask_b64(mask_b64)
field_mask_overlays[frame.sequence] = mask_b64
field_boundaries[frame.sequence] = resp.get("boundary", [])
field_coverage[frame.sequence] = resp.get("coverage", 0.0)
@@ -136,6 +141,7 @@ def run_field_segmentation(
return {
"field_masks": field_masks,
"field_mask_overlays": field_mask_overlays,
"field_boundaries": field_boundaries,
"field_coverage": field_coverage,
}

View File

@@ -27,6 +27,14 @@ class StageIO(BaseModel):
writes: List[str] = Field(default_factory=list)
optional_reads: List[str] = Field(default_factory=list)
class StageOutputHint(BaseModel):
"""How to render a stage output in the compare/editor views."""
key: str
type: str
label: str = ""
default_opacity: float = 0.5
src_format: str = "png"
class StageDefinition(BaseModel):
"""Complete metadata for a pipeline stage."""
name: str
@@ -35,6 +43,7 @@ class StageDefinition(BaseModel):
category: str = "detection"
io: StageIO
config_fields: List[StageConfigField] = Field(default_factory=list)
output_hints: List[StageOutputHint] = Field(default_factory=list)
tracks_element: Optional[str] = None
class FrameExtractionConfig(BaseModel):

View File

@@ -1,4 +1,4 @@
"""Registration for CV analysis stages: edge detection."""
"""Registration for CV analysis stages: edge detection, field segmentation."""
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
from core.detect.stages.base import register_stage
@@ -20,6 +20,24 @@ def _deser_regions(data: dict, job_id: str) -> dict:
return {"edge_regions_by_frame": regions}
def _ser_field_seg(state: dict, job_id: str) -> dict:
"""Serialize field segmentation — boundaries + coverage + mask overlays."""
boundaries = state.get("field_boundaries", {})
coverage = state.get("field_coverage", {})
mask_overlays = state.get("field_mask_overlays", {})
return {
"field_boundaries": {str(k): v for k, v in boundaries.items()},
"field_coverage": {str(k): v for k, v in coverage.items()},
"mask_overlays_by_frame": {str(k): v for k, v in mask_overlays.items()},
}
def _deser_field_seg(data: dict, job_id: str) -> dict:
boundaries = {int(k): v for k, v in data.get("field_boundaries", {}).items()}
coverage = {int(k): v for k, v in data.get("field_coverage", {}).items()}
return {"field_boundaries": boundaries, "field_coverage": coverage}
def register():
edge_detection = StageDefinition(
name="detect_edges",
@@ -31,7 +49,6 @@ def register():
writes=["edge_regions_by_frame"],
),
config_fields=[
StageConfigField(name="enabled", type="bool", default=True, description="Enable region analysis"),
StageConfigField(name="edge_canny_low", type="int", default=50, description="Canny low threshold", min=0, max=255),
StageConfigField(name="edge_canny_high", type="int", default=150, description="Canny high threshold", min=0, max=255),
StageConfigField(name="edge_hough_threshold", type="int", default=80, description="Hough accumulator threshold", min=1, max=500),
@@ -42,3 +59,25 @@ def register():
],
)
register_stage(edge_detection, serialize_fn=_ser_regions, deserialize_fn=_deser_regions)
field_seg = StageDefinition(
name="field_segmentation",
label="Field Segmentation",
description="HSV green mask — detect pitch boundaries",
category="cv_analysis",
io=StageIO(
reads=["filtered_frames"],
writes=["field_mask"],
),
config_fields=[
StageConfigField(name="hue_low", type="int", default=30, description="HSV hue lower bound", min=0, max=180),
StageConfigField(name="hue_high", type="int", default=85, description="HSV hue upper bound", min=0, max=180),
StageConfigField(name="sat_low", type="int", default=30, description="HSV saturation lower bound", min=0, max=255),
StageConfigField(name="sat_high", type="int", default=255, description="HSV saturation upper bound", min=0, max=255),
StageConfigField(name="val_low", type="int", default=30, description="HSV value lower bound", min=0, max=255),
StageConfigField(name="val_high", type="int", default=255, description="HSV value upper bound", min=0, max=255),
StageConfigField(name="morph_kernel", type="int", default=15, description="Morphology kernel size", min=3, max=51),
StageConfigField(name="min_area_ratio", type="float", default=0.05, description="Min contour area", min=0.01, max=0.5),
],
)
register_stage(field_seg, serialize_fn=_ser_field_seg, deserialize_fn=_deser_field_seg)

View File

@@ -60,7 +60,6 @@ def register():
io=StageIO(reads=["frames"], writes=["filtered_frames"]),
config_fields=[
StageConfigField(name="hamming_threshold", type="int", default=8, description="Hamming distance threshold", min=0, max=64),
StageConfigField(name="enabled", type="bool", default=True, description="Enable scene filtering"),
],
)
register_stage(scene_filter, serialize_fn=_ser_filter, deserialize_fn=_deser_filter)

View File

@@ -35,6 +35,16 @@ class StageIO:
optional_reads: List[str] = field(default_factory=list)
@dataclass
class StageOutputHint:
"""How to render a stage output in the compare/editor views."""
key: str # key in the stage output JSON, e.g. "edge_regions_by_frame"
type: str # "boxes_by_frame" | "overlay"
label: str = ""
default_opacity: float = 0.5
src_format: str = "png" # for overlays
@dataclass
class StageDefinition:
"""Complete metadata for a pipeline stage."""
@@ -44,6 +54,7 @@ class StageDefinition:
category: str = "detection"
io: StageIO = field(default_factory=StageIO)
config_fields: List[StageConfigField] = field(default_factory=list)
output_hints: List[StageOutputHint] = field(default_factory=list)
tracks_element: Optional[str] = None
@@ -139,6 +150,7 @@ class PipelineConfig:
STAGE_VIEWS = [
StageConfigField,
StageIO,
StageOutputHint,
StageDefinition,
FrameExtractionConfig,
SceneFilterConfig,