phase 12
This commit is contained in:
@@ -18,6 +18,7 @@ from detect.state import DetectState
|
||||
from detect.stages.frame_extractor import extract_frames
|
||||
from detect.stages.scene_filter import scene_filter
|
||||
from detect.stages.yolo_detector import detect_objects
|
||||
from detect.stages.preprocess import preprocess_regions
|
||||
from detect.stages.ocr_stage import run_ocr
|
||||
from detect.stages.brand_resolver import resolve_brands
|
||||
from detect.stages.vlm_local import escalate_vlm
|
||||
@@ -31,6 +32,7 @@ NODES = [
|
||||
"extract_frames",
|
||||
"filter_scenes",
|
||||
"detect_objects",
|
||||
"preprocess",
|
||||
"run_ocr",
|
||||
"match_brands",
|
||||
"escalate_vlm",
|
||||
@@ -137,6 +139,36 @@ def node_detect_objects(state: DetectState) -> dict:
|
||||
return {"boxes_by_frame": all_boxes, "stats": stats}
|
||||
|
||||
|
||||
def node_preprocess(state: DetectState) -> dict:
|
||||
_emit_transition(state, "preprocess", "running")
|
||||
|
||||
with trace_node(state, "preprocess") as span:
|
||||
profile = _get_profile(state)
|
||||
frames = state.get("filtered_frames", [])
|
||||
boxes = state.get("boxes_by_frame", {})
|
||||
job_id = state.get("job_id")
|
||||
|
||||
# Get preprocessing config from profile overrides or defaults
|
||||
overrides = state.get("config_overrides", {})
|
||||
prep_config = overrides.get("preprocessing", {})
|
||||
do_contrast = prep_config.get("contrast", True)
|
||||
do_deskew = prep_config.get("deskew", False)
|
||||
do_binarize = prep_config.get("binarize", False)
|
||||
|
||||
preprocessed = preprocess_regions(
|
||||
frames, boxes,
|
||||
do_contrast=do_contrast,
|
||||
do_deskew=do_deskew,
|
||||
do_binarize=do_binarize,
|
||||
inference_url=INFERENCE_URL,
|
||||
job_id=job_id,
|
||||
)
|
||||
span.set_output({"regions_preprocessed": len(preprocessed)})
|
||||
|
||||
_emit_transition(state, "preprocess", "done")
|
||||
return {"preprocessed_crops": preprocessed}
|
||||
|
||||
|
||||
def node_run_ocr(state: DetectState) -> dict:
|
||||
_emit_transition(state, "run_ocr", "running")
|
||||
|
||||
@@ -304,6 +336,7 @@ NODE_FUNCTIONS = [
|
||||
("extract_frames", node_extract_frames),
|
||||
("filter_scenes", node_filter_scenes),
|
||||
("detect_objects", node_detect_objects),
|
||||
("preprocess", node_preprocess),
|
||||
("run_ocr", node_run_ocr),
|
||||
("match_brands", node_match_brands),
|
||||
("escalate_vlm", node_escalate_vlm),
|
||||
|
||||
@@ -101,3 +101,40 @@ class JobComplete(BaseModel):
|
||||
"""Final report when pipeline finishes. SSE event: job_complete"""
|
||||
job_id: str
|
||||
report: Optional[DetectionReportSummary] = None
|
||||
|
||||
class RunContext(BaseModel):
|
||||
"""Run context injected into all SSE events for grouping."""
|
||||
run_id: str
|
||||
parent_job_id: str
|
||||
run_type: str = "initial"
|
||||
|
||||
class CheckpointInfo(BaseModel):
|
||||
"""Available checkpoint for a stage."""
|
||||
stage: str
|
||||
|
||||
class ReplayRequest(BaseModel):
|
||||
"""Request to replay pipeline from a specific stage."""
|
||||
job_id: str
|
||||
start_stage: str
|
||||
config_overrides: Optional[Dict[str, Any]] = None
|
||||
|
||||
class ReplayResponse(BaseModel):
|
||||
"""Result of a replay invocation."""
|
||||
status: str
|
||||
job_id: str
|
||||
start_stage: str
|
||||
detections: int = 0
|
||||
brands_found: int = 0
|
||||
|
||||
class RetryRequest(BaseModel):
|
||||
"""Request to queue async retry with different config."""
|
||||
job_id: str
|
||||
config_overrides: Optional[Dict[str, Any]] = None
|
||||
start_stage: str = "escalate_vlm"
|
||||
schedule_seconds: Optional[float] = None
|
||||
|
||||
class RetryResponse(BaseModel):
|
||||
"""Result of queueing a retry task."""
|
||||
status: str
|
||||
task_id: str
|
||||
job_id: str
|
||||
|
||||
128
detect/stages/preprocess.py
Normal file
128
detect/stages/preprocess.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
Stage 3.5 — Preprocessing
|
||||
|
||||
Runs between YOLO detection and OCR. Applies configurable image
|
||||
preprocessing to each detected region crop: contrast enhancement,
|
||||
deskewing, binarization.
|
||||
|
||||
Operates on the crops derived from boxes_by_frame, produces
|
||||
preprocessed_crops keyed by (frame_sequence, box_index).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
from detect import emit
|
||||
from detect.models import BoundingBox, Frame
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _crop_region(frame: Frame, box: BoundingBox) -> np.ndarray:
|
||||
h, w = frame.image.shape[:2]
|
||||
x1 = max(0, box.x)
|
||||
y1 = max(0, box.y)
|
||||
x2 = min(w, box.x + box.w)
|
||||
y2 = min(h, box.y + box.h)
|
||||
return frame.image[y1:y2, x1:x2]
|
||||
|
||||
|
||||
def preprocess_regions(
|
||||
frames: list[Frame],
|
||||
boxes_by_frame: dict[int, list[BoundingBox]],
|
||||
do_contrast: bool = True,
|
||||
do_deskew: bool = False,
|
||||
do_binarize: bool = False,
|
||||
inference_url: str | None = None,
|
||||
job_id: str | None = None,
|
||||
) -> dict[str, np.ndarray]:
|
||||
"""
|
||||
Preprocess cropped regions from YOLO detections.
|
||||
|
||||
Returns dict keyed by "{frame_seq}_{box_idx}" → preprocessed crop.
|
||||
These are passed to the OCR stage instead of raw crops.
|
||||
"""
|
||||
total_regions = sum(len(boxes) for boxes in boxes_by_frame.values())
|
||||
any_active = do_contrast or do_deskew or do_binarize
|
||||
|
||||
if not any_active:
|
||||
emit.log(job_id, "Preprocess", "INFO",
|
||||
f"Preprocessing disabled, passing {total_regions} regions through")
|
||||
return {}
|
||||
|
||||
mode = "remote" if inference_url else "local"
|
||||
emit.log(job_id, "Preprocess", "INFO",
|
||||
f"Preprocessing {total_regions} regions (mode={mode}, "
|
||||
f"contrast={do_contrast}, deskew={do_deskew}, binarize={do_binarize})")
|
||||
|
||||
frame_map = {f.sequence: f for f in frames}
|
||||
preprocessed: dict[str, np.ndarray] = {}
|
||||
processed_count = 0
|
||||
|
||||
for seq, boxes in boxes_by_frame.items():
|
||||
frame = frame_map.get(seq)
|
||||
if not frame:
|
||||
continue
|
||||
|
||||
for idx, box in enumerate(boxes):
|
||||
crop = _crop_region(frame, box)
|
||||
if crop.size == 0:
|
||||
continue
|
||||
|
||||
key = f"{seq}_{idx}"
|
||||
|
||||
if inference_url:
|
||||
result = _preprocess_remote(crop, inference_url,
|
||||
do_contrast, do_deskew, do_binarize)
|
||||
else:
|
||||
result = _preprocess_local(crop, do_contrast, do_deskew, do_binarize)
|
||||
|
||||
preprocessed[key] = result
|
||||
processed_count += 1
|
||||
|
||||
emit.log(job_id, "Preprocess", "INFO",
|
||||
f"Preprocessed {processed_count} regions")
|
||||
|
||||
return preprocessed
|
||||
|
||||
|
||||
def _preprocess_remote(crop: np.ndarray, inference_url: str,
|
||||
do_contrast: bool, do_deskew: bool, do_binarize: bool) -> np.ndarray:
|
||||
"""Call GPU server /preprocess endpoint."""
|
||||
import base64
|
||||
import io
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
img = Image.fromarray(crop)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG", quality=85)
|
||||
image_b64 = base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
resp = requests.post(
|
||||
f"{inference_url.rstrip('/')}/preprocess",
|
||||
json={
|
||||
"image": image_b64,
|
||||
"contrast": do_contrast,
|
||||
"deskew": do_deskew,
|
||||
"binarize": do_binarize,
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
result_bytes = base64.b64decode(data["image"])
|
||||
result_img = Image.open(io.BytesIO(result_bytes)).convert("RGB")
|
||||
return np.array(result_img)
|
||||
|
||||
|
||||
def _preprocess_local(crop: np.ndarray,
|
||||
do_contrast: bool, do_deskew: bool, do_binarize: bool) -> np.ndarray:
|
||||
"""Run preprocessing in-process (requires opencv-python-headless)."""
|
||||
from gpu.models.preprocess import preprocess
|
||||
return preprocess(crop, do_binarize=do_binarize, do_deskew=do_deskew, do_contrast=do_contrast)
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Registration for preprocessing stages: frame extraction, scene filter."""
|
||||
"""Registration for preprocessing stages: frame extraction, scene filter, image preprocessing."""
|
||||
|
||||
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
|
||||
from ._serializers import serialize_frames, deserialize_frames
|
||||
@@ -25,6 +25,17 @@ def _deser_filter(data: dict, job_id: str) -> dict:
|
||||
return {"_filtered_sequences": data["filtered_frame_sequences"]}
|
||||
|
||||
|
||||
def _ser_preprocess(state: dict, job_id: str) -> dict:
|
||||
# Preprocessed crops are numpy arrays — regenerable from frames + boxes + config
|
||||
crops = state.get("preprocessed_crops", {})
|
||||
return {"crop_keys": list(crops.keys()), "count": len(crops)}
|
||||
|
||||
|
||||
def _deser_preprocess(data: dict, job_id: str) -> dict:
|
||||
# Crops are regenerable — no need to restore from checkpoint
|
||||
return {"preprocessed_crops": {}}
|
||||
|
||||
|
||||
def register():
|
||||
extract = StageDefinition(
|
||||
name="extract_frames",
|
||||
@@ -55,3 +66,22 @@ def register():
|
||||
deserialize_fn=_deser_filter,
|
||||
)
|
||||
register_stage(scene_filter)
|
||||
|
||||
preprocess = StageDefinition(
|
||||
name="preprocess",
|
||||
label="Preprocess",
|
||||
description="Image preprocessing on detected regions before OCR",
|
||||
category="preprocessing",
|
||||
io=StageIO(
|
||||
reads=["filtered_frames", "boxes_by_frame"],
|
||||
writes=["preprocessed_crops"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField("contrast", "bool", True, "CLAHE contrast enhancement"),
|
||||
StageConfigField("deskew", "bool", False, "Correct slight rotation"),
|
||||
StageConfigField("binarize", "bool", False, "Otsu binarization"),
|
||||
],
|
||||
serialize_fn=_ser_preprocess,
|
||||
deserialize_fn=_deser_preprocess,
|
||||
)
|
||||
register_stage(preprocess)
|
||||
|
||||
@@ -23,6 +23,7 @@ class DetectState(TypedDict, total=False):
|
||||
frames: list[Frame]
|
||||
filtered_frames: list[Frame]
|
||||
boxes_by_frame: dict[int, list[BoundingBox]]
|
||||
preprocessed_crops: dict # "{frame_seq}_{box_idx}" → np.ndarray
|
||||
text_candidates: list[TextCandidate]
|
||||
unresolved_candidates: list[TextCandidate]
|
||||
detections: list[BrandDetection]
|
||||
|
||||
Reference in New Issue
Block a user