phase 5: edge transforms, soleprint-ui rename, infra fixes
- pipeline edge transforms: stages can declare accepted_transforms, edges carry a transform dict, runner injects per-stage and nodes apply (e.g. invert_mask before edge detection); editable from UI via PUT /config/edge-transform - rename mpr-ui-framework -> soleprint-ui (now an external package synced via .spr from /home/mariano/wdir/spr); add @vue-flow/core and uplot to detection-app so linked package resolves them - Tiltfile guards kubectl context, k8s commands pin --context kind-mpr - kind-config: gateway on hostPort 30080 (Caddy fronts mpr.local.ar) - modelgen: pyproject.toml, .spr marker, dict default_factory support
This commit is contained in:
@@ -38,6 +38,14 @@ class StageOutputHintInfo(BaseModel):
|
||||
src_format: str = "png"
|
||||
|
||||
|
||||
class TransformOptionInfo(BaseModel):
|
||||
key: str
|
||||
type: str
|
||||
default: object = False
|
||||
label: str = ""
|
||||
description: str = ""
|
||||
|
||||
|
||||
class StageConfigInfo(BaseModel):
|
||||
name: str
|
||||
label: str
|
||||
@@ -45,6 +53,7 @@ class StageConfigInfo(BaseModel):
|
||||
category: str
|
||||
config_fields: list[dict]
|
||||
output_hints: list[StageOutputHintInfo] = []
|
||||
accepted_transforms: list[TransformOptionInfo] = []
|
||||
reads: list[str]
|
||||
writes: list[str]
|
||||
|
||||
@@ -87,6 +96,51 @@ def get_pipeline_config(profile_name: str):
|
||||
return profile["pipeline"]
|
||||
|
||||
|
||||
class UpdateEdgeTransformRequest(BaseModel):
|
||||
profile_name: str = "soccer_broadcast"
|
||||
source_stage: str
|
||||
target_stage: str
|
||||
transform: dict
|
||||
|
||||
|
||||
@router.put("/config/edge-transform")
|
||||
def update_edge_transform(req: UpdateEdgeTransformRequest):
|
||||
"""Update the transform on an edge in a profile's pipeline config."""
|
||||
from uuid import UUID
|
||||
from core.db.models import Profile
|
||||
from core.db.connection import get_session
|
||||
from sqlmodel import select
|
||||
from fastapi import HTTPException
|
||||
|
||||
with get_session() as session:
|
||||
stmt = select(Profile).where(Profile.name == req.profile_name)
|
||||
profile = session.exec(stmt).first()
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail=f"Profile not found: {req.profile_name}")
|
||||
|
||||
pipeline = dict(profile.pipeline)
|
||||
edges = pipeline.get("edges", [])
|
||||
|
||||
found = False
|
||||
for edge in edges:
|
||||
if edge.get("source") == req.source_stage and edge.get("target") == req.target_stage:
|
||||
edge["transform"] = req.transform
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Edge not found: {req.source_stage} → {req.target_stage}",
|
||||
)
|
||||
|
||||
pipeline["edges"] = edges
|
||||
profile.pipeline = pipeline
|
||||
session.commit()
|
||||
|
||||
return {"status": "updated", "edge": f"{req.source_stage} → {req.target_stage}", "transform": req.transform}
|
||||
|
||||
|
||||
@router.get("/config/stages", response_model=list[StageConfigInfo])
|
||||
def list_stage_configs():
|
||||
"""Return the stage palette with config field metadata for the editor."""
|
||||
@@ -137,6 +191,13 @@ def _stage_to_info(stage) -> StageConfigInfo:
|
||||
)
|
||||
for h in getattr(stage, "output_hints", [])
|
||||
],
|
||||
accepted_transforms=[
|
||||
TransformOptionInfo(
|
||||
key=t.key, type=t.type, default=t.default,
|
||||
label=t.label, description=t.description,
|
||||
)
|
||||
for t in getattr(stage, "accepted_transforms", [])
|
||||
],
|
||||
reads=stage.io.reads,
|
||||
writes=stage.io.writes,
|
||||
)
|
||||
|
||||
@@ -54,7 +54,8 @@
|
||||
},
|
||||
{
|
||||
"source": "field_segmentation",
|
||||
"target": "detect_edges"
|
||||
"target": "detect_edges",
|
||||
"transform": {"invert_mask": true}
|
||||
},
|
||||
{
|
||||
"source": "field_segmentation",
|
||||
|
||||
@@ -162,6 +162,16 @@ def node_detect_edges(state: DetectState) -> dict:
|
||||
field_masks = state.get("field_masks", {})
|
||||
job_id = state.get("job_id")
|
||||
|
||||
# Apply edge transforms from upstream connections
|
||||
edge_transforms = state.get("_edge_transforms", {})
|
||||
for source_stage, transform in edge_transforms.items():
|
||||
if transform.get("invert_mask") and field_masks:
|
||||
import numpy as np
|
||||
field_masks = {
|
||||
seq: np.bitwise_not(mask) if mask is not None else None
|
||||
for seq, mask in field_masks.items()
|
||||
}
|
||||
|
||||
regions = detect_edge_regions(
|
||||
frames, config, inference_url=INFERENCE_URL, job_id=job_id,
|
||||
field_masks=field_masks,
|
||||
|
||||
@@ -213,6 +213,13 @@ class PipelineRunner:
|
||||
self.config = config
|
||||
self.do_checkpoint = checkpoint
|
||||
self.stage_sequence = _flatten_config(config, start_from)
|
||||
# Build edge transform lookup: {target_stage: {source_stage: transform_dict}}
|
||||
self._edge_transforms: dict[str, dict[str, dict]] = {}
|
||||
for edge in config.edges:
|
||||
if edge.transform:
|
||||
if edge.target not in self._edge_transforms:
|
||||
self._edge_transforms[edge.target] = {}
|
||||
self._edge_transforms[edge.target][edge.source] = edge.transform
|
||||
|
||||
def invoke(self, state: DetectState) -> DetectState:
|
||||
"""Run the pipeline on the given state. Returns final state."""
|
||||
@@ -224,6 +231,14 @@ class PipelineRunner:
|
||||
if check and check():
|
||||
raise PipelineCancelled(f"Cancelled before {stage_name}")
|
||||
|
||||
# Inject edge transforms into state so the stage can read them.
|
||||
# Compatible with LangGraph — just a state dict key.
|
||||
transforms = self._edge_transforms.get(stage_name, {})
|
||||
if transforms:
|
||||
state["_edge_transforms"] = transforms
|
||||
elif "_edge_transforms" in state:
|
||||
del state["_edge_transforms"]
|
||||
|
||||
# 2. Run node function
|
||||
node_fn = _NODE_FN_MAP.get(stage_name)
|
||||
if node_fn is None:
|
||||
|
||||
@@ -25,6 +25,7 @@ from core.detect.stages.models import (
|
||||
StageDefinition,
|
||||
StageIO,
|
||||
StageOutputHint,
|
||||
TransformOption,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -54,6 +55,9 @@ class FieldSegmentationStage(Stage):
|
||||
output_hints=[
|
||||
StageOutputHint(key="mask_overlay_b64", type="overlay", label="Field mask", default_opacity=0.5, src_format="png"),
|
||||
],
|
||||
accepted_transforms=[
|
||||
TransformOption(key="invert_mask", type="bool", default=False, label="Invert selection", description="Invert the mask so downstream stages look outside the detected area"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -35,6 +35,14 @@ class StageOutputHint(BaseModel):
|
||||
default_opacity: float = 0.5
|
||||
src_format: str = "png"
|
||||
|
||||
class TransformOption(BaseModel):
|
||||
"""A transform the stage accepts on its incoming edges."""
|
||||
key: str
|
||||
type: str
|
||||
default: Any = False
|
||||
label: str = ""
|
||||
description: str = ""
|
||||
|
||||
class StageDefinition(BaseModel):
|
||||
"""Complete metadata for a pipeline stage."""
|
||||
name: str
|
||||
@@ -44,6 +52,7 @@ class StageDefinition(BaseModel):
|
||||
io: StageIO
|
||||
config_fields: List[StageConfigField] = Field(default_factory=list)
|
||||
output_hints: List[StageOutputHint] = Field(default_factory=list)
|
||||
accepted_transforms: List[TransformOption] = Field(default_factory=list)
|
||||
tracks_element: Optional[str] = None
|
||||
|
||||
class FrameExtractionConfig(BaseModel):
|
||||
@@ -105,6 +114,7 @@ class Edge(BaseModel):
|
||||
source: str
|
||||
target: str
|
||||
condition: str = ""
|
||||
transform: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
class PipelineConfig(BaseModel):
|
||||
"""Pipeline graph topology + routing rules."""
|
||||
@@ -112,4 +122,4 @@ class PipelineConfig(BaseModel):
|
||||
profile_name: str
|
||||
stages: List[StageRef] = Field(default_factory=list)
|
||||
edges: List[Edge] = Field(default_factory=list)
|
||||
routing_rules: Dict[str, Any]
|
||||
routing_rules: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@@ -35,6 +35,16 @@ class StageIO:
|
||||
optional_reads: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformOption:
|
||||
"""A transform the stage accepts on its incoming edges."""
|
||||
key: str
|
||||
type: str # "bool", "float", "int", "str"
|
||||
default: Any = False
|
||||
label: str = ""
|
||||
description: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class StageOutputHint:
|
||||
"""How to render a stage output in the compare/editor views."""
|
||||
@@ -55,6 +65,7 @@ class StageDefinition:
|
||||
io: StageIO = field(default_factory=StageIO)
|
||||
config_fields: List[StageConfigField] = field(default_factory=list)
|
||||
output_hints: List[StageOutputHint] = field(default_factory=list)
|
||||
accepted_transforms: List[TransformOption] = field(default_factory=list)
|
||||
tracks_element: Optional[str] = None
|
||||
|
||||
|
||||
@@ -129,10 +140,16 @@ class StageRef:
|
||||
|
||||
@dataclass
|
||||
class Edge:
|
||||
"""Connection between stages in the graph."""
|
||||
"""Connection between stages in the graph.
|
||||
|
||||
transform: per-edge data transformation spec. Flexible JSONB blob
|
||||
type-checked by the consuming stage. E.g. {"invert_mask": true}
|
||||
tells edge detection to invert the field segmentation mask.
|
||||
"""
|
||||
source: str
|
||||
target: str
|
||||
condition: str = ""
|
||||
transform: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -151,6 +168,7 @@ STAGE_VIEWS = [
|
||||
StageConfigField,
|
||||
StageIO,
|
||||
StageOutputHint,
|
||||
TransformOption,
|
||||
StageDefinition,
|
||||
FrameExtractionConfig,
|
||||
SceneFilterConfig,
|
||||
|
||||
Reference in New Issue
Block a user