compare view
This commit is contained in:
@@ -147,6 +147,108 @@ def load_cached_frames_b64(timeline_id: str) -> list[dict]:
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Debug overlay storage — per job/stage/frame
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _overlay_prefix(timeline_id: str, job_id: str, stage: str) -> str:
|
||||
return f"{CACHE_PREFIX}/{timeline_id}/overlays/{job_id}/{stage}/"
|
||||
|
||||
|
||||
def _overlay_key(timeline_id: str, job_id: str, stage: str, seq: int, name: str) -> str:
|
||||
return f"{CACHE_PREFIX}/{timeline_id}/overlays/{job_id}/{stage}/{seq}_{name}.png"
|
||||
|
||||
|
||||
def save_overlays(
|
||||
timeline_id: str,
|
||||
job_id: str,
|
||||
stage: str,
|
||||
seq: int,
|
||||
overlays: dict[str, str],
|
||||
):
|
||||
"""
|
||||
Save debug overlay images (base64 PNG) to blob storage.
|
||||
|
||||
overlays: {overlay_key: base64_png_string}
|
||||
e.g. {"edge_overlay_b64": "iVBOR...", "lines_overlay_b64": "iVBOR..."}
|
||||
"""
|
||||
from core.storage.s3 import upload_file
|
||||
import tempfile
|
||||
|
||||
for name, b64_data in overlays.items():
|
||||
key = _overlay_key(timeline_id, job_id, stage, seq, name)
|
||||
raw = base64.b64decode(b64_data)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
||||
tmp.write(raw)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
upload_file(tmp_path, BUCKET, key)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
logger.info("Saved %d overlays for timeline %s job %s stage %s frame %d",
|
||||
len(overlays), timeline_id, job_id, stage, seq)
|
||||
|
||||
|
||||
def load_overlays(
|
||||
timeline_id: str,
|
||||
job_id: str,
|
||||
stage: str,
|
||||
seq: int,
|
||||
) -> dict[str, str] | None:
|
||||
"""
|
||||
Load debug overlay images from blob storage as base64 strings.
|
||||
|
||||
Returns {overlay_key: base64_png_string} or None if no overlays cached.
|
||||
"""
|
||||
from core.storage.s3 import list_objects, download_to_temp
|
||||
|
||||
prefix = _overlay_prefix(timeline_id, job_id, stage)
|
||||
seq_prefix = f"{seq}_"
|
||||
objects = list_objects(BUCKET, prefix)
|
||||
|
||||
overlays = {}
|
||||
for obj in objects:
|
||||
filename = obj["key"].rsplit("/", 1)[-1]
|
||||
if not filename.startswith(seq_prefix):
|
||||
continue
|
||||
name = filename[len(seq_prefix):].replace(".png", "")
|
||||
|
||||
tmp_path = download_to_temp(BUCKET, obj["key"])
|
||||
try:
|
||||
with open(tmp_path, "rb") as f:
|
||||
overlays[name] = base64.b64encode(f.read()).decode()
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
return overlays if overlays else None
|
||||
|
||||
|
||||
def list_overlay_frames(
|
||||
timeline_id: str,
|
||||
job_id: str,
|
||||
stage: str,
|
||||
) -> list[int]:
|
||||
"""List frame sequences that have cached overlays."""
|
||||
from core.storage.s3 import list_objects
|
||||
|
||||
prefix = _overlay_prefix(timeline_id, job_id, stage)
|
||||
objects = list_objects(BUCKET, prefix)
|
||||
|
||||
seqs = set()
|
||||
for obj in objects:
|
||||
filename = obj["key"].rsplit("/", 1)[-1]
|
||||
seq_str = filename.split("_")[0]
|
||||
try:
|
||||
seqs.add(int(seq_str))
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return sorted(seqs)
|
||||
|
||||
|
||||
def clear_cache(timeline_id: str):
|
||||
"""Delete the frame cache for a timeline."""
|
||||
from core.storage.s3 import delete_objects
|
||||
|
||||
@@ -1,32 +1,88 @@
|
||||
"""
|
||||
Pipeline replay — re-run from any stage with different config.
|
||||
|
||||
Loads a checkpoint, applies config overrides, builds a subgraph
|
||||
starting from the target stage, and invokes it.
|
||||
Loads stage outputs from DB, frames from timeline cache,
|
||||
reconstitutes state, and runs from a target stage onward.
|
||||
|
||||
Creates a new Job (run_type=REPLAY) for each replay invocation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from core.detect import emit
|
||||
# TODO: migrate to Timeline/Branch/Checkpoint model
|
||||
# These old functions no longer exist — replay needs rework
|
||||
def _not_migrated(*args, **kwargs):
|
||||
raise NotImplementedError("Replay not yet migrated to Timeline/Branch/Checkpoint model")
|
||||
|
||||
load_checkpoint = _not_migrated
|
||||
list_checkpoints = _not_migrated
|
||||
from core.detect.graph import NODES, build_graph
|
||||
from core.detect.graph import NODES, get_pipeline
|
||||
from core.detect.graph.runner import PipelineRunner
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _build_state_for_replay(
|
||||
job_id: str,
|
||||
up_to_stage: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Reconstitute pipeline state from a completed job's stage outputs,
|
||||
up to (but not including) the target stage.
|
||||
|
||||
# OverrideProfile removed — config overrides are now handled by dict merging
|
||||
# in _load_profile() (nodes.py) and replay_single_stage (below).
|
||||
Loads frames from timeline cache + stage outputs from DB.
|
||||
"""
|
||||
from .storage import load_stage_outputs_for_job, get_checkpoints_for_job
|
||||
from .frames import load_cached_frames
|
||||
from core.db.connection import get_session
|
||||
from core.db.job import get_job
|
||||
|
||||
# Load the job to get timeline_id and profile
|
||||
with get_session() as session:
|
||||
job = get_job(session, uuid.UUID(job_id))
|
||||
if not job:
|
||||
raise ValueError(f"Job not found: {job_id}")
|
||||
|
||||
timeline_id = str(job.timeline_id) if job.timeline_id else ""
|
||||
if not timeline_id:
|
||||
raise ValueError(f"Job {job_id} has no timeline")
|
||||
|
||||
# Load frames from timeline cache
|
||||
frames = load_cached_frames(timeline_id)
|
||||
if not frames:
|
||||
raise ValueError(f"No cached frames for timeline {timeline_id}. Run the pipeline first.")
|
||||
|
||||
# Load all stage outputs for this job
|
||||
all_outputs = load_stage_outputs_for_job(job_id)
|
||||
|
||||
# Build state with envelope + frames
|
||||
state = {
|
||||
"job_id": job_id,
|
||||
"timeline_id": timeline_id,
|
||||
"video_path": job.video_path,
|
||||
"profile_name": job.profile_name,
|
||||
"source_asset_id": str(job.source_asset_id),
|
||||
"frames": frames,
|
||||
"config_overrides": {},
|
||||
}
|
||||
|
||||
# Apply stage outputs in pipeline order, up to the target stage
|
||||
target_idx = NODES.index(up_to_stage)
|
||||
for stage_name in NODES[:target_idx]:
|
||||
output = all_outputs.get(stage_name)
|
||||
if output:
|
||||
# Stage outputs contain serialized data — merge into state
|
||||
# The stage registry's deserialize_fn can reconstitute if needed
|
||||
for key, value in output.items():
|
||||
state[key] = value
|
||||
|
||||
# Filtered frames: reconstruct from sequence list if present
|
||||
filtered_seqs = state.get("filtered_frame_sequences")
|
||||
if filtered_seqs:
|
||||
seq_set = set(filtered_seqs)
|
||||
state["filtered_frames"] = [f for f in frames if f.sequence in seq_set]
|
||||
elif "filtered_frames" not in state:
|
||||
state["filtered_frames"] = frames
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def replay_from(
|
||||
@@ -38,49 +94,60 @@ def replay_from(
|
||||
"""
|
||||
Replay the pipeline from a specific stage.
|
||||
|
||||
Loads the checkpoint from the stage immediately before start_stage,
|
||||
applies config overrides, and runs the subgraph from start_stage onward.
|
||||
Loads state from the original job's stage outputs up to start_stage,
|
||||
applies config overrides, and runs from start_stage onward.
|
||||
|
||||
Creates a new Job (run_type=REPLAY).
|
||||
Returns the final state dict.
|
||||
"""
|
||||
if start_stage not in NODES:
|
||||
raise ValueError(f"Unknown stage: {start_stage!r}. Options: {NODES}")
|
||||
|
||||
start_idx = NODES.index(start_stage)
|
||||
|
||||
# Load checkpoint from the stage before start_stage
|
||||
if start_idx == 0:
|
||||
raise ValueError("Cannot replay from the first stage — just run the full pipeline")
|
||||
|
||||
previous_stage = NODES[start_idx - 1]
|
||||
logger.info("Replaying job %s from %s", job_id, start_stage)
|
||||
|
||||
available = list_checkpoints(job_id)
|
||||
if previous_stage not in available:
|
||||
raise ValueError(
|
||||
f"No checkpoint for stage {previous_stage!r} (job {job_id}). "
|
||||
f"Available: {available}"
|
||||
)
|
||||
|
||||
logger.info("Replaying job %s from %s (loading checkpoint: %s)",
|
||||
job_id, start_stage, previous_stage)
|
||||
|
||||
state = load_checkpoint(job_id, previous_stage)
|
||||
state = _build_state_for_replay(job_id, start_stage)
|
||||
|
||||
# Apply config overrides
|
||||
if config_overrides:
|
||||
state["config_overrides"] = config_overrides
|
||||
|
||||
# Create replay job
|
||||
from core.db.connection import get_session
|
||||
from core.db.job import create_job, get_job
|
||||
with get_session() as session:
|
||||
original = get_job(session, uuid.UUID(job_id))
|
||||
replay_job = create_job(
|
||||
session,
|
||||
source_asset_id=original.source_asset_id,
|
||||
video_path=original.video_path,
|
||||
timeline_id=original.timeline_id,
|
||||
profile_name=original.profile_name,
|
||||
run_type="replay",
|
||||
parent_id=original.id,
|
||||
config_overrides=config_overrides,
|
||||
)
|
||||
replay_job_id = str(replay_job.id)
|
||||
|
||||
# Update state with new job ID
|
||||
state["job_id"] = replay_job_id
|
||||
|
||||
# Set run context for SSE events
|
||||
run_id = str(uuid.uuid4())[:8]
|
||||
emit.set_run_context(
|
||||
run_id=run_id,
|
||||
run_id=replay_job_id,
|
||||
parent_job_id=job_id,
|
||||
run_type="replay",
|
||||
)
|
||||
|
||||
# Build subgraph starting from start_stage
|
||||
graph = build_graph(checkpoint=checkpoint, start_from=start_stage)
|
||||
pipeline = graph.compile()
|
||||
# Run from start_stage onward
|
||||
pipeline = get_pipeline(
|
||||
checkpoint=checkpoint,
|
||||
profile_name=state["profile_name"],
|
||||
start_from=start_stage,
|
||||
)
|
||||
|
||||
try:
|
||||
result = pipeline.invoke(state)
|
||||
@@ -102,12 +169,6 @@ def replay_single_stage(
|
||||
|
||||
Fast path for interactive parameter tuning — runs only the target stage
|
||||
function, not the full pipeline tail. Returns the stage output directly.
|
||||
|
||||
When debug=True and stage is detect_edges, returns additional overlay
|
||||
data (Canny edges, Hough lines) for visual feedback in the editor.
|
||||
|
||||
For detect_edges: returns {"edge_regions_by_frame": {seq: [box, ...]}}
|
||||
With debug=True, also returns {"debug": {seq: {edge_overlay_b64, lines_overlay_b64, ...}}}
|
||||
"""
|
||||
if stage not in NODES:
|
||||
raise ValueError(f"Unknown stage: {stage!r}. Options: {NODES}")
|
||||
@@ -116,19 +177,9 @@ def replay_single_stage(
|
||||
if stage_idx == 0:
|
||||
raise ValueError("Cannot replay the first stage — just run the full pipeline")
|
||||
|
||||
previous_stage = NODES[stage_idx - 1]
|
||||
logger.info("Single-stage replay: job %s, stage %s (debug=%s)", job_id, stage, debug)
|
||||
|
||||
available = list_checkpoints(job_id)
|
||||
if previous_stage not in available:
|
||||
raise ValueError(
|
||||
f"No checkpoint for stage {previous_stage!r} (job {job_id}). "
|
||||
f"Available: {available}"
|
||||
)
|
||||
|
||||
logger.info("Single-stage replay: job %s, stage %s (loading checkpoint: %s, debug=%s)",
|
||||
job_id, stage, previous_stage, debug)
|
||||
|
||||
state = load_checkpoint(job_id, previous_stage)
|
||||
state = _build_state_for_replay(job_id, stage)
|
||||
|
||||
# Build profile with overrides
|
||||
from core.detect.profile import get_profile, get_stage_config
|
||||
@@ -142,9 +193,17 @@ def replay_single_stage(
|
||||
merged_configs[sname] = soverrides
|
||||
profile = {**profile, "configs": merged_configs}
|
||||
|
||||
# Run the stage function directly (not through the graph)
|
||||
# Subset frames if requested
|
||||
frames = state.get("filtered_frames", state.get("frames", []))
|
||||
if frame_refs:
|
||||
ref_set = set(frame_refs)
|
||||
frames = [f for f in frames if f.sequence in ref_set]
|
||||
|
||||
# Run the specific stage
|
||||
if stage == "detect_edges":
|
||||
return _replay_detect_edges(state, profile, frame_refs, job_id, debug)
|
||||
return _replay_detect_edges(state, profile, frames, job_id, debug)
|
||||
elif stage == "field_segmentation":
|
||||
return _replay_field_segmentation(state, profile, frames, job_id, debug)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Single-stage replay not yet implemented for {stage!r}. "
|
||||
@@ -155,35 +214,28 @@ def replay_single_stage(
|
||||
def _replay_detect_edges(
|
||||
state: dict,
|
||||
profile,
|
||||
frame_refs: list[int] | None,
|
||||
frames: list,
|
||||
job_id: str,
|
||||
debug: bool,
|
||||
) -> dict:
|
||||
"""Run edge detection on checkpoint frames, optionally with debug overlays."""
|
||||
import os
|
||||
from core.detect.stages.edge_detector import detect_edge_regions
|
||||
|
||||
from core.detect.profile import get_stage_config
|
||||
from core.detect.stages.models import RegionAnalysisConfig
|
||||
|
||||
config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
|
||||
frames = state.get("filtered_frames", [])
|
||||
|
||||
if frame_refs:
|
||||
ref_set = set(frame_refs)
|
||||
frames = [f for f in frames if f.sequence in ref_set]
|
||||
|
||||
inference_url = os.environ.get("INFERENCE_URL")
|
||||
field_masks = state.get("field_masks", {})
|
||||
|
||||
# Normal run — always needed for the boxes
|
||||
result = detect_edge_regions(
|
||||
frames=frames,
|
||||
config=config,
|
||||
inference_url=inference_url,
|
||||
job_id=job_id,
|
||||
field_masks=field_masks,
|
||||
)
|
||||
output = {"edge_regions_by_frame": result}
|
||||
|
||||
# Debug overlays — call debug endpoint (remote) or local debug function
|
||||
if debug and frames:
|
||||
debug_data = {}
|
||||
if inference_url:
|
||||
@@ -207,7 +259,6 @@ def _replay_detect_edges(
|
||||
"pair_count": dr.pair_count,
|
||||
}
|
||||
else:
|
||||
# Local mode — import GPU module directly
|
||||
from core.detect.stages.edge_detector import _load_cv_edges
|
||||
edges_mod = _load_cv_edges()
|
||||
for frame in frames:
|
||||
@@ -230,3 +281,27 @@ def _replay_detect_edges(
|
||||
output["debug"] = debug_data
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _replay_field_segmentation(
|
||||
state: dict,
|
||||
profile,
|
||||
frames: list,
|
||||
job_id: str,
|
||||
debug: bool,
|
||||
) -> dict:
|
||||
"""Run field segmentation on checkpoint frames."""
|
||||
from core.detect.stages.field_segmentation import run_field_segmentation
|
||||
from core.detect.profile import get_stage_config
|
||||
from core.detect.stages.models import FieldSegmentationConfig
|
||||
|
||||
config = FieldSegmentationConfig(**get_stage_config(profile, "field_segmentation"))
|
||||
inference_url = os.environ.get("INFERENCE_URL")
|
||||
|
||||
result = run_field_segmentation(
|
||||
frames=frames,
|
||||
config=config,
|
||||
inference_url=inference_url,
|
||||
job_id=job_id,
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -39,13 +39,21 @@ def checkpoint_after_stage(job_id: str, stage_name: str, state: dict, result: di
|
||||
return
|
||||
|
||||
from .storage import save_checkpoint, save_stage_output
|
||||
from core.detect.stages.base import _REGISTRY
|
||||
from core.detect.stages.base import _REGISTRY, _LEGACY_REGISTRY
|
||||
|
||||
merged = {**state, **result}
|
||||
|
||||
# Serialize stage output using the stage's serialize_fn if available
|
||||
# Check new-style registry first, then legacy (some stages are in both)
|
||||
serialize_fn = None
|
||||
stage_cls = _REGISTRY.get(stage_name)
|
||||
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
|
||||
if stage_cls:
|
||||
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
|
||||
if not serialize_fn:
|
||||
legacy = _LEGACY_REGISTRY.get(stage_name)
|
||||
if legacy:
|
||||
serialize_fn = legacy.serialize_fn
|
||||
|
||||
if serialize_fn:
|
||||
output_json = serialize_fn(merged, job_id)
|
||||
else:
|
||||
|
||||
@@ -146,6 +146,7 @@ def node_field_segmentation(state: DetectState) -> dict:
|
||||
_emit(state, "field_segmentation", "done")
|
||||
return {
|
||||
"field_masks": result["field_masks"],
|
||||
"field_mask_overlays": result.get("field_mask_overlays", {}),
|
||||
"field_boundaries": result["field_boundaries"],
|
||||
"field_coverage": result["field_coverage"],
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ from PIL import Image
|
||||
from core.detect import emit
|
||||
from core.detect.models import BoundingBox, Frame
|
||||
from core.detect.stages.base import Stage
|
||||
from core.detect.stages.models import StageDefinition, StageConfigField, StageIO
|
||||
from core.detect.stages.models import StageDefinition, StageConfigField, StageIO, StageOutputHint
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -42,7 +42,6 @@ class EdgeDetectionStage(Stage):
|
||||
writes=["edge_regions_by_frame"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField(name="enabled", type="bool", default=True, description="Enable edge detection"),
|
||||
StageConfigField(name="edge_canny_low", type="int", default=50, description="Canny low threshold", min=0, max=255),
|
||||
StageConfigField(name="edge_canny_high", type="int", default=150, description="Canny high threshold", min=0, max=255),
|
||||
StageConfigField(name="edge_hough_threshold", type="int", default=80, description="Hough accumulator threshold", min=1, max=500),
|
||||
@@ -51,6 +50,11 @@ class EdgeDetectionStage(Stage):
|
||||
StageConfigField(name="edge_pair_max_distance", type="int", default=200, description="Max distance between line pair (px)", min=10, max=500),
|
||||
StageConfigField(name="edge_pair_min_distance", type="int", default=15, description="Min distance between line pair (px)", min=5, max=200),
|
||||
],
|
||||
output_hints=[
|
||||
StageOutputHint(key="edge_regions_by_frame", type="boxes_by_frame", label="Edge regions"),
|
||||
StageOutputHint(key="edge_overlay_b64", type="overlay", label="Canny edges", default_opacity=0.25),
|
||||
StageOutputHint(key="lines_overlay_b64", type="overlay", label="Hough lines", default_opacity=0.25),
|
||||
],
|
||||
tracks_element="edge_region",
|
||||
)
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ from core.detect.stages.models import (
|
||||
StageConfigField,
|
||||
StageDefinition,
|
||||
StageIO,
|
||||
StageOutputHint,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -41,7 +42,6 @@ class FieldSegmentationStage(Stage):
|
||||
writes=["field_mask"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField(name="enabled", type="bool", default=True, description="Enable field segmentation"),
|
||||
StageConfigField(name="hue_low", type="int", default=30, description="HSV hue lower bound", min=0, max=180),
|
||||
StageConfigField(name="hue_high", type="int", default=85, description="HSV hue upper bound", min=0, max=180),
|
||||
StageConfigField(name="sat_low", type="int", default=30, description="HSV saturation lower bound", min=0, max=255),
|
||||
@@ -51,6 +51,9 @@ class FieldSegmentationStage(Stage):
|
||||
StageConfigField(name="morph_kernel", type="int", default=15, description="Morphology kernel size", min=3, max=51),
|
||||
StageConfigField(name="min_area_ratio", type="float", default=0.05, description="Min contour area as fraction of frame", min=0.01, max=0.5),
|
||||
],
|
||||
output_hints=[
|
||||
StageOutputHint(key="mask_overlay_b64", type="overlay", label="Field mask", default_opacity=0.5, src_format="png"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@@ -102,6 +105,7 @@ def run_field_segmentation(
|
||||
client = InferenceClient(base_url=url, job_id=job_id or "", log_level=_run_log_level)
|
||||
|
||||
field_masks = {}
|
||||
field_mask_overlays = {}
|
||||
field_boundaries = {}
|
||||
field_coverage = {}
|
||||
|
||||
@@ -126,6 +130,7 @@ def run_field_segmentation(
|
||||
mask_b64 = resp.get("mask_b64", "")
|
||||
if mask_b64:
|
||||
field_masks[frame.sequence] = _decode_mask_b64(mask_b64)
|
||||
field_mask_overlays[frame.sequence] = mask_b64
|
||||
|
||||
field_boundaries[frame.sequence] = resp.get("boundary", [])
|
||||
field_coverage[frame.sequence] = resp.get("coverage", 0.0)
|
||||
@@ -136,6 +141,7 @@ def run_field_segmentation(
|
||||
|
||||
return {
|
||||
"field_masks": field_masks,
|
||||
"field_mask_overlays": field_mask_overlays,
|
||||
"field_boundaries": field_boundaries,
|
||||
"field_coverage": field_coverage,
|
||||
}
|
||||
|
||||
@@ -27,6 +27,14 @@ class StageIO(BaseModel):
|
||||
writes: List[str] = Field(default_factory=list)
|
||||
optional_reads: List[str] = Field(default_factory=list)
|
||||
|
||||
class StageOutputHint(BaseModel):
|
||||
"""How to render a stage output in the compare/editor views."""
|
||||
key: str
|
||||
type: str
|
||||
label: str = ""
|
||||
default_opacity: float = 0.5
|
||||
src_format: str = "png"
|
||||
|
||||
class StageDefinition(BaseModel):
|
||||
"""Complete metadata for a pipeline stage."""
|
||||
name: str
|
||||
@@ -35,6 +43,7 @@ class StageDefinition(BaseModel):
|
||||
category: str = "detection"
|
||||
io: StageIO
|
||||
config_fields: List[StageConfigField] = Field(default_factory=list)
|
||||
output_hints: List[StageOutputHint] = Field(default_factory=list)
|
||||
tracks_element: Optional[str] = None
|
||||
|
||||
class FrameExtractionConfig(BaseModel):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Registration for CV analysis stages: edge detection."""
|
||||
"""Registration for CV analysis stages: edge detection, field segmentation."""
|
||||
|
||||
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
|
||||
from core.detect.stages.base import register_stage
|
||||
@@ -20,6 +20,24 @@ def _deser_regions(data: dict, job_id: str) -> dict:
|
||||
return {"edge_regions_by_frame": regions}
|
||||
|
||||
|
||||
def _ser_field_seg(state: dict, job_id: str) -> dict:
|
||||
"""Serialize field segmentation — boundaries + coverage + mask overlays."""
|
||||
boundaries = state.get("field_boundaries", {})
|
||||
coverage = state.get("field_coverage", {})
|
||||
mask_overlays = state.get("field_mask_overlays", {})
|
||||
return {
|
||||
"field_boundaries": {str(k): v for k, v in boundaries.items()},
|
||||
"field_coverage": {str(k): v for k, v in coverage.items()},
|
||||
"mask_overlays_by_frame": {str(k): v for k, v in mask_overlays.items()},
|
||||
}
|
||||
|
||||
|
||||
def _deser_field_seg(data: dict, job_id: str) -> dict:
|
||||
boundaries = {int(k): v for k, v in data.get("field_boundaries", {}).items()}
|
||||
coverage = {int(k): v for k, v in data.get("field_coverage", {}).items()}
|
||||
return {"field_boundaries": boundaries, "field_coverage": coverage}
|
||||
|
||||
|
||||
def register():
|
||||
edge_detection = StageDefinition(
|
||||
name="detect_edges",
|
||||
@@ -31,7 +49,6 @@ def register():
|
||||
writes=["edge_regions_by_frame"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField(name="enabled", type="bool", default=True, description="Enable region analysis"),
|
||||
StageConfigField(name="edge_canny_low", type="int", default=50, description="Canny low threshold", min=0, max=255),
|
||||
StageConfigField(name="edge_canny_high", type="int", default=150, description="Canny high threshold", min=0, max=255),
|
||||
StageConfigField(name="edge_hough_threshold", type="int", default=80, description="Hough accumulator threshold", min=1, max=500),
|
||||
@@ -42,3 +59,25 @@ def register():
|
||||
],
|
||||
)
|
||||
register_stage(edge_detection, serialize_fn=_ser_regions, deserialize_fn=_deser_regions)
|
||||
|
||||
field_seg = StageDefinition(
|
||||
name="field_segmentation",
|
||||
label="Field Segmentation",
|
||||
description="HSV green mask — detect pitch boundaries",
|
||||
category="cv_analysis",
|
||||
io=StageIO(
|
||||
reads=["filtered_frames"],
|
||||
writes=["field_mask"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField(name="hue_low", type="int", default=30, description="HSV hue lower bound", min=0, max=180),
|
||||
StageConfigField(name="hue_high", type="int", default=85, description="HSV hue upper bound", min=0, max=180),
|
||||
StageConfigField(name="sat_low", type="int", default=30, description="HSV saturation lower bound", min=0, max=255),
|
||||
StageConfigField(name="sat_high", type="int", default=255, description="HSV saturation upper bound", min=0, max=255),
|
||||
StageConfigField(name="val_low", type="int", default=30, description="HSV value lower bound", min=0, max=255),
|
||||
StageConfigField(name="val_high", type="int", default=255, description="HSV value upper bound", min=0, max=255),
|
||||
StageConfigField(name="morph_kernel", type="int", default=15, description="Morphology kernel size", min=3, max=51),
|
||||
StageConfigField(name="min_area_ratio", type="float", default=0.05, description="Min contour area", min=0.01, max=0.5),
|
||||
],
|
||||
)
|
||||
register_stage(field_seg, serialize_fn=_ser_field_seg, deserialize_fn=_deser_field_seg)
|
||||
|
||||
@@ -60,7 +60,6 @@ def register():
|
||||
io=StageIO(reads=["frames"], writes=["filtered_frames"]),
|
||||
config_fields=[
|
||||
StageConfigField(name="hamming_threshold", type="int", default=8, description="Hamming distance threshold", min=0, max=64),
|
||||
StageConfigField(name="enabled", type="bool", default=True, description="Enable scene filtering"),
|
||||
],
|
||||
)
|
||||
register_stage(scene_filter, serialize_fn=_ser_filter, deserialize_fn=_deser_filter)
|
||||
|
||||
Reference in New Issue
Block a user