liuyang commited on
Commit
726a091
·
1 Parent(s): 54da026

restore to whisper

Browse files
Files changed (2) hide show
  1. app.py +55 -178
  2. requirements.txt +2 -2
app.py CHANGED
@@ -37,7 +37,6 @@ import tempfile
37
  import spaces
38
  from faster_whisper import WhisperModel, BatchedInferencePipeline
39
  from faster_whisper.vad import VadOptions
40
- import whisperx
41
  import requests
42
  import base64
43
  from pyannote.audio import Pipeline, Inference, Model
@@ -133,17 +132,14 @@ MODELS = {
133
  }
134
  DEFAULT_MODEL = "large-v3-turbo"
135
 
136
- # Supported languages for alignment models (whisperX)
137
- ALIGN_LANGUAGES = ["en", "es", "fr", "de", "it", "pt", "ru", "ja", "ko", "zh", "ar", "nl", "tr", "pl", "cs", "sv", "da", "fi", "no", "uk"]
138
-
139
  def _download_model(model_name: str):
140
- """Downloads a faster-whisper model from the hub if not already present."""
141
  if model_name not in MODELS:
142
  raise ValueError(f"Model '{model_name}' not found in MODELS registry.")
143
 
144
  model_info = MODELS[model_name]
145
  if not os.path.exists(model_info["local_dir"]):
146
- print(f"Downloading faster-whisper model '{model_name}' from {model_info['repo_id']}...")
147
  snapshot_download(
148
  repo_id=model_info["repo_id"],
149
  local_dir=model_info["local_dir"],
@@ -152,11 +148,9 @@ def _download_model(model_name: str):
152
  )
153
  return model_info["local_dir"]
154
 
155
- # Download all faster-whisper models on startup
156
- print("Downloading all faster-whisper models...")
157
  for model in MODELS:
158
  _download_model(model)
159
- print("All faster-whisper models downloaded!")
160
 
161
 
162
  # -----------------------------------------------------------------------------
@@ -384,32 +378,13 @@ def _process_single_chunk(task: dict, out_dir: str) -> dict:
384
  # Lazy global holder ----------------------------------------------------------
385
  _whisper_models = {}
386
  _batched_whisper_models = {}
 
387
  _whipser_x_align_models = {}
388
 
389
  _diarizer = None
390
  _embedder = None
391
 
392
- # Preload WhisperX alignment models at startup (no GPU decorator needed)
393
- print("Preloading all WhisperX alignment models...")
394
- for lang in ALIGN_LANGUAGES:
395
- try:
396
- print(f"Loading alignment model for language '{lang}'...")
397
- device = "cuda"
398
-
399
- align_model, align_metadata = whisperx.load_align_model(
400
- language_code=lang,
401
- device=device,
402
- model_dir=CACHE_ROOT
403
- )
404
- _whipser_x_align_models[lang] = {
405
- "model": align_model,
406
- "metadata": align_metadata
407
- }
408
- print(f"Alignment model for '{lang}' loaded successfully")
409
- except Exception as e:
410
- print(f"Could not load alignment model for '{lang}': {e}")
411
-
412
- # Create global diarization pipeline at startup
413
  try:
414
  print("Loading diarization model...")
415
  torch.backends.cuda.matmul.allow_tf32 = True
@@ -428,24 +403,15 @@ except Exception as e:
428
  print(f"Could not load diarization model: {e}")
429
  _diarizer = None
430
 
431
- print("WhisperX alignment and diarization models preloaded successfully!")
432
-
433
  @spaces.GPU # GPU is guaranteed to exist *inside* this function
434
- def _load_faster_whisper_model(model_name: str):
435
- """Load a specific faster-whisper model on GPU (lazy loading)"""
436
- global _whisper_models, _batched_whisper_models
437
-
438
- if model_name in _whisper_models:
439
- print(f"Faster-whisper model '{model_name}' already loaded")
440
- return _whisper_models[model_name], _batched_whisper_models[model_name]
441
-
442
- if model_name not in MODELS:
443
- raise ValueError(f"Model '{model_name}' not found in MODELS registry. Available: {list(MODELS.keys())}")
444
 
445
- print(f"Loading faster-whisper model '{model_name}' on GPU...")
446
- model_cache_path = _download_model(model_name)
447
-
448
- try:
 
449
  model = WhisperModel(
450
  model_cache_path,
451
  device="cuda",
@@ -458,24 +424,12 @@ def _load_faster_whisper_model(model_name: str):
458
  _whisper_models[model_name] = model
459
  _batched_whisper_models[model_name] = batched_model
460
 
461
- print(f"Faster-whisper model '{model_name}' and batched pipeline loaded successfully")
462
- return model, batched_model
463
- except Exception as e:
464
- import traceback
465
- traceback.print_exc()
466
- raise RuntimeError(f"Could not load faster-whisper model '{model_name}': {e}")
467
-
468
- # Optional: Preload all faster-whisper models explicitly
469
- @spaces.GPU
470
- def preload_all_whisper_models():
471
- """Preload all faster-whisper models - optional, for faster first-time use"""
472
- print("Preloading all faster-whisper models...")
473
- for model_name in MODELS.keys():
474
- try:
475
- _load_faster_whisper_model(model_name)
476
- except Exception as e:
477
- print(f"Failed to preload model '{model_name}': {e}")
478
- print("All faster-whisper models preloaded!")
479
 
480
  # -----------------------------------------------------------------------------
481
  class WhisperTranscriber:
@@ -504,18 +458,10 @@ class WhisperTranscriber:
504
 
505
  @spaces.GPU # each call gets a GPU slice
506
  def transcribe_full_audio(self, audio_path, language=None, translate=False, prompt=None, batch_size=16, base_offset_s: float = 0.0, clip_timestamps=None, model_name: str = DEFAULT_MODEL, transcribe_options: dict = None):
507
- """Transcribe the entire audio file using faster-whisper, then align with WhisperX"""
508
- global _whisper_models, _batched_whisper_models, _whipser_x_align_models
509
 
510
- # Get preloaded faster-whisper model, or load it if not available
511
- if model_name not in _whisper_models:
512
- print(f"Faster-whisper model '{model_name}' not preloaded, loading now...")
513
- _load_faster_whisper_model(model_name)
514
-
515
- whisper = _whisper_models[model_name]
516
- batched_whisper = _batched_whisper_models[model_name]
517
-
518
- print(f"Transcribing full audio with faster-whisper '{model_name}' and batch size {batch_size}...")
519
  start_time = time.time()
520
 
521
  # Prepare options for batched inference
@@ -528,125 +474,65 @@ class WhisperTranscriber:
528
  language_detection_segments=1,
529
  task="translate" if translate else "transcribe",
530
  )
531
-
532
  if clip_timestamps:
533
  options["vad_filter"] = False
534
  options["clip_timestamps"] = clip_timestamps
535
  else:
536
- vad_options = transcribe_options.get("vad_parameters", None) if transcribe_options else None
537
  options["vad_filter"] = True # VAD is enabled by default for batched transcription
538
  options["vad_parameters"] = VadOptions(**vad_options) if vad_options else VadOptions(
539
  max_speech_duration_s=whisper.feature_extractor.chunk_length,
540
  min_speech_duration_ms=180, # ignore ultra-short blips
541
- min_silence_duration_ms=120, # split on short Mandarin pauses (if supported)
542
  speech_pad_ms=120,
543
  threshold=0.35,
544
  neg_threshold=0.2,
545
  )
546
-
547
  if batch_size > 1:
548
  # Use batched inference for better performance
549
  segments, transcript_info = batched_whisper.transcribe(
550
- audio_path,
551
- batch_size=batch_size,
552
  **options
553
  )
554
  else:
555
  segments, transcript_info = whisper.transcribe(
556
- audio_path,
557
  **options
558
  )
559
  segments = list(segments)
560
 
561
  detected_language = transcript_info.language
562
- print(f"Detected language: {detected_language}, segments: {len(segments)}, transcribing done in {time.time() - start_time:.2f} seconds")
563
-
564
- # Align with WhisperX if alignment model is available
565
- aligned_segments = segments
566
- if detected_language in _whipser_x_align_models:
567
- print(f"Performing WhisperX alignment for language '{detected_language}'...")
568
- align_start = time.time()
569
- try:
570
- # Load audio for whisperx alignment
571
- audio = whisperx.load_audio(audio_path)
572
-
573
- # Convert faster-whisper segments to whisperx format
574
- whisperx_segments = []
575
- for seg in segments:
576
- whisperx_segments.append({
577
- "start": seg.start,
578
- "end": seg.end,
579
- "text": seg.text
580
- })
581
-
582
- align_info = _whipser_x_align_models[detected_language]
583
- result = whisperx.align(
584
- whisperx_segments,
585
- align_info["model"],
586
- align_info["metadata"],
587
- audio,
588
- "cuda",
589
- return_char_alignments=False
590
- )
591
- aligned_segments = result.get("segments", segments)
592
- print(f"WhisperX alignment completed in {time.time() - align_start:.2f} seconds")
593
- except Exception as e:
594
- print(f"WhisperX alignment failed: {e}, using original timestamps")
595
- aligned_segments = segments
596
- else:
597
- print(f"No WhisperX alignment model available for language '{detected_language}', using faster-whisper timestamps")
598
 
599
- # Process segments into the expected format
600
  results = []
601
- for i, seg in enumerate(aligned_segments):
602
- # Check if this is a whisperx aligned segment (dict) or faster-whisper segment (object)
603
- if isinstance(seg, dict):
604
- # WhisperX aligned segment
605
- words_list = []
606
- if "words" in seg:
607
- for word in seg["words"]:
608
- words_list.append({
609
- "start": float(word.get("start", 0.0)) + float(base_offset_s),
610
- "end": float(word.get("end", 0.0)) + float(base_offset_s),
611
- "word": word.get("word", ""),
612
- "probability": word.get("score", 1.0),
613
- "speaker": "SPEAKER_00"
614
- })
615
-
616
- results.append({
617
- "start": float(seg.get("start", 0.0)) + float(base_offset_s),
618
- "end": float(seg.get("end", 0.0)) + float(base_offset_s),
619
- "text": seg.get("text", ""),
620
- "speaker": "SPEAKER_00",
621
- "avg_logprob": segments[i].avg_logprob if i < len(segments) else 0.0,
622
- "words": words_list,
623
- "duration": float(seg.get("end", 0.0)) - float(seg.get("start", 0.0))
624
- })
625
- else:
626
- # Faster-whisper segment (not aligned)
627
- words_list = []
628
- if seg.words:
629
- for word in seg.words:
630
- words_list.append({
631
- "start": float(word.start) + float(base_offset_s),
632
- "end": float(word.end) + float(base_offset_s),
633
- "word": word.word,
634
- "probability": word.probability,
635
- "speaker": "SPEAKER_00"
636
- })
637
-
638
- results.append({
639
- "start": float(seg.start) + float(base_offset_s),
640
- "end": float(seg.end) + float(base_offset_s),
641
- "text": seg.text,
642
- "speaker": "SPEAKER_00",
643
- "avg_logprob": seg.avg_logprob,
644
- "words": words_list,
645
- "duration": float(seg.end - seg.start)
646
- })
647
 
648
  transcription_time = time.time() - start_time
649
- print(f"Full audio transcribed and aligned in {transcription_time:.2f} seconds using batch size {batch_size}")
 
650
  return results, detected_language
651
 
652
  # Removed audio cutting; transcription is done once on the full (preprocessed) audio
@@ -654,9 +540,9 @@ class WhisperTranscriber:
654
  @spaces.GPU # each call gets a GPU slice
655
  def perform_diarization(self, audio_path, num_speakers=None, base_offset_s: float = 0.0):
656
  """Perform speaker diarization; return segments with global timestamps and per-speaker embeddings."""
657
- global _diarizer
658
 
659
- if _diarizer is None:
660
  print("Diarization model not available, creating single speaker segment")
661
  # Load audio to get duration
662
  waveform, sample_rate = torchaudio.load(audio_path)
@@ -689,7 +575,7 @@ class WhisperTranscriber:
689
  waveform, sample_rate = torchaudio.load(audio_path)
690
 
691
  # Perform diarization
692
- diarization = _diarizer(
693
  {"waveform": waveform, "sample_rate": sample_rate},
694
  num_speakers=num_speakers,
695
  )
@@ -1604,14 +1490,5 @@ with demo:
1604
  - Vocabulary: Add names and technical terms in the prompt for better accuracy
1605
  """)
1606
 
1607
- # Preload all whisper models once at service initialization
1608
- print("Preloading all WhisperX transcribe models at startup...")
1609
- try:
1610
- preload_all_whisper_models()
1611
- print("All WhisperX transcribe models preloaded at startup!")
1612
- except Exception as e:
1613
- print(f"Warning: Could not preload whisper models at startup: {e}")
1614
- print("Models will be loaded on first use instead.")
1615
-
1616
  if __name__ == "__main__":
1617
  demo.launch(debug=True)
 
37
  import spaces
38
  from faster_whisper import WhisperModel, BatchedInferencePipeline
39
  from faster_whisper.vad import VadOptions
 
40
  import requests
41
  import base64
42
  from pyannote.audio import Pipeline, Inference, Model
 
132
  }
133
  DEFAULT_MODEL = "large-v3-turbo"
134
 
 
 
 
135
  def _download_model(model_name: str):
136
+ """Downloads a model from the hub if not already present."""
137
  if model_name not in MODELS:
138
  raise ValueError(f"Model '{model_name}' not found in MODELS registry.")
139
 
140
  model_info = MODELS[model_name]
141
  if not os.path.exists(model_info["local_dir"]):
142
+ print(f"Downloading model '{model_name}' from {model_info['repo_id']}...")
143
  snapshot_download(
144
  repo_id=model_info["repo_id"],
145
  local_dir=model_info["local_dir"],
 
148
  )
149
  return model_info["local_dir"]
150
 
151
+ # Download the default model on startup
 
152
  for model in MODELS:
153
  _download_model(model)
 
154
 
155
 
156
  # -----------------------------------------------------------------------------
 
378
  # Lazy global holder ----------------------------------------------------------
379
  _whisper_models = {}
380
  _batched_whisper_models = {}
381
+ _whipser_x_transcribe_models = {}
382
  _whipser_x_align_models = {}
383
 
384
  _diarizer = None
385
  _embedder = None
386
 
387
+ # Create global diarization pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
  try:
389
  print("Loading diarization model...")
390
  torch.backends.cuda.matmul.allow_tf32 = True
 
403
  print(f"Could not load diarization model: {e}")
404
  _diarizer = None
405
 
 
 
406
  @spaces.GPU # GPU is guaranteed to exist *inside* this function
407
+ def _load_models(model_name: str = DEFAULT_MODEL):
408
+ global _whisper_models, _batched_whisper_models, _diarizer
 
 
 
 
 
 
 
 
409
 
410
+ if model_name not in _whisper_models:
411
+ print(f"Loading Whisper model '{model_name}'...")
412
+
413
+ model_cache_path = _download_model(model_name)
414
+
415
  model = WhisperModel(
416
  model_cache_path,
417
  device="cuda",
 
424
  _whisper_models[model_name] = model
425
  _batched_whisper_models[model_name] = batched_model
426
 
427
+ print(f"Whisper model '{model_name}' and batched pipeline loaded successfully")
428
+
429
+ whisper = _whisper_models[model_name]
430
+ batched_whisper = _batched_whisper_models[model_name]
431
+
432
+ return whisper, batched_whisper, _diarizer
 
 
 
 
 
 
 
 
 
 
 
 
433
 
434
  # -----------------------------------------------------------------------------
435
  class WhisperTranscriber:
 
458
 
459
  @spaces.GPU # each call gets a GPU slice
460
  def transcribe_full_audio(self, audio_path, language=None, translate=False, prompt=None, batch_size=16, base_offset_s: float = 0.0, clip_timestamps=None, model_name: str = DEFAULT_MODEL, transcribe_options: dict = None):
461
+ """Transcribe the entire audio file without speaker diarization using batched inference"""
462
+ whisper, batched_whisper, _ = _load_models(model_name) # models live on the GPU
463
 
464
+ print(f"Transcribing full audio with '{model_name}' and batch size {batch_size}...")
 
 
 
 
 
 
 
 
465
  start_time = time.time()
466
 
467
  # Prepare options for batched inference
 
474
  language_detection_segments=1,
475
  task="translate" if translate else "transcribe",
476
  )
 
477
  if clip_timestamps:
478
  options["vad_filter"] = False
479
  options["clip_timestamps"] = clip_timestamps
480
  else:
481
+ vad_options = transcribe_options.get("vad_parameters", None)
482
  options["vad_filter"] = True # VAD is enabled by default for batched transcription
483
  options["vad_parameters"] = VadOptions(**vad_options) if vad_options else VadOptions(
484
  max_speech_duration_s=whisper.feature_extractor.chunk_length,
485
  min_speech_duration_ms=180, # ignore ultra-short blips
486
+ min_silence_duration_ms=120, # split on short Mandarin pauses (if supported)
487
  speech_pad_ms=120,
488
  threshold=0.35,
489
  neg_threshold=0.2,
490
  )
 
491
  if batch_size > 1:
492
  # Use batched inference for better performance
493
  segments, transcript_info = batched_whisper.transcribe(
494
+ audio_path,
495
+ batch_size=batch_size,
496
  **options
497
  )
498
  else:
499
  segments, transcript_info = whisper.transcribe(
500
+ audio_path,
501
  **options
502
  )
503
  segments = list(segments)
504
 
505
  detected_language = transcript_info.language
506
+ print("Detected language: ", detected_language, "segments: ", len(segments))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
507
 
508
+ # Process segments
509
  results = []
510
+ for seg in segments:
511
+ # Create result entry with detailed format
512
+ words_list = []
513
+ if seg.words:
514
+ for word in seg.words:
515
+ words_list.append({
516
+ "start": float(word.start) + float(base_offset_s),
517
+ "end": float(word.end) + float(base_offset_s),
518
+ "word": word.word,
519
+ "probability": word.probability,
520
+ "speaker": "SPEAKER_00" # No speaker identification in full transcription
521
+ })
522
+
523
+ results.append({
524
+ "start": float(seg.start) + float(base_offset_s),
525
+ "end": float(seg.end) + float(base_offset_s),
526
+ "text": seg.text,
527
+ "speaker": "SPEAKER_00", # Single speaker assumption
528
+ "avg_logprob": seg.avg_logprob,
529
+ "words": words_list,
530
+ "duration": float(seg.end - seg.start)
531
+ })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
 
533
  transcription_time = time.time() - start_time
534
+ print(f"Full audio transcribed in {transcription_time:.2f} seconds using batch size {batch_size}")
535
+ print(results)
536
  return results, detected_language
537
 
538
  # Removed audio cutting; transcription is done once on the full (preprocessed) audio
 
540
  @spaces.GPU # each call gets a GPU slice
541
  def perform_diarization(self, audio_path, num_speakers=None, base_offset_s: float = 0.0):
542
  """Perform speaker diarization; return segments with global timestamps and per-speaker embeddings."""
543
+ _, _, diarizer = _load_models() # models live on the GPU
544
 
545
+ if diarizer is None:
546
  print("Diarization model not available, creating single speaker segment")
547
  # Load audio to get duration
548
  waveform, sample_rate = torchaudio.load(audio_path)
 
575
  waveform, sample_rate = torchaudio.load(audio_path)
576
 
577
  # Perform diarization
578
+ diarization = diarizer(
579
  {"waveform": waveform, "sample_rate": sample_rate},
580
  num_speakers=num_speakers,
581
  )
 
1490
  - Vocabulary: Add names and technical terms in the prompt for better accuracy
1491
  """)
1492
 
 
 
 
 
 
 
 
 
 
1493
  if __name__ == "__main__":
1494
  demo.launch(debug=True)
requirements.txt CHANGED
@@ -5,8 +5,8 @@ transformers==4.48.0
5
  pydantic==2.10.6
6
 
7
  # 2. Main whisper model - using whisperx instead of faster-whisper
8
- ctranslate2==4.4.0
9
- whisperx
10
  torch
11
 
12
  # 3. Extra libs your app really needs
 
5
  pydantic==2.10.6
6
 
7
  # 2. Main whisper model - using whisperx instead of faster-whisper
8
+ faster-whisper==1.1.1
9
+ ctranslate2==4.5.0
10
  torch
11
 
12
  # 3. Extra libs your app really needs