liuyang commited on
Commit
8c68b8b
·
1 Parent(s): d441278

Add full audio transcription functionality and update Gradio interface

Browse files
Files changed (1) hide show
  1. app.py +146 -27
app.py CHANGED
@@ -52,7 +52,6 @@ except OSError as e:
52
  _whisper = None
53
  _diarizer = None
54
 
55
-
56
  # Create global diarization pipeline
57
  try:
58
  print("Loading diarization model...")
@@ -63,17 +62,8 @@ try:
63
  _diarizer = Pipeline.from_pretrained(
64
  "pyannote/speaker-diarization-3.1",
65
  use_auth_token=os.getenv("HF_TOKEN"),
66
- #torch_dtype=torch.float16,
67
  ).to(torch.device("cuda"))
68
- '''
69
- _diarizer.model.half() # FP16
70
-
71
- for m in _diarizer.model.modules(): # compact LSTM weights
72
- if isinstance(m, torch.nn.LSTM):
73
- m.flatten_parameters()
74
-
75
- _diarizer.model = torch.compile(_diarizer.model, mode="reduce-overhead")
76
- '''
77
  print("Diarization model loaded successfully")
78
  except Exception as e:
79
  import traceback
@@ -116,6 +106,68 @@ class WhisperTranscriber:
116
  except subprocess.CalledProcessError as e:
117
  raise RuntimeError(f"Audio conversion failed: {e}")
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  def cut_audio_segments(self, audio_path, diarization_segments):
120
  """Cut audio into segments based on diarization results"""
121
  print("Cutting audio into segments...")
@@ -309,6 +361,47 @@ class WhisperTranscriber:
309
 
310
  return grouped_segments
311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  @spaces.GPU # each call gets a GPU slice
313
  def process_audio(self, audio_file, num_speakers=None, language=None,
314
  translate=False, prompt=None, group_segments=True):
@@ -345,7 +438,8 @@ class WhisperTranscriber:
345
  return {
346
  "segments": transcription_results,
347
  "language": detected_language,
348
- "num_speakers": detected_num_speakers
 
349
  }
350
 
351
  except Exception as e:
@@ -369,11 +463,13 @@ def format_segments_for_display(result):
369
  segments = result.get("segments", [])
370
  language = result.get("language", "unknown")
371
  num_speakers = result.get("num_speakers", 1)
 
372
 
373
  output = f"🎯 **Detection Results:**\n"
374
  output += f"- Language: {language}\n"
375
  output += f"- Speakers: {num_speakers}\n"
376
- output += f"- Segments: {len(segments)}\n\n"
 
377
 
378
  output += "📝 **Transcription:**\n\n"
379
 
@@ -389,16 +485,25 @@ def format_segments_for_display(result):
389
  return output
390
 
391
  @spaces.GPU
392
- def process_audio_gradio(audio_file, num_speakers, language, translate, prompt, group_segments):
393
  """Gradio interface function"""
394
- result = transcriber.process_audio(
395
- audio_file=audio_file,
396
- num_speakers=num_speakers if num_speakers > 0 else None,
397
- language=language if language != "auto" else None,
398
- translate=translate,
399
- prompt=prompt if prompt and prompt.strip() else None,
400
- group_segments=group_segments
401
- )
 
 
 
 
 
 
 
 
 
402
 
403
  formatted_output = format_segments_for_display(result)
404
  return formatted_output, result
@@ -424,16 +529,22 @@ with demo:
424
  audio_input = gr.Audio(
425
  label="🎵 Upload Audio File",
426
  type="filepath",
427
- #source="upload"
428
  )
429
 
430
  with gr.Accordion("⚙️ Advanced Settings", open=False):
 
 
 
 
 
 
431
  num_speakers = gr.Slider(
432
  minimum=0,
433
  maximum=20,
434
  value=0,
435
  step=1,
436
- label="Number of Speakers (0 = auto-detect)"
 
437
  )
438
 
439
  language = gr.Dropdown(
@@ -454,7 +565,7 @@ with demo:
454
  )
455
 
456
  group_segments = gr.Checkbox(
457
- label="Group segments by speaker",
458
  value=True
459
  )
460
 
@@ -471,6 +582,13 @@ with demo:
471
  visible=False
472
  )
473
 
 
 
 
 
 
 
 
474
  # Event handlers
475
  process_btn.click(
476
  fn=process_audio_gradio,
@@ -480,7 +598,8 @@ with demo:
480
  language,
481
  translate,
482
  prompt,
483
- group_segments
 
484
  ],
485
  outputs=[output_text, output_json]
486
  )
@@ -490,7 +609,7 @@ with demo:
490
  gr.Markdown("""
491
  - **Supported formats**: MP3, WAV, M4A, FLAC, OGG, and more
492
  - **Max duration**: Recommended under 10 minutes for optimal performance
493
- - **Speaker detection**: Works best with clear, distinct voices
494
  - **Languages**: Supports 100+ languages with auto-detection
495
  - **Vocabulary**: Add names and technical terms in the prompt for better accuracy
496
  """)
 
52
  _whisper = None
53
  _diarizer = None
54
 
 
55
  # Create global diarization pipeline
56
  try:
57
  print("Loading diarization model...")
 
62
  _diarizer = Pipeline.from_pretrained(
63
  "pyannote/speaker-diarization-3.1",
64
  use_auth_token=os.getenv("HF_TOKEN"),
 
65
  ).to(torch.device("cuda"))
66
+
 
 
 
 
 
 
 
 
67
  print("Diarization model loaded successfully")
68
  except Exception as e:
69
  import traceback
 
106
  except subprocess.CalledProcessError as e:
107
  raise RuntimeError(f"Audio conversion failed: {e}")
108
 
109
+ @spaces.GPU # each call gets a GPU slice
110
+ def transcribe_full_audio(self, audio_path, language=None, translate=False, prompt=None):
111
+ """Transcribe the entire audio file without speaker diarization"""
112
+ whisper, _ = _load_models() # models live on the GPU
113
+
114
+ print("Transcribing full audio...")
115
+ start_time = time.time()
116
+
117
+ # Prepare options
118
+ options = dict(
119
+ language=language,
120
+ beam_size=5,
121
+ vad_filter=True,
122
+ vad_parameters=VadOptions(
123
+ max_speech_duration_s=whisper.feature_extractor.chunk_length,
124
+ min_speech_duration_ms=100,
125
+ speech_pad_ms=100,
126
+ threshold=0.25,
127
+ neg_threshold=0.2,
128
+ ),
129
+ word_timestamps=True,
130
+ initial_prompt=prompt,
131
+ language_detection_segments=1,
132
+ task="translate" if translate else "transcribe",
133
+ )
134
+
135
+ # Transcribe the entire audio
136
+ segments, transcript_info = whisper.transcribe(audio_path, **options)
137
+ segments = list(segments)
138
+
139
+ detected_language = transcript_info.language
140
+
141
+ # Process segments
142
+ results = []
143
+ for seg in segments:
144
+ # Create result entry with detailed format
145
+ words_list = []
146
+ if seg.words:
147
+ for word in seg.words:
148
+ words_list.append({
149
+ "start": float(word.start),
150
+ "end": float(word.end),
151
+ "word": word.word,
152
+ "probability": word.probability,
153
+ "speaker": "SPEAKER_00" # No speaker identification in full transcription
154
+ })
155
+
156
+ results.append({
157
+ "start": float(seg.start),
158
+ "end": float(seg.end),
159
+ "text": seg.text,
160
+ "speaker": "SPEAKER_00", # Single speaker assumption
161
+ "avg_logprob": seg.avg_logprob,
162
+ "words": words_list,
163
+ "duration": float(seg.end - seg.start)
164
+ })
165
+
166
+ transcription_time = time.time() - start_time
167
+ print(f"Full audio transcribed in {transcription_time:.2f} seconds")
168
+
169
+ return results, detected_language
170
+
171
  def cut_audio_segments(self, audio_path, diarization_segments):
172
  """Cut audio into segments based on diarization results"""
173
  print("Cutting audio into segments...")
 
361
 
362
  return grouped_segments
363
 
364
+ @spaces.GPU # each call gets a GPU slice
365
+ def process_audio_full(self, audio_file, language=None, translate=False, prompt=None, group_segments=True):
366
+ """Process audio with full transcription (no speaker diarization)"""
367
+ if audio_file is None:
368
+ return {"error": "No audio file provided"}
369
+
370
+ converted_audio_path = None
371
+ try:
372
+ print("Starting full transcription pipeline...")
373
+
374
+ # Step 1: Convert audio format
375
+ print("Converting audio format...")
376
+ converted_audio_path = self.convert_audio_format(audio_file)
377
+
378
+ # Step 2: Transcribe the entire audio
379
+ transcription_results, detected_language = self.transcribe_full_audio(
380
+ converted_audio_path, language, translate, prompt
381
+ )
382
+
383
+ # Step 3: Group segments if requested (based on time gaps and sentence endings)
384
+ if group_segments:
385
+ transcription_results = self.group_segments_by_speaker(transcription_results)
386
+
387
+ # Step 4: Return results
388
+ return {
389
+ "segments": transcription_results,
390
+ "language": detected_language,
391
+ "num_speakers": 1, # Single speaker assumption
392
+ "transcription_method": "full_audio"
393
+ }
394
+
395
+ except Exception as e:
396
+ import traceback
397
+ traceback.print_exc()
398
+ return {"error": f"Processing failed: {str(e)}"}
399
+ finally:
400
+ # Clean up converted audio file
401
+ if converted_audio_path and os.path.exists(converted_audio_path):
402
+ os.unlink(converted_audio_path)
403
+ print("Cleaned up converted audio file")
404
+
405
  @spaces.GPU # each call gets a GPU slice
406
  def process_audio(self, audio_file, num_speakers=None, language=None,
407
  translate=False, prompt=None, group_segments=True):
 
438
  return {
439
  "segments": transcription_results,
440
  "language": detected_language,
441
+ "num_speakers": detected_num_speakers,
442
+ "transcription_method": "diarized_segments"
443
  }
444
 
445
  except Exception as e:
 
463
  segments = result.get("segments", [])
464
  language = result.get("language", "unknown")
465
  num_speakers = result.get("num_speakers", 1)
466
+ method = result.get("transcription_method", "unknown")
467
 
468
  output = f"🎯 **Detection Results:**\n"
469
  output += f"- Language: {language}\n"
470
  output += f"- Speakers: {num_speakers}\n"
471
+ output += f"- Segments: {len(segments)}\n"
472
+ output += f"- Method: {method}\n\n"
473
 
474
  output += "📝 **Transcription:**\n\n"
475
 
 
485
  return output
486
 
487
  @spaces.GPU
488
+ def process_audio_gradio(audio_file, num_speakers, language, translate, prompt, group_segments, use_diarization):
489
  """Gradio interface function"""
490
+ if use_diarization:
491
+ result = transcriber.process_audio(
492
+ audio_file=audio_file,
493
+ num_speakers=num_speakers if num_speakers > 0 else None,
494
+ language=language if language != "auto" else None,
495
+ translate=translate,
496
+ prompt=prompt if prompt and prompt.strip() else None,
497
+ group_segments=group_segments
498
+ )
499
+ else:
500
+ result = transcriber.process_audio_full(
501
+ audio_file=audio_file,
502
+ language=language if language != "auto" else None,
503
+ translate=translate,
504
+ prompt=prompt if prompt and prompt.strip() else None,
505
+ group_segments=group_segments
506
+ )
507
 
508
  formatted_output = format_segments_for_display(result)
509
  return formatted_output, result
 
529
  audio_input = gr.Audio(
530
  label="🎵 Upload Audio File",
531
  type="filepath",
 
532
  )
533
 
534
  with gr.Accordion("⚙️ Advanced Settings", open=False):
535
+ use_diarization = gr.Checkbox(
536
+ label="Enable Speaker Diarization",
537
+ value=True,
538
+ info="Uncheck for faster transcription without speaker identification"
539
+ )
540
+
541
  num_speakers = gr.Slider(
542
  minimum=0,
543
  maximum=20,
544
  value=0,
545
  step=1,
546
+ label="Number of Speakers (0 = auto-detect)",
547
+ visible=True
548
  )
549
 
550
  language = gr.Dropdown(
 
565
  )
566
 
567
  group_segments = gr.Checkbox(
568
+ label="Group segments by speaker/time",
569
  value=True
570
  )
571
 
 
582
  visible=False
583
  )
584
 
585
+ # Update visibility of num_speakers based on diarization toggle
586
+ use_diarization.change(
587
+ fn=lambda x: gr.update(visible=x),
588
+ inputs=[use_diarization],
589
+ outputs=[num_speakers]
590
+ )
591
+
592
  # Event handlers
593
  process_btn.click(
594
  fn=process_audio_gradio,
 
598
  language,
599
  translate,
600
  prompt,
601
+ group_segments,
602
+ use_diarization
603
  ],
604
  outputs=[output_text, output_json]
605
  )
 
609
  gr.Markdown("""
610
  - **Supported formats**: MP3, WAV, M4A, FLAC, OGG, and more
611
  - **Max duration**: Recommended under 10 minutes for optimal performance
612
+ - **Speaker diarization**: Enable for speaker identification (slower), disable for faster transcription
613
  - **Languages**: Supports 100+ languages with auto-detection
614
  - **Vocabulary**: Add names and technical terms in the prompt for better accuracy
615
  """)