phase 5: edge transforms, soleprint-ui rename, infra fixes
- pipeline edge transforms: stages can declare accepted_transforms, edges carry a transform dict, runner injects per-stage and nodes apply (e.g. invert_mask before edge detection); editable from UI via PUT /config/edge-transform - rename mpr-ui-framework -> soleprint-ui (now an external package synced via .spr from /home/mariano/wdir/spr); add @vue-flow/core and uplot to detection-app so linked package resolves them - Tiltfile guards kubectl context, k8s commands pin --context kind-mpr - kind-config: gateway on hostPort 30080 (Caddy fronts mpr.local.ar) - modelgen: pyproject.toml, .spr marker, dict default_factory support
This commit is contained in:
@@ -162,6 +162,16 @@ def node_detect_edges(state: DetectState) -> dict:
|
||||
field_masks = state.get("field_masks", {})
|
||||
job_id = state.get("job_id")
|
||||
|
||||
# Apply edge transforms from upstream connections
|
||||
edge_transforms = state.get("_edge_transforms", {})
|
||||
for source_stage, transform in edge_transforms.items():
|
||||
if transform.get("invert_mask") and field_masks:
|
||||
import numpy as np
|
||||
field_masks = {
|
||||
seq: np.bitwise_not(mask) if mask is not None else None
|
||||
for seq, mask in field_masks.items()
|
||||
}
|
||||
|
||||
regions = detect_edge_regions(
|
||||
frames, config, inference_url=INFERENCE_URL, job_id=job_id,
|
||||
field_masks=field_masks,
|
||||
|
||||
@@ -213,6 +213,13 @@ class PipelineRunner:
|
||||
self.config = config
|
||||
self.do_checkpoint = checkpoint
|
||||
self.stage_sequence = _flatten_config(config, start_from)
|
||||
# Build edge transform lookup: {target_stage: {source_stage: transform_dict}}
|
||||
self._edge_transforms: dict[str, dict[str, dict]] = {}
|
||||
for edge in config.edges:
|
||||
if edge.transform:
|
||||
if edge.target not in self._edge_transforms:
|
||||
self._edge_transforms[edge.target] = {}
|
||||
self._edge_transforms[edge.target][edge.source] = edge.transform
|
||||
|
||||
def invoke(self, state: DetectState) -> DetectState:
|
||||
"""Run the pipeline on the given state. Returns final state."""
|
||||
@@ -224,6 +231,14 @@ class PipelineRunner:
|
||||
if check and check():
|
||||
raise PipelineCancelled(f"Cancelled before {stage_name}")
|
||||
|
||||
# Inject edge transforms into state so the stage can read them.
|
||||
# Compatible with LangGraph — just a state dict key.
|
||||
transforms = self._edge_transforms.get(stage_name, {})
|
||||
if transforms:
|
||||
state["_edge_transforms"] = transforms
|
||||
elif "_edge_transforms" in state:
|
||||
del state["_edge_transforms"]
|
||||
|
||||
# 2. Run node function
|
||||
node_fn = _NODE_FN_MAP.get(stage_name)
|
||||
if node_fn is None:
|
||||
|
||||
@@ -25,6 +25,7 @@ from core.detect.stages.models import (
|
||||
StageDefinition,
|
||||
StageIO,
|
||||
StageOutputHint,
|
||||
TransformOption,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -54,6 +55,9 @@ class FieldSegmentationStage(Stage):
|
||||
output_hints=[
|
||||
StageOutputHint(key="mask_overlay_b64", type="overlay", label="Field mask", default_opacity=0.5, src_format="png"),
|
||||
],
|
||||
accepted_transforms=[
|
||||
TransformOption(key="invert_mask", type="bool", default=False, label="Invert selection", description="Invert the mask so downstream stages look outside the detected area"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -35,6 +35,14 @@ class StageOutputHint(BaseModel):
|
||||
default_opacity: float = 0.5
|
||||
src_format: str = "png"
|
||||
|
||||
class TransformOption(BaseModel):
|
||||
"""A transform the stage accepts on its incoming edges."""
|
||||
key: str
|
||||
type: str
|
||||
default: Any = False
|
||||
label: str = ""
|
||||
description: str = ""
|
||||
|
||||
class StageDefinition(BaseModel):
|
||||
"""Complete metadata for a pipeline stage."""
|
||||
name: str
|
||||
@@ -44,6 +52,7 @@ class StageDefinition(BaseModel):
|
||||
io: StageIO
|
||||
config_fields: List[StageConfigField] = Field(default_factory=list)
|
||||
output_hints: List[StageOutputHint] = Field(default_factory=list)
|
||||
accepted_transforms: List[TransformOption] = Field(default_factory=list)
|
||||
tracks_element: Optional[str] = None
|
||||
|
||||
class FrameExtractionConfig(BaseModel):
|
||||
@@ -105,6 +114,7 @@ class Edge(BaseModel):
|
||||
source: str
|
||||
target: str
|
||||
condition: str = ""
|
||||
transform: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
class PipelineConfig(BaseModel):
|
||||
"""Pipeline graph topology + routing rules."""
|
||||
@@ -112,4 +122,4 @@ class PipelineConfig(BaseModel):
|
||||
profile_name: str
|
||||
stages: List[StageRef] = Field(default_factory=list)
|
||||
edges: List[Edge] = Field(default_factory=list)
|
||||
routing_rules: Dict[str, Any]
|
||||
routing_rules: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
Reference in New Issue
Block a user