add whisperx support
This commit is contained in:
@@ -73,7 +73,7 @@ class CacheManager:
|
|||||||
if not self.use_cache or self.skip_cache_frames or not self.frames_dir.exists():
|
if not self.use_cache or self.skip_cache_frames or not self.frames_dir.exists():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
existing_frames = list(self.frames_dir.glob("frame_*.jpg"))
|
existing_frames = list(self.frames_dir.glob("*.jpg"))
|
||||||
|
|
||||||
if not existing_frames:
|
if not existing_frames:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -70,8 +70,10 @@ class TranscriptMerger:
|
|||||||
for seg in data
|
for seg in data
|
||||||
]
|
]
|
||||||
|
|
||||||
# Group by interval if requested
|
# Group by interval if requested, but skip if we have speaker diarization
|
||||||
if group_interval and segments:
|
# (merge_transcripts will group by speaker instead)
|
||||||
|
has_speakers = any(seg.get('speaker') for seg in segments)
|
||||||
|
if group_interval and segments and not has_speakers:
|
||||||
segments = self.group_audio_by_intervals(segments, group_interval)
|
segments = self.group_audio_by_intervals(segments, group_interval)
|
||||||
|
|
||||||
return segments
|
return segments
|
||||||
@@ -164,13 +166,14 @@ class TranscriptMerger:
|
|||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
"""
|
"""
|
||||||
Merge audio and screen transcripts by timestamp.
|
Merge audio and screen transcripts by timestamp.
|
||||||
|
Groups consecutive audio from same speaker until a screen frame interrupts.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
audio_segments: List of audio transcript segments
|
audio_segments: List of audio transcript segments
|
||||||
screen_segments: List of screen OCR segments
|
screen_segments: List of screen OCR segments
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Merged list sorted by timestamp
|
Merged list sorted by timestamp, with audio grouped by speaker
|
||||||
"""
|
"""
|
||||||
# Mark segment types
|
# Mark segment types
|
||||||
for seg in audio_segments:
|
for seg in audio_segments:
|
||||||
@@ -182,7 +185,46 @@ class TranscriptMerger:
|
|||||||
all_segments = audio_segments + screen_segments
|
all_segments = audio_segments + screen_segments
|
||||||
all_segments.sort(key=lambda x: x['timestamp'])
|
all_segments.sort(key=lambda x: x['timestamp'])
|
||||||
|
|
||||||
return all_segments
|
# Group consecutive audio segments by speaker (screen frames break groups)
|
||||||
|
grouped = []
|
||||||
|
current_group = None
|
||||||
|
|
||||||
|
for seg in all_segments:
|
||||||
|
if seg['type'] == 'screen':
|
||||||
|
# Screen frame: flush current group and add frame
|
||||||
|
if current_group:
|
||||||
|
grouped.append(current_group)
|
||||||
|
current_group = None
|
||||||
|
grouped.append(seg)
|
||||||
|
else:
|
||||||
|
# Audio segment
|
||||||
|
speaker = seg.get('speaker')
|
||||||
|
if current_group is None:
|
||||||
|
# Start new group
|
||||||
|
current_group = {
|
||||||
|
'timestamp': seg['timestamp'],
|
||||||
|
'text': seg['text'],
|
||||||
|
'speaker': speaker,
|
||||||
|
'type': 'audio'
|
||||||
|
}
|
||||||
|
elif speaker == current_group.get('speaker'):
|
||||||
|
# Same speaker, append text
|
||||||
|
current_group['text'] += ' ' + seg['text']
|
||||||
|
else:
|
||||||
|
# Speaker changed, flush and start new group
|
||||||
|
grouped.append(current_group)
|
||||||
|
current_group = {
|
||||||
|
'timestamp': seg['timestamp'],
|
||||||
|
'text': seg['text'],
|
||||||
|
'speaker': speaker,
|
||||||
|
'type': 'audio'
|
||||||
|
}
|
||||||
|
|
||||||
|
# Don't forget last group
|
||||||
|
if current_group:
|
||||||
|
grouped.append(current_group)
|
||||||
|
|
||||||
|
return grouped
|
||||||
|
|
||||||
def format_for_claude(
|
def format_for_claude(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ Coordinates frame extraction, analysis, and transcript merging.
|
|||||||
"""
|
"""
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import shutil
|
import shutil
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional
|
||||||
@@ -32,6 +33,7 @@ class WorkflowConfig:
|
|||||||
# Whisper options
|
# Whisper options
|
||||||
self.run_whisper = kwargs.get('run_whisper', False)
|
self.run_whisper = kwargs.get('run_whisper', False)
|
||||||
self.whisper_model = kwargs.get('whisper_model', 'medium')
|
self.whisper_model = kwargs.get('whisper_model', 'medium')
|
||||||
|
self.diarize = kwargs.get('diarize', False)
|
||||||
|
|
||||||
# Frame extraction
|
# Frame extraction
|
||||||
self.scene_detection = kwargs.get('scene_detection', False)
|
self.scene_detection = kwargs.get('scene_detection', False)
|
||||||
@@ -176,18 +178,27 @@ class ProcessingWorkflow:
|
|||||||
if cached:
|
if cached:
|
||||||
return str(cached)
|
return str(cached)
|
||||||
|
|
||||||
# If no cache and not running whisper, use provided transcript path (if any)
|
# If no cache and not running whisper/diarize, use provided transcript path (if any)
|
||||||
if not self.config.run_whisper:
|
if not self.config.run_whisper and not self.config.diarize:
|
||||||
return self.config.transcript_path
|
return self.config.transcript_path
|
||||||
|
|
||||||
logger.info("=" * 80)
|
logger.info("=" * 80)
|
||||||
logger.info("STEP 0: Running Whisper Transcription")
|
logger.info("STEP 0: Running Whisper Transcription")
|
||||||
logger.info("=" * 80)
|
logger.info("=" * 80)
|
||||||
|
|
||||||
# Check if whisperx is installed
|
# Determine which transcription tool to use
|
||||||
|
use_diarize = getattr(self.config, 'diarize', False)
|
||||||
|
|
||||||
|
if use_diarize:
|
||||||
if not shutil.which("whisperx"):
|
if not shutil.which("whisperx"):
|
||||||
logger.error("WhisperX is not installed. Install it with: pip install whisperx")
|
logger.error("WhisperX is not installed. Install it with: pip install whisperx")
|
||||||
raise RuntimeError("WhisperX not installed")
|
raise RuntimeError("WhisperX not installed (required for --diarize)")
|
||||||
|
transcribe_cmd = "whisperx"
|
||||||
|
else:
|
||||||
|
if not shutil.which("whisper"):
|
||||||
|
logger.error("Whisper is not installed. Install it with: pip install openai-whisper")
|
||||||
|
raise RuntimeError("Whisper not installed")
|
||||||
|
transcribe_cmd = "whisper"
|
||||||
|
|
||||||
# Unload Ollama model to free GPU memory for Whisper (if using vision)
|
# Unload Ollama model to free GPU memory for Whisper (if using vision)
|
||||||
if self.config.use_vision:
|
if self.config.use_vision:
|
||||||
@@ -199,21 +210,34 @@ class ProcessingWorkflow:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Could not unload Ollama model: {e}")
|
logger.warning(f"Could not unload Ollama model: {e}")
|
||||||
|
|
||||||
|
if use_diarize:
|
||||||
logger.info(f"Running WhisperX transcription with diarization (model: {self.config.whisper_model})...")
|
logger.info(f"Running WhisperX transcription with diarization (model: {self.config.whisper_model})...")
|
||||||
|
else:
|
||||||
|
logger.info(f"Running Whisper transcription (model: {self.config.whisper_model})...")
|
||||||
logger.info("This may take a few minutes depending on video length...")
|
logger.info("This may take a few minutes depending on video length...")
|
||||||
|
|
||||||
# Run whisperx command with diarization
|
# Build command
|
||||||
cmd = [
|
cmd = [
|
||||||
"whisperx",
|
transcribe_cmd,
|
||||||
str(self.config.video_path),
|
str(self.config.video_path),
|
||||||
"--model", self.config.whisper_model,
|
"--model", self.config.whisper_model,
|
||||||
"--output_format", "json",
|
"--output_format", "json",
|
||||||
"--output_dir", str(self.output_mgr.output_dir),
|
"--output_dir", str(self.output_mgr.output_dir),
|
||||||
"--diarize",
|
|
||||||
]
|
]
|
||||||
|
if use_diarize:
|
||||||
|
cmd.append("--diarize")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
subprocess.run(cmd, check=True, capture_output=True, text=True)
|
# Set up environment with cuDNN library path for whisperx
|
||||||
|
env = os.environ.copy()
|
||||||
|
if use_diarize:
|
||||||
|
import site
|
||||||
|
site_packages = site.getsitepackages()[0]
|
||||||
|
cudnn_path = Path(site_packages) / "nvidia" / "cudnn" / "lib"
|
||||||
|
if cudnn_path.exists():
|
||||||
|
env["LD_LIBRARY_PATH"] = str(cudnn_path) + ":" + env.get("LD_LIBRARY_PATH", "")
|
||||||
|
|
||||||
|
subprocess.run(cmd, check=True, capture_output=True, text=True, env=env)
|
||||||
|
|
||||||
transcript_path = self.output_mgr.get_path(f"{self.config.video_path.stem}.json")
|
transcript_path = self.output_mgr.get_path(f"{self.config.video_path.stem}.json")
|
||||||
|
|
||||||
|
|||||||
@@ -72,6 +72,11 @@ Examples:
|
|||||||
help='Whisper model to use (default: medium)',
|
help='Whisper model to use (default: medium)',
|
||||||
default='medium'
|
default='medium'
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--diarize',
|
||||||
|
action='store_true',
|
||||||
|
help='Use WhisperX with speaker diarization (requires whisperx and HuggingFace token)'
|
||||||
|
)
|
||||||
|
|
||||||
# Output options
|
# Output options
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
Reference in New Issue
Block a user