phase 2
This commit is contained in:
23
detect/profiles/__init__.py
Normal file
23
detect/profiles/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from .base import (
|
||||
ContentTypeProfile,
|
||||
BrandDictionary,
|
||||
CropContext,
|
||||
DetectionConfig,
|
||||
FrameExtractionConfig,
|
||||
OCRConfig,
|
||||
ResolverConfig,
|
||||
SceneFilterConfig,
|
||||
)
|
||||
from .soccer import SoccerBroadcastProfile
|
||||
|
||||
__all__ = [
|
||||
"ContentTypeProfile",
|
||||
"BrandDictionary",
|
||||
"CropContext",
|
||||
"DetectionConfig",
|
||||
"FrameExtractionConfig",
|
||||
"OCRConfig",
|
||||
"ResolverConfig",
|
||||
"SceneFilterConfig",
|
||||
"SoccerBroadcastProfile",
|
||||
]
|
||||
71
detect/profiles/base.py
Normal file
71
detect/profiles/base.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
ContentTypeProfile protocol and config dataclasses.
|
||||
|
||||
The pipeline graph is fixed — what varies per content type is configuration
|
||||
and hooks. Each profile provides stage configs, a brand dictionary,
|
||||
VLM prompt templates, and an aggregation strategy.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Protocol
|
||||
|
||||
from detect.models import BrandDetection, DetectionReport
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrameExtractionConfig:
|
||||
fps: float = 2.0
|
||||
max_frames: int = 500
|
||||
|
||||
|
||||
@dataclass
|
||||
class SceneFilterConfig:
|
||||
hamming_threshold: int = 8
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectionConfig:
|
||||
model_name: str = "yolov8n.pt"
|
||||
confidence_threshold: float = 0.3
|
||||
target_classes: list[str] = field(default_factory=lambda: ["logo", "text"])
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRConfig:
|
||||
languages: list[str] = field(default_factory=lambda: ["en"])
|
||||
min_confidence: float = 0.5
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResolverConfig:
|
||||
fuzzy_threshold: int = 75
|
||||
|
||||
|
||||
@dataclass
|
||||
class BrandDictionary:
|
||||
"""Maps canonical brand name → list of known aliases/spellings."""
|
||||
brands: dict[str, list[str]] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CropContext:
|
||||
image: bytes
|
||||
surrounding_text: str = ""
|
||||
position_hint: str = ""
|
||||
|
||||
|
||||
class ContentTypeProfile(Protocol):
|
||||
name: str
|
||||
|
||||
def frame_extraction_config(self) -> FrameExtractionConfig: ...
|
||||
def scene_filter_config(self) -> SceneFilterConfig: ...
|
||||
def detection_config(self) -> DetectionConfig: ...
|
||||
def ocr_config(self) -> OCRConfig: ...
|
||||
def brand_dictionary(self) -> BrandDictionary: ...
|
||||
def resolver_config(self) -> ResolverConfig: ...
|
||||
def vlm_prompt(self, crop_context: CropContext) -> str: ...
|
||||
def aggregate(self, detections: list[BrandDetection]) -> DetectionReport: ...
|
||||
def auxiliary_detections(self, source: str) -> list[BrandDetection]: ...
|
||||
92
detect/profiles/soccer.py
Normal file
92
detect/profiles/soccer.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Soccer broadcast profile — pitch hoardings, kits, scoreboards."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from detect.models import BrandDetection, BrandStats, DetectionReport, PipelineStats
|
||||
|
||||
from .base import (
|
||||
BrandDictionary,
|
||||
CropContext,
|
||||
DetectionConfig,
|
||||
FrameExtractionConfig,
|
||||
OCRConfig,
|
||||
ResolverConfig,
|
||||
SceneFilterConfig,
|
||||
)
|
||||
|
||||
|
||||
class SoccerBroadcastProfile:
|
||||
name = "soccer_broadcast"
|
||||
|
||||
def frame_extraction_config(self) -> FrameExtractionConfig:
|
||||
return FrameExtractionConfig(fps=2.0, max_frames=500)
|
||||
|
||||
def scene_filter_config(self) -> SceneFilterConfig:
|
||||
return SceneFilterConfig(hamming_threshold=8, enabled=True)
|
||||
|
||||
def detection_config(self) -> DetectionConfig:
|
||||
return DetectionConfig(
|
||||
model_name="yolov8n.pt",
|
||||
confidence_threshold=0.3,
|
||||
target_classes=["logo", "text", "banner", "scoreboard"],
|
||||
)
|
||||
|
||||
def ocr_config(self) -> OCRConfig:
|
||||
return OCRConfig(languages=["en", "es"], min_confidence=0.5)
|
||||
|
||||
def brand_dictionary(self) -> BrandDictionary:
|
||||
return BrandDictionary(brands={
|
||||
"Nike": ["nike", "NIKE", "swoosh"],
|
||||
"Adidas": ["adidas", "ADIDAS", "adi"],
|
||||
"Puma": ["puma", "PUMA"],
|
||||
"Emirates": ["emirates", "fly emirates", "EMIRATES"],
|
||||
"Coca-Cola": ["coca-cola", "coca cola", "coke", "COCA-COLA"],
|
||||
"Pepsi": ["pepsi", "PEPSI"],
|
||||
"Mastercard": ["mastercard", "MASTERCARD"],
|
||||
"Heineken": ["heineken", "HEINEKEN"],
|
||||
"Santander": ["santander", "SANTANDER"],
|
||||
"Gazprom": ["gazprom", "GAZPROM"],
|
||||
"Qatar Airways": ["qatar airways", "QATAR AIRWAYS"],
|
||||
"Lay's": ["lays", "lay's", "LAYS", "LAY'S"],
|
||||
})
|
||||
|
||||
def resolver_config(self) -> ResolverConfig:
|
||||
return ResolverConfig(fuzzy_threshold=75)
|
||||
|
||||
def vlm_prompt(self, crop_context: CropContext) -> str:
|
||||
hint = f" Position: {crop_context.position_hint}." if crop_context.position_hint else ""
|
||||
text = f" Nearby text: '{crop_context.surrounding_text}'." if crop_context.surrounding_text else ""
|
||||
return (
|
||||
f"Identify the brand or sponsor visible in this cropped region "
|
||||
f"from a soccer broadcast.{hint}{text} "
|
||||
f"Respond with: brand, confidence (0-1), reasoning."
|
||||
)
|
||||
|
||||
def aggregate(self, detections: list[BrandDetection]) -> DetectionReport:
|
||||
brands: dict[str, BrandStats] = {}
|
||||
for d in detections:
|
||||
if d.brand not in brands:
|
||||
brands[d.brand] = BrandStats()
|
||||
s = brands[d.brand]
|
||||
s.total_appearances += 1
|
||||
s.total_screen_time += d.duration
|
||||
s.avg_confidence = (
|
||||
(s.avg_confidence * (s.total_appearances - 1) + d.confidence)
|
||||
/ s.total_appearances
|
||||
)
|
||||
if s.first_seen == 0.0 or d.timestamp < s.first_seen:
|
||||
s.first_seen = d.timestamp
|
||||
if d.timestamp > s.last_seen:
|
||||
s.last_seen = d.timestamp
|
||||
|
||||
return DetectionReport(
|
||||
video_source="",
|
||||
content_type=self.name,
|
||||
duration_seconds=0.0,
|
||||
brands=brands,
|
||||
timeline=sorted(detections, key=lambda d: d.timestamp),
|
||||
pipeline_stats=PipelineStats(),
|
||||
)
|
||||
|
||||
def auxiliary_detections(self, source: str) -> list[BrandDetection]:
|
||||
return []
|
||||
108
detect/profiles/stubs.py
Normal file
108
detect/profiles/stubs.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Stub profiles — interfaces defined, not yet implemented."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from detect.models import BrandDetection, DetectionReport
|
||||
|
||||
from .base import (
|
||||
BrandDictionary,
|
||||
CropContext,
|
||||
DetectionConfig,
|
||||
FrameExtractionConfig,
|
||||
OCRConfig,
|
||||
ResolverConfig,
|
||||
SceneFilterConfig,
|
||||
)
|
||||
|
||||
|
||||
class NewsBroadcastProfile:
|
||||
name = "news_broadcast"
|
||||
|
||||
def frame_extraction_config(self) -> FrameExtractionConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
def scene_filter_config(self) -> SceneFilterConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
def detection_config(self) -> DetectionConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
def ocr_config(self) -> OCRConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
def brand_dictionary(self) -> BrandDictionary:
|
||||
raise NotImplementedError
|
||||
|
||||
def resolver_config(self) -> ResolverConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
def vlm_prompt(self, crop_context: CropContext) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def aggregate(self, detections: list[BrandDetection]) -> DetectionReport:
|
||||
raise NotImplementedError
|
||||
|
||||
def auxiliary_detections(self, source: str) -> list[BrandDetection]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AdvertisingProfile:
|
||||
name = "advertising"
|
||||
|
||||
def frame_extraction_config(self) -> FrameExtractionConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
def scene_filter_config(self) -> SceneFilterConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
def detection_config(self) -> DetectionConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
def ocr_config(self) -> OCRConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
def brand_dictionary(self) -> BrandDictionary:
|
||||
raise NotImplementedError
|
||||
|
||||
def resolver_config(self) -> ResolverConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
def vlm_prompt(self, crop_context: CropContext) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def aggregate(self, detections: list[BrandDetection]) -> DetectionReport:
|
||||
raise NotImplementedError
|
||||
|
||||
def auxiliary_detections(self, source: str) -> list[BrandDetection]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TranscriptProfile:
|
||||
name = "transcript"
|
||||
|
||||
def frame_extraction_config(self) -> FrameExtractionConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
def scene_filter_config(self) -> SceneFilterConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
def detection_config(self) -> DetectionConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
def ocr_config(self) -> OCRConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
def brand_dictionary(self) -> BrandDictionary:
|
||||
raise NotImplementedError
|
||||
|
||||
def resolver_config(self) -> ResolverConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
def vlm_prompt(self, crop_context: CropContext) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def aggregate(self, detections: list[BrandDetection]) -> DetectionReport:
|
||||
raise NotImplementedError
|
||||
|
||||
def auxiliary_detections(self, source: str) -> list[BrandDetection]:
|
||||
raise NotImplementedError
|
||||
Reference in New Issue
Block a user