phase 12
This commit is contained in:
128
detect/stages/preprocess.py
Normal file
128
detect/stages/preprocess.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
Stage 3.5 — Preprocessing
|
||||
|
||||
Runs between YOLO detection and OCR. Applies configurable image
|
||||
preprocessing to each detected region crop: contrast enhancement,
|
||||
deskewing, binarization.
|
||||
|
||||
Operates on the crops derived from boxes_by_frame, produces
|
||||
preprocessed_crops keyed by (frame_sequence, box_index).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
from detect import emit
|
||||
from detect.models import BoundingBox, Frame
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _crop_region(frame: Frame, box: BoundingBox) -> np.ndarray:
|
||||
h, w = frame.image.shape[:2]
|
||||
x1 = max(0, box.x)
|
||||
y1 = max(0, box.y)
|
||||
x2 = min(w, box.x + box.w)
|
||||
y2 = min(h, box.y + box.h)
|
||||
return frame.image[y1:y2, x1:x2]
|
||||
|
||||
|
||||
def preprocess_regions(
|
||||
frames: list[Frame],
|
||||
boxes_by_frame: dict[int, list[BoundingBox]],
|
||||
do_contrast: bool = True,
|
||||
do_deskew: bool = False,
|
||||
do_binarize: bool = False,
|
||||
inference_url: str | None = None,
|
||||
job_id: str | None = None,
|
||||
) -> dict[str, np.ndarray]:
|
||||
"""
|
||||
Preprocess cropped regions from YOLO detections.
|
||||
|
||||
Returns dict keyed by "{frame_seq}_{box_idx}" → preprocessed crop.
|
||||
These are passed to the OCR stage instead of raw crops.
|
||||
"""
|
||||
total_regions = sum(len(boxes) for boxes in boxes_by_frame.values())
|
||||
any_active = do_contrast or do_deskew or do_binarize
|
||||
|
||||
if not any_active:
|
||||
emit.log(job_id, "Preprocess", "INFO",
|
||||
f"Preprocessing disabled, passing {total_regions} regions through")
|
||||
return {}
|
||||
|
||||
mode = "remote" if inference_url else "local"
|
||||
emit.log(job_id, "Preprocess", "INFO",
|
||||
f"Preprocessing {total_regions} regions (mode={mode}, "
|
||||
f"contrast={do_contrast}, deskew={do_deskew}, binarize={do_binarize})")
|
||||
|
||||
frame_map = {f.sequence: f for f in frames}
|
||||
preprocessed: dict[str, np.ndarray] = {}
|
||||
processed_count = 0
|
||||
|
||||
for seq, boxes in boxes_by_frame.items():
|
||||
frame = frame_map.get(seq)
|
||||
if not frame:
|
||||
continue
|
||||
|
||||
for idx, box in enumerate(boxes):
|
||||
crop = _crop_region(frame, box)
|
||||
if crop.size == 0:
|
||||
continue
|
||||
|
||||
key = f"{seq}_{idx}"
|
||||
|
||||
if inference_url:
|
||||
result = _preprocess_remote(crop, inference_url,
|
||||
do_contrast, do_deskew, do_binarize)
|
||||
else:
|
||||
result = _preprocess_local(crop, do_contrast, do_deskew, do_binarize)
|
||||
|
||||
preprocessed[key] = result
|
||||
processed_count += 1
|
||||
|
||||
emit.log(job_id, "Preprocess", "INFO",
|
||||
f"Preprocessed {processed_count} regions")
|
||||
|
||||
return preprocessed
|
||||
|
||||
|
||||
def _preprocess_remote(crop: np.ndarray, inference_url: str,
|
||||
do_contrast: bool, do_deskew: bool, do_binarize: bool) -> np.ndarray:
|
||||
"""Call GPU server /preprocess endpoint."""
|
||||
import base64
|
||||
import io
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
img = Image.fromarray(crop)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG", quality=85)
|
||||
image_b64 = base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
resp = requests.post(
|
||||
f"{inference_url.rstrip('/')}/preprocess",
|
||||
json={
|
||||
"image": image_b64,
|
||||
"contrast": do_contrast,
|
||||
"deskew": do_deskew,
|
||||
"binarize": do_binarize,
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
result_bytes = base64.b64decode(data["image"])
|
||||
result_img = Image.open(io.BytesIO(result_bytes)).convert("RGB")
|
||||
return np.array(result_img)
|
||||
|
||||
|
||||
def _preprocess_local(crop: np.ndarray,
|
||||
do_contrast: bool, do_deskew: bool, do_binarize: bool) -> np.ndarray:
|
||||
"""Run preprocessing in-process (requires opencv-python-headless)."""
|
||||
from gpu.models.preprocess import preprocess
|
||||
return preprocess(crop, do_binarize=do_binarize, do_deskew=do_deskew, do_contrast=do_contrast)
|
||||
Reference in New Issue
Block a user