123 lines
4.5 KiB
Python
123 lines
4.5 KiB
Python
"""Soccer broadcast profile — pitch hoardings, kits, scoreboards."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from core.schema.models.pipeline_config import PipelineConfig
|
|
from detect.models import BrandDetection, BrandStats, DetectionReport, PipelineStats
|
|
|
|
from .base import (
|
|
CropContext,
|
|
DetectionConfig,
|
|
FrameExtractionConfig,
|
|
OCRConfig,
|
|
RegionAnalysisConfig,
|
|
ResolverConfig,
|
|
SceneFilterConfig,
|
|
pipeline_config_from_dict,
|
|
)
|
|
|
|
|
|
class SoccerBroadcastProfile:
|
|
name = "soccer_broadcast"
|
|
|
|
# Pipeline topology as JSONB — will be a DB field when profiles are persisted
|
|
pipeline = {
|
|
"name": "soccer_broadcast",
|
|
"profile_name": "soccer_broadcast",
|
|
"stages": [
|
|
{"name": "extract_frames", "branch": "trunk"},
|
|
{"name": "filter_scenes", "branch": "trunk"},
|
|
{"name": "detect_edges", "branch": "hoarding"},
|
|
{"name": "detect_objects", "branch": "objects"},
|
|
{"name": "preprocess"},
|
|
{"name": "run_ocr"},
|
|
{"name": "match_brands"},
|
|
{"name": "escalate_vlm"},
|
|
{"name": "escalate_cloud"},
|
|
{"name": "compile_report"},
|
|
],
|
|
"edges": [
|
|
{"source": "extract_frames", "target": "filter_scenes"},
|
|
{"source": "filter_scenes", "target": "detect_edges"},
|
|
{"source": "filter_scenes", "target": "detect_objects"},
|
|
{"source": "detect_edges", "target": "preprocess"},
|
|
{"source": "detect_objects", "target": "preprocess"},
|
|
{"source": "preprocess", "target": "run_ocr"},
|
|
{"source": "run_ocr", "target": "match_brands"},
|
|
{"source": "match_brands", "target": "escalate_vlm"},
|
|
{"source": "escalate_vlm", "target": "escalate_cloud"},
|
|
{"source": "escalate_cloud", "target": "compile_report"},
|
|
],
|
|
}
|
|
|
|
def pipeline_config(self) -> PipelineConfig:
|
|
return pipeline_config_from_dict(self.pipeline)
|
|
|
|
def frame_extraction_config(self) -> FrameExtractionConfig:
|
|
return FrameExtractionConfig(fps=2.0, max_frames=500)
|
|
|
|
def scene_filter_config(self) -> SceneFilterConfig:
|
|
return SceneFilterConfig(hamming_threshold=8, enabled=True)
|
|
|
|
def region_analysis_config(self) -> RegionAnalysisConfig:
|
|
return RegionAnalysisConfig(
|
|
edge_canny_low=50,
|
|
edge_canny_high=150,
|
|
edge_hough_threshold=80,
|
|
edge_hough_min_length=100,
|
|
edge_hough_max_gap=10,
|
|
edge_pair_max_distance=200,
|
|
edge_pair_min_distance=15,
|
|
)
|
|
|
|
def detection_config(self) -> DetectionConfig:
|
|
return DetectionConfig(
|
|
model_name="yolov8n.pt",
|
|
confidence_threshold=0.3,
|
|
target_classes=[], # empty = accept all COCO classes (until custom model)
|
|
)
|
|
|
|
def ocr_config(self) -> OCRConfig:
|
|
return OCRConfig(languages=["en", "es"], min_confidence=0.5)
|
|
|
|
def resolver_config(self) -> ResolverConfig:
|
|
return ResolverConfig(fuzzy_threshold=75)
|
|
|
|
def vlm_prompt(self, crop_context: CropContext) -> str:
|
|
hint = f" Position: {crop_context.position_hint}." if crop_context.position_hint else ""
|
|
text = f" Nearby text: '{crop_context.surrounding_text}'." if crop_context.surrounding_text else ""
|
|
return (
|
|
f"Identify the brand or sponsor visible in this cropped region "
|
|
f"from a soccer broadcast.{hint}{text} "
|
|
f"Respond with: brand, confidence (0-1), reasoning."
|
|
)
|
|
|
|
def aggregate(self, detections: list[BrandDetection]) -> DetectionReport:
|
|
brands: dict[str, BrandStats] = {}
|
|
for d in detections:
|
|
if d.brand not in brands:
|
|
brands[d.brand] = BrandStats()
|
|
s = brands[d.brand]
|
|
s.total_appearances += 1
|
|
s.total_screen_time += d.duration
|
|
s.avg_confidence = (
|
|
(s.avg_confidence * (s.total_appearances - 1) + d.confidence)
|
|
/ s.total_appearances
|
|
)
|
|
if s.first_seen == 0.0 or d.timestamp < s.first_seen:
|
|
s.first_seen = d.timestamp
|
|
if d.timestamp > s.last_seen:
|
|
s.last_seen = d.timestamp
|
|
|
|
return DetectionReport(
|
|
video_source="",
|
|
content_type=self.name,
|
|
duration_seconds=0.0,
|
|
brands=brands,
|
|
timeline=sorted(detections, key=lambda d: d.timestamp),
|
|
pipeline_stats=PipelineStats(),
|
|
)
|
|
|
|
def auxiliary_detections(self, source: str) -> list[BrandDetection]:
|
|
return []
|