add whisperx support
This commit is contained in:
@@ -73,7 +73,7 @@ class CacheManager:
|
||||
if not self.use_cache or self.skip_cache_frames or not self.frames_dir.exists():
|
||||
return None
|
||||
|
||||
existing_frames = list(self.frames_dir.glob("frame_*.jpg"))
|
||||
existing_frames = list(self.frames_dir.glob("*.jpg"))
|
||||
|
||||
if not existing_frames:
|
||||
return None
|
||||
|
||||
@@ -70,8 +70,10 @@ class TranscriptMerger:
|
||||
for seg in data
|
||||
]
|
||||
|
||||
# Group by interval if requested
|
||||
if group_interval and segments:
|
||||
# Group by interval if requested, but skip if we have speaker diarization
|
||||
# (merge_transcripts will group by speaker instead)
|
||||
has_speakers = any(seg.get('speaker') for seg in segments)
|
||||
if group_interval and segments and not has_speakers:
|
||||
segments = self.group_audio_by_intervals(segments, group_interval)
|
||||
|
||||
return segments
|
||||
@@ -164,13 +166,14 @@ class TranscriptMerger:
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Merge audio and screen transcripts by timestamp.
|
||||
Groups consecutive audio from same speaker until a screen frame interrupts.
|
||||
|
||||
Args:
|
||||
audio_segments: List of audio transcript segments
|
||||
screen_segments: List of screen OCR segments
|
||||
|
||||
Returns:
|
||||
Merged list sorted by timestamp
|
||||
Merged list sorted by timestamp, with audio grouped by speaker
|
||||
"""
|
||||
# Mark segment types
|
||||
for seg in audio_segments:
|
||||
@@ -182,7 +185,46 @@ class TranscriptMerger:
|
||||
all_segments = audio_segments + screen_segments
|
||||
all_segments.sort(key=lambda x: x['timestamp'])
|
||||
|
||||
return all_segments
|
||||
# Group consecutive audio segments by speaker (screen frames break groups)
|
||||
grouped = []
|
||||
current_group = None
|
||||
|
||||
for seg in all_segments:
|
||||
if seg['type'] == 'screen':
|
||||
# Screen frame: flush current group and add frame
|
||||
if current_group:
|
||||
grouped.append(current_group)
|
||||
current_group = None
|
||||
grouped.append(seg)
|
||||
else:
|
||||
# Audio segment
|
||||
speaker = seg.get('speaker')
|
||||
if current_group is None:
|
||||
# Start new group
|
||||
current_group = {
|
||||
'timestamp': seg['timestamp'],
|
||||
'text': seg['text'],
|
||||
'speaker': speaker,
|
||||
'type': 'audio'
|
||||
}
|
||||
elif speaker == current_group.get('speaker'):
|
||||
# Same speaker, append text
|
||||
current_group['text'] += ' ' + seg['text']
|
||||
else:
|
||||
# Speaker changed, flush and start new group
|
||||
grouped.append(current_group)
|
||||
current_group = {
|
||||
'timestamp': seg['timestamp'],
|
||||
'text': seg['text'],
|
||||
'speaker': speaker,
|
||||
'type': 'audio'
|
||||
}
|
||||
|
||||
# Don't forget last group
|
||||
if current_group:
|
||||
grouped.append(current_group)
|
||||
|
||||
return grouped
|
||||
|
||||
def format_for_claude(
|
||||
self,
|
||||
|
||||
@@ -4,6 +4,7 @@ Coordinates frame extraction, analysis, and transcript merging.
|
||||
"""
|
||||
from pathlib import Path
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import shutil
|
||||
from typing import Dict, Any, Optional
|
||||
@@ -32,6 +33,7 @@ class WorkflowConfig:
|
||||
# Whisper options
|
||||
self.run_whisper = kwargs.get('run_whisper', False)
|
||||
self.whisper_model = kwargs.get('whisper_model', 'medium')
|
||||
self.diarize = kwargs.get('diarize', False)
|
||||
|
||||
# Frame extraction
|
||||
self.scene_detection = kwargs.get('scene_detection', False)
|
||||
@@ -176,18 +178,27 @@ class ProcessingWorkflow:
|
||||
if cached:
|
||||
return str(cached)
|
||||
|
||||
# If no cache and not running whisper, use provided transcript path (if any)
|
||||
if not self.config.run_whisper:
|
||||
# If no cache and not running whisper/diarize, use provided transcript path (if any)
|
||||
if not self.config.run_whisper and not self.config.diarize:
|
||||
return self.config.transcript_path
|
||||
|
||||
logger.info("=" * 80)
|
||||
logger.info("STEP 0: Running Whisper Transcription")
|
||||
logger.info("=" * 80)
|
||||
|
||||
# Check if whisperx is installed
|
||||
# Determine which transcription tool to use
|
||||
use_diarize = getattr(self.config, 'diarize', False)
|
||||
|
||||
if use_diarize:
|
||||
if not shutil.which("whisperx"):
|
||||
logger.error("WhisperX is not installed. Install it with: pip install whisperx")
|
||||
raise RuntimeError("WhisperX not installed")
|
||||
raise RuntimeError("WhisperX not installed (required for --diarize)")
|
||||
transcribe_cmd = "whisperx"
|
||||
else:
|
||||
if not shutil.which("whisper"):
|
||||
logger.error("Whisper is not installed. Install it with: pip install openai-whisper")
|
||||
raise RuntimeError("Whisper not installed")
|
||||
transcribe_cmd = "whisper"
|
||||
|
||||
# Unload Ollama model to free GPU memory for Whisper (if using vision)
|
||||
if self.config.use_vision:
|
||||
@@ -199,21 +210,34 @@ class ProcessingWorkflow:
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not unload Ollama model: {e}")
|
||||
|
||||
if use_diarize:
|
||||
logger.info(f"Running WhisperX transcription with diarization (model: {self.config.whisper_model})...")
|
||||
else:
|
||||
logger.info(f"Running Whisper transcription (model: {self.config.whisper_model})...")
|
||||
logger.info("This may take a few minutes depending on video length...")
|
||||
|
||||
# Run whisperx command with diarization
|
||||
# Build command
|
||||
cmd = [
|
||||
"whisperx",
|
||||
transcribe_cmd,
|
||||
str(self.config.video_path),
|
||||
"--model", self.config.whisper_model,
|
||||
"--output_format", "json",
|
||||
"--output_dir", str(self.output_mgr.output_dir),
|
||||
"--diarize",
|
||||
]
|
||||
if use_diarize:
|
||||
cmd.append("--diarize")
|
||||
|
||||
try:
|
||||
subprocess.run(cmd, check=True, capture_output=True, text=True)
|
||||
# Set up environment with cuDNN library path for whisperx
|
||||
env = os.environ.copy()
|
||||
if use_diarize:
|
||||
import site
|
||||
site_packages = site.getsitepackages()[0]
|
||||
cudnn_path = Path(site_packages) / "nvidia" / "cudnn" / "lib"
|
||||
if cudnn_path.exists():
|
||||
env["LD_LIBRARY_PATH"] = str(cudnn_path) + ":" + env.get("LD_LIBRARY_PATH", "")
|
||||
|
||||
subprocess.run(cmd, check=True, capture_output=True, text=True, env=env)
|
||||
|
||||
transcript_path = self.output_mgr.get_path(f"{self.config.video_path.stem}.json")
|
||||
|
||||
|
||||
@@ -72,6 +72,11 @@ Examples:
|
||||
help='Whisper model to use (default: medium)',
|
||||
default='medium'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--diarize',
|
||||
action='store_true',
|
||||
help='Use WhisperX with speaker diarization (requires whisperx and HuggingFace token)'
|
||||
)
|
||||
|
||||
# Output options
|
||||
parser.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user