tweaks
This commit is contained in:
@@ -7,11 +7,18 @@ and persists to transcript/index.json in the session directory.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass, asdict
|
||||
from pathlib import Path
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
LANGUAGES = {
|
||||
"Auto": None,
|
||||
"English": "en",
|
||||
"Spanish": "es",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptSegment:
|
||||
@@ -30,6 +37,9 @@ class TranscriberEngine:
|
||||
self._device = device
|
||||
self._segments: list[TranscriptSegment] = []
|
||||
self._next_id = 1
|
||||
self._lock = threading.Lock()
|
||||
self._stopped = False
|
||||
self.language = None # None = auto-detect, "en", "es", etc.
|
||||
|
||||
def _ensure_model(self):
|
||||
if self._model is not None:
|
||||
@@ -45,32 +55,36 @@ class TranscriberEngine:
|
||||
|
||||
def transcribe_chunk(self, wav_path, time_offset=0.0) -> list[TranscriptSegment]:
|
||||
"""Transcribe a WAV chunk. Returns new segments with absolute timestamps."""
|
||||
if self._stopped:
|
||||
return []
|
||||
self._ensure_model()
|
||||
try:
|
||||
segments_iter, _info = self._model.transcribe(
|
||||
str(wav_path),
|
||||
beam_size=5,
|
||||
vad_filter=True,
|
||||
)
|
||||
kwargs = {"beam_size": 5, "vad_filter": True}
|
||||
if self.language:
|
||||
kwargs["language"] = self.language
|
||||
segments_iter, info = self._model.transcribe(str(wav_path), **kwargs)
|
||||
except Exception as e:
|
||||
log.error("Whisper transcription failed: %s", e)
|
||||
return []
|
||||
|
||||
new_segments = []
|
||||
for seg in segments_iter:
|
||||
text = seg.text.strip()
|
||||
if not text:
|
||||
continue
|
||||
tid = f"T{self._next_id:04d}"
|
||||
self._next_id += 1
|
||||
entry = TranscriptSegment(
|
||||
id=tid,
|
||||
start=time_offset + seg.start,
|
||||
end=time_offset + seg.end,
|
||||
text=text,
|
||||
)
|
||||
self._segments.append(entry)
|
||||
new_segments.append(entry)
|
||||
with self._lock:
|
||||
if self._stopped:
|
||||
return []
|
||||
for seg in segments_iter:
|
||||
text = seg.text.strip()
|
||||
if not text:
|
||||
continue
|
||||
tid = f"T{self._next_id:04d}"
|
||||
self._next_id += 1
|
||||
entry = TranscriptSegment(
|
||||
id=tid,
|
||||
start=time_offset + seg.start,
|
||||
end=time_offset + seg.end,
|
||||
text=text,
|
||||
)
|
||||
self._segments.append(entry)
|
||||
new_segments.append(entry)
|
||||
|
||||
return new_segments
|
||||
|
||||
@@ -78,7 +92,10 @@ class TranscriberEngine:
|
||||
return list(self._segments)
|
||||
|
||||
def save_index(self, path: Path):
|
||||
data = [asdict(s) for s in self._segments]
|
||||
with self._lock:
|
||||
if self._stopped:
|
||||
return
|
||||
data = [asdict(s) for s in self._segments]
|
||||
path.write_text(json.dumps(data, indent=2))
|
||||
|
||||
def load_index(self, path: Path):
|
||||
@@ -87,12 +104,16 @@ class TranscriberEngine:
|
||||
except Exception as e:
|
||||
log.warning("Failed to load transcript index: %s", e)
|
||||
return
|
||||
self._segments = [TranscriptSegment(**e) for e in data]
|
||||
if self._segments:
|
||||
last_num = max(int(s.id.lstrip("T")) for s in self._segments)
|
||||
self._next_id = last_num + 1
|
||||
with self._lock:
|
||||
self._segments = [TranscriptSegment(**e) for e in data]
|
||||
if self._segments:
|
||||
last_num = max(int(s.id.lstrip("T")) for s in self._segments)
|
||||
self._next_id = last_num + 1
|
||||
self._stopped = False
|
||||
log.info("Loaded %d transcript segments", len(self._segments))
|
||||
|
||||
def reset(self):
|
||||
self._segments.clear()
|
||||
self._next_id = 1
|
||||
with self._lock:
|
||||
self._stopped = True
|
||||
self._segments.clear()
|
||||
self._next_id = 1
|
||||
|
||||
Reference in New Issue
Block a user