add whisperx support

This commit is contained in:
Mariano Gabriel
2025-12-03 06:48:45 -03:00
parent 7b919beda6
commit 7d7ec15ff7
4 changed files with 87 additions and 16 deletions

View File

@@ -73,7 +73,7 @@ class CacheManager:
if not self.use_cache or self.skip_cache_frames or not self.frames_dir.exists(): if not self.use_cache or self.skip_cache_frames or not self.frames_dir.exists():
return None return None
existing_frames = list(self.frames_dir.glob("frame_*.jpg")) existing_frames = list(self.frames_dir.glob("*.jpg"))
if not existing_frames: if not existing_frames:
return None return None

View File

@@ -70,8 +70,10 @@ class TranscriptMerger:
for seg in data for seg in data
] ]
# Group by interval if requested # Group by interval if requested, but skip if we have speaker diarization
if group_interval and segments: # (merge_transcripts will group by speaker instead)
has_speakers = any(seg.get('speaker') for seg in segments)
if group_interval and segments and not has_speakers:
segments = self.group_audio_by_intervals(segments, group_interval) segments = self.group_audio_by_intervals(segments, group_interval)
return segments return segments
@@ -164,13 +166,14 @@ class TranscriptMerger:
) -> List[Dict]: ) -> List[Dict]:
""" """
Merge audio and screen transcripts by timestamp. Merge audio and screen transcripts by timestamp.
Groups consecutive audio from same speaker until a screen frame interrupts.
Args: Args:
audio_segments: List of audio transcript segments audio_segments: List of audio transcript segments
screen_segments: List of screen OCR segments screen_segments: List of screen OCR segments
Returns: Returns:
Merged list sorted by timestamp Merged list sorted by timestamp, with audio grouped by speaker
""" """
# Mark segment types # Mark segment types
for seg in audio_segments: for seg in audio_segments:
@@ -182,7 +185,46 @@ class TranscriptMerger:
all_segments = audio_segments + screen_segments all_segments = audio_segments + screen_segments
all_segments.sort(key=lambda x: x['timestamp']) all_segments.sort(key=lambda x: x['timestamp'])
return all_segments # Group consecutive audio segments by speaker (screen frames break groups)
grouped = []
current_group = None
for seg in all_segments:
if seg['type'] == 'screen':
# Screen frame: flush current group and add frame
if current_group:
grouped.append(current_group)
current_group = None
grouped.append(seg)
else:
# Audio segment
speaker = seg.get('speaker')
if current_group is None:
# Start new group
current_group = {
'timestamp': seg['timestamp'],
'text': seg['text'],
'speaker': speaker,
'type': 'audio'
}
elif speaker == current_group.get('speaker'):
# Same speaker, append text
current_group['text'] += ' ' + seg['text']
else:
# Speaker changed, flush and start new group
grouped.append(current_group)
current_group = {
'timestamp': seg['timestamp'],
'text': seg['text'],
'speaker': speaker,
'type': 'audio'
}
# Don't forget last group
if current_group:
grouped.append(current_group)
return grouped
def format_for_claude( def format_for_claude(
self, self,

View File

@@ -4,6 +4,7 @@ Coordinates frame extraction, analysis, and transcript merging.
""" """
from pathlib import Path from pathlib import Path
import logging import logging
import os
import subprocess import subprocess
import shutil import shutil
from typing import Dict, Any, Optional from typing import Dict, Any, Optional
@@ -32,6 +33,7 @@ class WorkflowConfig:
# Whisper options # Whisper options
self.run_whisper = kwargs.get('run_whisper', False) self.run_whisper = kwargs.get('run_whisper', False)
self.whisper_model = kwargs.get('whisper_model', 'medium') self.whisper_model = kwargs.get('whisper_model', 'medium')
self.diarize = kwargs.get('diarize', False)
# Frame extraction # Frame extraction
self.scene_detection = kwargs.get('scene_detection', False) self.scene_detection = kwargs.get('scene_detection', False)
@@ -176,18 +178,27 @@ class ProcessingWorkflow:
if cached: if cached:
return str(cached) return str(cached)
# If no cache and not running whisper, use provided transcript path (if any) # If no cache and not running whisper/diarize, use provided transcript path (if any)
if not self.config.run_whisper: if not self.config.run_whisper and not self.config.diarize:
return self.config.transcript_path return self.config.transcript_path
logger.info("=" * 80) logger.info("=" * 80)
logger.info("STEP 0: Running Whisper Transcription") logger.info("STEP 0: Running Whisper Transcription")
logger.info("=" * 80) logger.info("=" * 80)
# Check if whisperx is installed # Determine which transcription tool to use
if not shutil.which("whisperx"): use_diarize = getattr(self.config, 'diarize', False)
logger.error("WhisperX is not installed. Install it with: pip install whisperx")
raise RuntimeError("WhisperX not installed") if use_diarize:
if not shutil.which("whisperx"):
logger.error("WhisperX is not installed. Install it with: pip install whisperx")
raise RuntimeError("WhisperX not installed (required for --diarize)")
transcribe_cmd = "whisperx"
else:
if not shutil.which("whisper"):
logger.error("Whisper is not installed. Install it with: pip install openai-whisper")
raise RuntimeError("Whisper not installed")
transcribe_cmd = "whisper"
# Unload Ollama model to free GPU memory for Whisper (if using vision) # Unload Ollama model to free GPU memory for Whisper (if using vision)
if self.config.use_vision: if self.config.use_vision:
@@ -199,21 +210,34 @@ class ProcessingWorkflow:
except Exception as e: except Exception as e:
logger.warning(f"Could not unload Ollama model: {e}") logger.warning(f"Could not unload Ollama model: {e}")
logger.info(f"Running WhisperX transcription with diarization (model: {self.config.whisper_model})...") if use_diarize:
logger.info(f"Running WhisperX transcription with diarization (model: {self.config.whisper_model})...")
else:
logger.info(f"Running Whisper transcription (model: {self.config.whisper_model})...")
logger.info("This may take a few minutes depending on video length...") logger.info("This may take a few minutes depending on video length...")
# Run whisperx command with diarization # Build command
cmd = [ cmd = [
"whisperx", transcribe_cmd,
str(self.config.video_path), str(self.config.video_path),
"--model", self.config.whisper_model, "--model", self.config.whisper_model,
"--output_format", "json", "--output_format", "json",
"--output_dir", str(self.output_mgr.output_dir), "--output_dir", str(self.output_mgr.output_dir),
"--diarize",
] ]
if use_diarize:
cmd.append("--diarize")
try: try:
subprocess.run(cmd, check=True, capture_output=True, text=True) # Set up environment with cuDNN library path for whisperx
env = os.environ.copy()
if use_diarize:
import site
site_packages = site.getsitepackages()[0]
cudnn_path = Path(site_packages) / "nvidia" / "cudnn" / "lib"
if cudnn_path.exists():
env["LD_LIBRARY_PATH"] = str(cudnn_path) + ":" + env.get("LD_LIBRARY_PATH", "")
subprocess.run(cmd, check=True, capture_output=True, text=True, env=env)
transcript_path = self.output_mgr.get_path(f"{self.config.video_path.stem}.json") transcript_path = self.output_mgr.get_path(f"{self.config.video_path.stem}.json")

View File

@@ -72,6 +72,11 @@ Examples:
help='Whisper model to use (default: medium)', help='Whisper model to use (default: medium)',
default='medium' default='medium'
) )
parser.add_argument(
'--diarize',
action='store_true',
help='Use WhisperX with speaker diarization (requires whisperx and HuggingFace token)'
)
# Output options # Output options
parser.add_argument( parser.add_argument(