import spaces import gradio as gr import torch import torchaudio import numpy as np import pandas as pd import time import datetime import re import subprocess import os import tempfile import spaces from transformers import pipeline from pyannote.audio import Pipeline import requests import base64 # Install flash attention for acceleration ''' try: subprocess.run( "pip install flash-attn --no-build-isolation", env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, shell=True, check=True ) except subprocess.CalledProcessError: print("Warning: Could not install flash-attn, falling back to default attention") ''' # Create global pipeline (similar to working HuggingFace example) pipe = pipeline( "automatic-speech-recognition", model="openai/whisper-large-v3-turbo", torch_dtype=torch.float16, device="cuda", model_kwargs={"attn_implementation": "flash_attention_2"}, return_timestamps=True, ) def comprehensive_flash_attention_verification(): """Comprehensive verification of flash attention setup""" print("šŸ” Running Flash Attention Verification...") print("=" * 50) verification_results = {} # Check 1: Package Installation print("šŸ” Checking Python packages...") try: import flash_attn print(f"āœ… flash-attn: {flash_attn.__version__}") verification_results["flash_attn_installed"] = True except ImportError: print("āŒ flash-attn: Not installed") verification_results["flash_attn_installed"] = False try: import transformers print(f"āœ… transformers: {transformers.__version__}") verification_results["transformers_available"] = True except ImportError: print("āŒ transformers: Not installed") verification_results["transformers_available"] = False # Check 2: CUDA Availability print("\nšŸ” Checking CUDA availability...") cuda_available = torch.cuda.is_available() print(f"āœ… CUDA available: {cuda_available}") if cuda_available: print(f"āœ… CUDA version: {torch.version.cuda}") print(f"āœ… GPU count: {torch.cuda.device_count()}") for i in range(torch.cuda.device_count()): print(f"āœ… GPU {i}: {torch.cuda.get_device_name(i)}") verification_results["cuda_available"] = cuda_available # Check 3: Flash Attention Import print("\nšŸ” Testing flash attention imports...") try: from flash_attn import flash_attn_func print("āœ… flash_attn_func imported successfully") if flash_attn_func is None: print("āŒ flash_attn_func is None") verification_results["flash_attn_import"] = False else: print("āœ… flash_attn_func is callable") verification_results["flash_attn_import"] = True except ImportError as e: print(f"āŒ Import error: {e}") verification_results["flash_attn_import"] = False except Exception as e: print(f"āŒ Unexpected error: {e}") verification_results["flash_attn_import"] = False # Check 4: Flash Attention Functionality Test print("\nšŸ” Testing flash attention functionality...") if not cuda_available: print("āš ļø Skipping functionality test - CUDA not available") verification_results["flash_attn_functional"] = False elif not verification_results.get("flash_attn_import", False): print("āš ļø Skipping functionality test - Import failed") verification_results["flash_attn_functional"] = False else: try: from flash_attn import flash_attn_func # Create small dummy tensors batch_size, seq_len, num_heads, head_dim = 1, 16, 4, 32 device = "cuda:0" dtype = torch.float16 print(f"Creating tensors: batch={batch_size}, seq_len={seq_len}, heads={num_heads}, dim={head_dim}") q = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype, device=device) k = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype, device=device) v = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype, device=device) print("āœ… Tensors created successfully") # Test flash attention output = flash_attn_func(q, k, v, dropout_p=0.0, causal=False) print(f"āœ… Flash attention output shape: {output.shape}") print("āœ… Flash attention test passed!") verification_results["flash_attn_functional"] = True except Exception as e: print(f"āŒ Flash attention test failed: {e}") import traceback traceback.print_exc() verification_results["flash_attn_functional"] = False # Summary print("\n" + "=" * 50) print("šŸ“Š VERIFICATION SUMMARY") print("=" * 50) all_passed = True for check_name, result in verification_results.items(): status = "āœ… PASS" if result else "āŒ FAIL" print(f"{check_name}: {status}") if not result: all_passed = False if all_passed: print("\nšŸŽ‰ All checks passed! Flash attention should work.") return True else: print("\nāš ļø Some checks failed. Flash attention may not work properly.") print("\nRecommendations:") print("1. Try reinstalling flash-attn: pip uninstall flash-attn && pip install flash-attn --no-build-isolation") print("2. Check CUDA compatibility with your PyTorch version") print("3. Consider using default attention as fallback") return False class WhisperTranscriber: def __init__(self): self.pipe = pipe # Use global pipeline self.diarization_model = None #@spaces.GPU def setup_models(self): """Initialize models with GPU acceleration""" if self.pipe is None: print("Loading Whisper model...") self.pipe = pipeline( "automatic-speech-recognition", model="openai/whisper-large-v3-turbo", torch_dtype=torch.float16, device="cuda:0", model_kwargs={"attn_implementation": "flash_attention_2"}, return_timestamps=True, ) if self.diarization_model is None: print("Loading diarization model...") # Note: You'll need to set up authentication for pyannote models # For demo purposes, we'll handle the case where it's not available try: self.diarization_model = Pipeline.from_pretrained( "pyannote/speaker-diarization-3.1", use_auth_token=os.getenv("HF_TOKEN") ).to(torch.device("cuda")) except Exception as e: print(f"Could not load diarization model: {e}") self.diarization_model = None def convert_audio_format(self, audio_path): """Convert audio to 16kHz mono WAV format""" temp_wav = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") temp_wav_path = temp_wav.name temp_wav.close() try: subprocess.run([ "ffmpeg", "-i", audio_path, "-ar", "16000", "-ac", "1", "-c:a", "pcm_s16le", temp_wav_path, "-y" ], check=True, capture_output=True) return temp_wav_path except subprocess.CalledProcessError as e: raise RuntimeError(f"Audio conversion failed: {e}") @spaces.GPU def transcribe_audio(self, audio_path, language=None, translate=False, prompt=None): """Transcribe audio using Whisper with flash attention""" # Run comprehensive flash attention verification #flash_attention_working = comprehensive_flash_attention_verification() #if not flash_attention_working: # print("āš ļø Flash attention verification failed, but proceeding with transcription...") # print("You may encounter the TypeError: 'NoneType' object is not callable error") ''' #if self.pipe is None: # self.setup_models() if next(self.pipe.model.parameters()).device.type != "cuda": self.pipe.model.to("cuda") ''' print("Starting transcription...") start_time = time.time() # Prepare generation kwargs generate_kwargs = {} if language: generate_kwargs["language"] = language if translate: generate_kwargs["task"] = "translate" if prompt: generate_kwargs["prompt_ids"] = self.pipe.tokenizer.encode(prompt) # Transcribe with timestamps result = self.pipe( audio_path, return_timestamps=True, generate_kwargs=generate_kwargs, chunk_length_s=30, batch_size=128, ) transcription_time = time.time() - start_time print(f"Transcription completed in {transcription_time:.2f} seconds") # Extract segments and detected language segments = [] if "chunks" in result: for chunk in result["chunks"]: segment = { "start": float(chunk["timestamp"][0] or 0), "end": float(chunk["timestamp"][1] or 0), "text": chunk["text"].strip(), } segments.append(segment) else: # Fallback for different result format segments = [{ "start": 0.0, "end": 0.0, "text": result["text"] }] detected_language = getattr(result, 'language', language or 'unknown') transcription_time = time.time() - start_time print(f"Transcription parse completed in {transcription_time:.2f} seconds") return segments, detected_language def perform_diarization(self, audio_path, num_speakers=None): """Perform speaker diarization""" if self.diarization_model is None: print("Diarization model not available, assigning single speaker") return [], 1 print("Starting diarization...") start_time = time.time() # Load audio for diarization waveform, sample_rate = torchaudio.load(audio_path) # Perform diarization diarization = self.diarization_model( {"waveform": waveform, "sample_rate": sample_rate}, num_speakers=num_speakers, ) # Convert to list format diarize_segments = [] diarization_list = list(diarization.itertracks(yield_label=True)) for turn, _, speaker in diarization_list: diarize_segments.append({ "start": turn.start, "end": turn.end, "speaker": speaker }) unique_speakers = {speaker for _, _, speaker in diarization_list} detected_num_speakers = len(unique_speakers) diarization_time = time.time() - start_time print(f"Diarization completed in {diarization_time:.2f} seconds") return diarize_segments, detected_num_speakers def merge_transcription_and_diarization(self, transcription_segments, diarization_segments): """Merge transcription segments with speaker information""" if not diarization_segments: # No diarization available, assign single speaker for segment in transcription_segments: segment["speaker"] = "SPEAKER_00" return transcription_segments print("Merging transcription and diarization...") diarize_df = pd.DataFrame(diarization_segments) final_segments = [] for segment in transcription_segments: # Calculate intersection with diarization segments diarize_df["intersection"] = np.maximum(0, np.minimum(diarize_df["end"], segment["end"]) - np.maximum(diarize_df["start"], segment["start"]) ) # Find speaker with maximum intersection dia_tmp = diarize_df[diarize_df["intersection"] > 0] if len(dia_tmp) > 0: speaker = ( dia_tmp.groupby("speaker")["intersection"] .sum() .sort_values(ascending=False) .index[0] ) else: speaker = "SPEAKER_00" segment["speaker"] = speaker segment["duration"] = segment["end"] - segment["start"] final_segments.append(segment) return final_segments def group_segments_by_speaker(self, segments, max_gap=1.0, max_duration=30.0): """Group consecutive segments from the same speaker""" if not segments: return segments grouped_segments = [] current_group = segments[0].copy() sentence_end_pattern = r"[.!?]+\s*$" for segment in segments[1:]: time_gap = segment["start"] - current_group["end"] current_duration = current_group["end"] - current_group["start"] # Conditions for combining segments can_combine = ( segment["speaker"] == current_group["speaker"] and time_gap <= max_gap and current_duration < max_duration and not re.search(sentence_end_pattern, current_group["text"]) ) if can_combine: # Merge segments current_group["end"] = segment["end"] current_group["text"] += " " + segment["text"] current_group["duration"] = current_group["end"] - current_group["start"] else: # Start new group grouped_segments.append(current_group) current_group = segment.copy() grouped_segments.append(current_group) # Clean up text for segment in grouped_segments: segment["text"] = re.sub(r"\s+", " ", segment["text"]).strip() segment["text"] = re.sub(r"\s+([.,!?])", r"\1", segment["text"]) return grouped_segments @spaces.GPU def process_audio(self, audio_file, num_speakers=None, language=None, translate=False, prompt=None, group_segments=True): """Main processing function""" if audio_file is None: return {"error": "No audio file provided"} try: # Setup models if not already done #self.setup_models() # Convert audio format #wav_path = self.convert_audio_format(audio_file) try: # Transcribe audio transcription_segments, detected_language = self.transcribe_audio( audio_file, language, translate, prompt ) # Perform diarization diarization_segments, detected_num_speakers = self.perform_diarization( audio_file, num_speakers ) # Merge transcription and diarization final_segments = self.merge_transcription_and_diarization( transcription_segments, diarization_segments ) # Group segments if requested if group_segments: final_segments = self.group_segments_by_speaker(final_segments) return { "segments": final_segments, "language": detected_language, "num_speakers": detected_num_speakers or 1, "total_segments": len(final_segments) } finally: # Clean up temporary file if os.path.exists(audio_file): os.unlink(audio_file) except Exception as e: import traceback traceback.print_exc() return {"error": f"Processing failed: {str(e)}"} # Initialize transcriber transcriber = WhisperTranscriber() def format_segments_for_display(result): """Format segments for display in Gradio""" if "error" in result: return f"āŒ Error: {result['error']}" segments = result.get("segments", []) language = result.get("language", "unknown") num_speakers = result.get("num_speakers", 1) output = f"šŸŽÆ **Detection Results:**\n" output += f"- Language: {language}\n" output += f"- Speakers: {num_speakers}\n" output += f"- Segments: {len(segments)}\n\n" output += "šŸ“ **Transcription:**\n\n" for i, segment in enumerate(segments, 1): start_time = str(datetime.timedelta(seconds=int(segment["start"]))) end_time = str(datetime.timedelta(seconds=int(segment["end"]))) speaker = segment.get("speaker", "SPEAKER_00") text = segment["text"] output += f"**{speaker}** ({start_time} → {end_time})\n" output += f"{text}\n\n" return output @spaces.GPU def process_audio_gradio(audio_file, num_speakers, language, translate, prompt, group_segments): """Gradio interface function""" result = transcriber.process_audio( audio_file=audio_file, num_speakers=num_speakers if num_speakers > 0 else None, language=language if language != "auto" else None, translate=translate, prompt=prompt if prompt and prompt.strip() else None, group_segments=group_segments ) formatted_output = format_segments_for_display(result) return formatted_output, result # Create Gradio interface demo = gr.Blocks( title="šŸŽ™ļø Whisper Transcription with Speaker Diarization", theme="default" ) with demo: gr.Markdown(""" # šŸŽ™ļø Advanced Audio Transcription & Speaker Diarization Upload an audio file to get accurate transcription with speaker identification, powered by: - **Whisper Large V3 Turbo** with Flash Attention for fast transcription - **Pyannote 3.1** for speaker diarization - **ZeroGPU** acceleration for optimal performance """) with gr.Row(): with gr.Column(): audio_input = gr.Audio( label="šŸŽµ Upload Audio File", type="filepath", #source="upload" ) with gr.Accordion("āš™ļø Advanced Settings", open=False): num_speakers = gr.Slider( minimum=0, maximum=20, value=0, step=1, label="Number of Speakers (0 = auto-detect)" ) language = gr.Dropdown( choices=["auto", "en", "es", "fr", "de", "it", "pt", "ru", "ja", "ko", "zh"], value="auto", label="Language" ) translate = gr.Checkbox( label="Translate to English", value=False ) prompt = gr.Textbox( label="Vocabulary Prompt (names, acronyms, etc.)", placeholder="Enter names, technical terms, or context...", lines=2 ) group_segments = gr.Checkbox( label="Group segments by speaker", value=True ) process_btn = gr.Button("šŸš€ Transcribe Audio", variant="primary") with gr.Column(): output_text = gr.Markdown( label="šŸ“ Transcription Results", value="Upload an audio file and click 'Transcribe Audio' to get started!" ) output_json = gr.JSON( label="šŸ”§ Raw Output (JSON)", visible=False ) # Event handlers process_btn.click( fn=process_audio_gradio, inputs=[ audio_input, num_speakers, language, translate, prompt, group_segments ], outputs=[output_text, output_json] ) # Examples gr.Markdown("### šŸ“‹ Usage Tips:") gr.Markdown(""" - **Supported formats**: MP3, WAV, M4A, FLAC, OGG, and more - **Max duration**: Recommended under 10 minutes for optimal performance - **Speaker detection**: Works best with clear, distinct voices - **Languages**: Supports 100+ languages with auto-detection - **Vocabulary**: Add names and technical terms in the prompt for better accuracy """) if __name__ == "__main__": demo.launch(debug=True)