Files
mitus/cht/transcriber/engine.py
2026-04-03 00:25:14 -03:00

120 lines
3.7 KiB
Python

"""
Transcription engine using faster-whisper.
Processes WAV chunks incrementally, assigns sequential IDs (T0001, T0002, ...),
and persists to transcript/index.json in the session directory.
"""
import json
import logging
import threading
from dataclasses import dataclass, asdict
from pathlib import Path
log = logging.getLogger(__name__)
LANGUAGES = {
"Auto": None,
"English": "en",
"Spanish": "es",
}
@dataclass
class TranscriptSegment:
id: str # "T0001"
start: float # seconds into recording
end: float # seconds into recording
text: str # transcribed text
class TranscriberEngine:
"""Incremental transcription via faster-whisper with GPU acceleration."""
def __init__(self, model_size="small", device="cuda"):
self._model = None
self._model_size = model_size
self._device = device
self._segments: list[TranscriptSegment] = []
self._next_id = 1
self._lock = threading.Lock()
self._stopped = False
self.language = None # None = auto-detect, "en", "es", etc.
def _ensure_model(self):
if self._model is not None:
return
log.info("Loading whisper model: %s (device=%s)", self._model_size, self._device)
from faster_whisper import WhisperModel
self._model = WhisperModel(
self._model_size,
device=self._device,
compute_type="float16" if self._device == "cuda" else "int8",
)
log.info("Whisper model loaded")
def transcribe_chunk(self, wav_path, time_offset=0.0) -> list[TranscriptSegment]:
"""Transcribe a WAV chunk. Returns new segments with absolute timestamps."""
if self._stopped:
return []
self._ensure_model()
try:
kwargs = {"beam_size": 5, "vad_filter": True}
if self.language:
kwargs["language"] = self.language
segments_iter, info = self._model.transcribe(str(wav_path), **kwargs)
except Exception as e:
log.error("Whisper transcription failed: %s", e)
return []
new_segments = []
with self._lock:
if self._stopped:
return []
for seg in segments_iter:
text = seg.text.strip()
if not text:
continue
tid = f"T{self._next_id:04d}"
self._next_id += 1
entry = TranscriptSegment(
id=tid,
start=time_offset + seg.start,
end=time_offset + seg.end,
text=text,
)
self._segments.append(entry)
new_segments.append(entry)
return new_segments
def all_segments(self) -> list[TranscriptSegment]:
return list(self._segments)
def save_index(self, path: Path):
with self._lock:
if self._stopped:
return
data = [asdict(s) for s in self._segments]
path.write_text(json.dumps(data, indent=2))
def load_index(self, path: Path):
try:
data = json.loads(path.read_text())
except Exception as e:
log.warning("Failed to load transcript index: %s", e)
return
with self._lock:
self._segments = [TranscriptSegment(**e) for e in data]
if self._segments:
last_num = max(int(s.id.lstrip("T")) for s in self._segments)
self._next_id = last_num + 1
self._stopped = False
log.info("Loaded %d transcript segments", len(self._segments))
def reset(self):
with self._lock:
self._stopped = True
self._segments.clear()
self._next_id = 1