163 lines
5.1 KiB
Python
163 lines
5.1 KiB
Python
"""
|
|
Transcription engine using faster-whisper.
|
|
|
|
Processes WAV chunks incrementally, assigns sequential IDs (T0001, T0002, ...),
|
|
and persists to transcript/index.json in the session directory.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import threading
|
|
from dataclasses import dataclass, asdict
|
|
from pathlib import Path
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
LANGUAGES = {
|
|
"Auto": None,
|
|
"English": "en",
|
|
"Spanish": "es",
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class TranscriptSegment:
|
|
id: str # "T0001"
|
|
start: float # seconds into recording
|
|
end: float # seconds into recording
|
|
text: str # transcribed text
|
|
|
|
|
|
class TranscriberEngine:
|
|
"""Incremental transcription via faster-whisper with GPU acceleration."""
|
|
|
|
def __init__(self, model_size="small", device="cuda"):
|
|
self._model = None
|
|
self._model_size = model_size
|
|
self._device = device
|
|
self._segments: list[TranscriptSegment] = []
|
|
self._next_id = 1
|
|
self._lock = threading.Lock()
|
|
self._stopped = False
|
|
self.language = None # None = auto-detect, "en", "es", etc.
|
|
|
|
def _ensure_model(self):
|
|
if self._model is not None:
|
|
return
|
|
log.info("Loading whisper model: %s (device=%s)", self._model_size, self._device)
|
|
from faster_whisper import WhisperModel
|
|
self._model = WhisperModel(
|
|
self._model_size,
|
|
device=self._device,
|
|
compute_type="float16" if self._device == "cuda" else "int8",
|
|
)
|
|
log.info("Whisper model loaded")
|
|
|
|
def transcribe_chunk(self, wav_path, time_offset=0.0) -> list[TranscriptSegment]:
|
|
"""Transcribe a WAV chunk. Returns new segments with absolute timestamps."""
|
|
if self._stopped:
|
|
return []
|
|
self._ensure_model()
|
|
try:
|
|
kwargs = {
|
|
"beam_size": 5,
|
|
"vad_filter": True,
|
|
"condition_on_previous_text": True,
|
|
}
|
|
if self.language:
|
|
kwargs["language"] = self.language
|
|
# Feed last transcript text as context for better continuity
|
|
if self._segments:
|
|
kwargs["initial_prompt"] = self._segments[-1].text
|
|
segments_iter, info = self._model.transcribe(str(wav_path), **kwargs)
|
|
except Exception as e:
|
|
log.error("Whisper transcription failed: %s", e)
|
|
return []
|
|
|
|
# Group whisper segments: new T-ID every N lines or on silence gap (>1s)
|
|
from cht.config import TRANSCRIBE_LINES_PER_GROUP
|
|
lines_per_group = TRANSCRIBE_LINES_PER_GROUP
|
|
SILENCE_GAP_S = 1.0
|
|
|
|
raw_segs = []
|
|
for seg in segments_iter:
|
|
text = seg.text.strip()
|
|
if text:
|
|
raw_segs.append((time_offset + seg.start, time_offset + seg.end, text))
|
|
|
|
new_segments = []
|
|
with self._lock:
|
|
if self._stopped:
|
|
return []
|
|
|
|
group_start = None
|
|
group_end = None
|
|
group_lines = []
|
|
prev_end = None
|
|
|
|
def _flush():
|
|
nonlocal group_start, group_end, group_lines
|
|
if not group_lines:
|
|
return
|
|
tid = f"T{self._next_id:04d}"
|
|
self._next_id += 1
|
|
entry = TranscriptSegment(
|
|
id=tid,
|
|
start=group_start,
|
|
end=group_end,
|
|
text=" ".join(group_lines),
|
|
)
|
|
self._segments.append(entry)
|
|
new_segments.append(entry)
|
|
group_lines = []
|
|
group_start = None
|
|
group_end = None
|
|
|
|
for start, end, text in raw_segs:
|
|
# Silence gap → flush current group
|
|
if prev_end is not None and start - prev_end > SILENCE_GAP_S:
|
|
_flush()
|
|
|
|
if group_start is None:
|
|
group_start = start
|
|
group_end = end
|
|
group_lines.append(text)
|
|
prev_end = end
|
|
|
|
if len(group_lines) >= lines_per_group:
|
|
_flush()
|
|
|
|
_flush()
|
|
|
|
return new_segments
|
|
|
|
def all_segments(self) -> list[TranscriptSegment]:
|
|
return list(self._segments)
|
|
|
|
def save_index(self, path: Path):
|
|
with self._lock:
|
|
if self._stopped:
|
|
return
|
|
data = [asdict(s) for s in self._segments]
|
|
path.write_text(json.dumps(data, indent=2))
|
|
|
|
def load_index(self, path: Path):
|
|
try:
|
|
data = json.loads(path.read_text())
|
|
except Exception as e:
|
|
log.warning("Failed to load transcript index: %s", e)
|
|
return
|
|
with self._lock:
|
|
self._segments = [TranscriptSegment(**e) for e in data]
|
|
if self._segments:
|
|
last_num = max(int(s.id.lstrip("T")) for s in self._segments)
|
|
self._next_id = last_num + 1
|
|
self._stopped = False
|
|
log.info("Loaded %d transcript segments", len(self._segments))
|
|
|
|
def reset(self):
|
|
with self._lock:
|
|
self._stopped = True
|
|
self._segments.clear()
|
|
self._next_id = 1
|