liuyang commited on
Commit
70438f0
ยท
1 Parent(s): 7c60c3b

Implement audio preprocessing and speaker diarization enhancements in WhisperTranscriber. Introduce methods for audio chunk preparation, VAD-based trimming, and speaker embedding extraction. Update process_audio methods to utilize task JSON for improved workflow and metadata handling. Add webrtcvad dependency for voice activity detection.

Browse files
Files changed (2) hide show
  1. app.py +365 -187
  2. requirements.txt +2 -1
app.py CHANGED
@@ -32,9 +32,11 @@ from faster_whisper import WhisperModel, BatchedInferencePipeline
32
  from faster_whisper.vad import VadOptions
33
  import requests
34
  import base64
35
- from pyannote.audio import Pipeline
36
 
37
  import os, sys, importlib.util, pathlib, ctypes, tempfile, wave, math
 
 
38
  spec = importlib.util.find_spec("nvidia.cudnn")
39
  if spec is None:
40
  sys.exit("โŒ nvidia-cudnn-cu12 wheel not found. Run: pip install nvidia-cudnn-cu12")
@@ -53,6 +55,183 @@ from huggingface_hub import snapshot_download
53
  MODEL_REPO = "deepdml/faster-whisper-large-v3-turbo-ct2" # CT2 format
54
  LOCAL_DIR = f"{CACHE_ROOT}/whisper_turbo"
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  # Download once; later runs are instant
57
  snapshot_download(
58
  repo_id=MODEL_REPO,
@@ -66,6 +245,7 @@ model_cache_path = LOCAL_DIR # <โ€‘โ€‘ this is what we pass to WhisperModel
66
  _whisper = None
67
  _batched_whisper = None
68
  _diarizer = None
 
69
 
70
  # Create global diarization pipeline
71
  try:
@@ -108,24 +288,20 @@ class WhisperTranscriber:
108
  # do **not** create the models here!
109
  pass
110
 
111
- def convert_audio_format(self, audio_path):
112
- """Convert audio to 16kHz mono WAV format"""
113
- temp_wav = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
114
- temp_wav_path = temp_wav.name
115
- temp_wav.close()
116
-
117
  try:
118
- subprocess.run([
119
- "ffmpeg", "-i", audio_path,
120
- "-ar", "16000", "-ac", "1", "-c:a", "pcm_s16le",
121
- temp_wav_path, "-y"
122
- ], check=True, capture_output=True)
123
- return temp_wav_path
124
- except subprocess.CalledProcessError as e:
125
- raise RuntimeError(f"Audio conversion failed: {e}")
126
 
127
  @spaces.GPU # each call gets a GPU slice
128
- def transcribe_full_audio(self, audio_path, language=None, translate=False, prompt=None, batch_size=16):
129
  """Transcribe the entire audio file without speaker diarization using batched inference"""
130
  whisper, batched_whisper, _ = _load_models() # models live on the GPU
131
 
@@ -168,16 +344,16 @@ class WhisperTranscriber:
168
  if seg.words:
169
  for word in seg.words:
170
  words_list.append({
171
- "start": float(word.start),
172
- "end": float(word.end),
173
  "word": word.word,
174
  "probability": word.probability,
175
  "speaker": "SPEAKER_00" # No speaker identification in full transcription
176
  })
177
 
178
  results.append({
179
- "start": float(seg.start),
180
- "end": float(seg.end),
181
  "text": seg.text,
182
  "speaker": "SPEAKER_00", # Single speaker assumption
183
  "avg_logprob": seg.avg_logprob,
@@ -190,118 +366,14 @@ class WhisperTranscriber:
190
  print(results)
191
  return results, detected_language
192
 
193
- def cut_audio_segments(self, audio_path, diarization_segments):
194
- """Cut audio into segments based on diarization results"""
195
- print("Cutting audio into segments...")
196
-
197
- # Load the full audio
198
- waveform, sample_rate = torchaudio.load(audio_path)
199
-
200
- audio_segments = []
201
- for segment in diarization_segments:
202
- start_sample = int(segment["start"] * sample_rate)
203
- end_sample = int(segment["end"] * sample_rate)
204
-
205
- # Extract the segment
206
- segment_waveform = waveform[:, start_sample:end_sample]
207
-
208
- # Create temporary file for this segment
209
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
210
- temp_file.close()
211
-
212
- # Save the segment
213
- torchaudio.save(temp_file.name, segment_waveform, sample_rate)
214
-
215
- audio_segments.append({
216
- "audio_path": temp_file.name,
217
- "start": segment["start"],
218
- "end": segment["end"],
219
- "speaker": segment["speaker"]
220
- })
221
-
222
- return audio_segments
223
 
224
  @spaces.GPU # each call gets a GPU slice
225
- def transcribe_audio_segments(self, audio_segments, language=None, translate=False, prompt=None, batch_size=8):
226
- """Transcribe multiple audio segments using faster_whisper with batching"""
227
- whisper, batched_whisper, _ = _load_models() # models live on the GPU
228
-
229
- print(f"Transcribing {len(audio_segments)} audio segments with batch size {batch_size}...")
230
- start_time = time.time()
231
-
232
- # Prepare options
233
- options = dict(
234
- language=language,
235
- beam_size=5,
236
- vad_filter=True,
237
- vad_parameters=VadOptions(
238
- max_speech_duration_s=whisper.feature_extractor.chunk_length,
239
- min_speech_duration_ms=100,
240
- speech_pad_ms=100,
241
- threshold=0.25,
242
- neg_threshold=0.2,
243
- ),
244
- word_timestamps=True,
245
- initial_prompt=prompt,
246
- language_detection_segments=1,
247
- task="translate" if translate else "transcribe",
248
- )
249
-
250
- results = []
251
- detected_language = None
252
-
253
- for i, segment in enumerate(audio_segments):
254
- print(f"Processing segment {i+1}/{len(audio_segments)}")
255
-
256
- # Use batched inference for each segment
257
- segments, transcript_info = batched_whisper.transcribe(
258
- segment["audio_path"],
259
- batch_size=batch_size,
260
- **options
261
- )
262
- segments = list(segments)
263
-
264
- # Get detected language from first segment
265
- if detected_language is None:
266
- detected_language = transcript_info.language
267
-
268
- # Process each transcribed segment
269
- for seg in segments:
270
- # Create result entry with detailed format
271
- words_list = []
272
- if seg.words:
273
- for word in seg.words:
274
- words_list.append({
275
- "start": float(word.start) + segment["start"],
276
- "end": float(word.end) + segment["start"],
277
- "word": word.word,
278
- "probability": word.probability,
279
- "speaker": segment["speaker"]
280
- })
281
-
282
- results.append({
283
- "start": float(seg.start) + segment["start"],
284
- "end": float(seg.end) + segment["start"],
285
- "text": seg.text,
286
- "speaker": segment["speaker"],
287
- "avg_logprob": seg.avg_logprob,
288
- "words": words_list,
289
- "duration": float(seg.end - seg.start)
290
- })
291
-
292
- # Clean up temporary files
293
- for segment in audio_segments:
294
- if os.path.exists(segment["audio_path"]):
295
- os.unlink(segment["audio_path"])
296
-
297
- transcription_time = time.time() - start_time
298
- print(f"All segments transcribed in {transcription_time:.2f} seconds using batch size {batch_size}")
299
-
300
- return results, detected_language
301
 
302
  @spaces.GPU # each call gets a GPU slice
303
- def perform_diarization(self, audio_path, num_speakers=None):
304
- """Perform speaker diarization"""
305
  _, _, diarizer = _load_models() # models live on the GPU
306
 
307
  if diarizer is None:
@@ -309,11 +381,20 @@ class WhisperTranscriber:
309
  # Load audio to get duration
310
  waveform, sample_rate = torchaudio.load(audio_path)
311
  duration = waveform.shape[1] / sample_rate
 
 
 
 
 
 
 
 
 
312
  return [{
313
- "start": 0.0,
314
- "end": duration,
315
  "speaker": "SPEAKER_00"
316
- }], 1
317
 
318
  print("Starting diarization...")
319
  start_time = time.time()
@@ -330,21 +411,100 @@ class WhisperTranscriber:
330
  # Convert to list format
331
  diarize_segments = []
332
  diarization_list = list(diarization.itertracks(yield_label=True))
333
-
334
  for turn, _, speaker in diarization_list:
335
  diarize_segments.append({
336
- "start": turn.start,
337
- "end": turn.end,
338
  "speaker": speaker
339
  })
340
 
341
  unique_speakers = {speaker for segment in diarize_segments for speaker in [segment["speaker"]]}
342
  detected_num_speakers = len(unique_speakers)
343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  diarization_time = time.time() - start_time
345
  print(f"Diarization completed in {diarization_time:.2f} seconds")
346
 
347
- return diarize_segments, detected_num_speakers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
 
349
  def group_segments_by_speaker(self, segments, max_gap=1.0, max_duration=30.0):
350
  """Group consecutive segments from the same speaker"""
@@ -388,22 +548,27 @@ class WhisperTranscriber:
388
  return grouped_segments
389
 
390
  @spaces.GPU # each call gets a GPU slice
391
- def process_audio_full(self, audio_file, language=None, translate=False, prompt=None, group_segments=True, batch_size=16):
392
- """Process audio with full transcription (no speaker diarization)"""
393
- if audio_file is None:
394
- return {"error": "No audio file provided"}
395
-
396
- converted_audio_path = None
397
  try:
398
  print("Starting full transcription pipeline...")
399
 
400
- # Step 1: Convert audio format
401
- print("Converting audio format...")
402
- converted_audio_path = self.convert_audio_format(audio_file)
 
 
 
 
 
403
 
404
  # Step 2: Transcribe the entire audio with batching
405
  transcription_results, detected_language = self.transcribe_full_audio(
406
- converted_audio_path, language, translate, prompt, batch_size
407
  )
408
 
409
  # Step 3: Group segments if requested (based on time gaps and sentence endings)
@@ -424,38 +589,47 @@ class WhisperTranscriber:
424
  traceback.print_exc()
425
  return {"error": f"Processing failed: {str(e)}"}
426
  finally:
427
- # Clean up converted audio file
428
- if converted_audio_path and os.path.exists(converted_audio_path):
429
- os.unlink(converted_audio_path)
430
- print("Cleaned up converted audio file")
 
 
431
 
432
  @spaces.GPU # each call gets a GPU slice
433
- def process_audio(self, audio_file, num_speakers=None, language=None,
434
  translate=False, prompt=None, group_segments=True, batch_size=8):
435
- """Main processing function - diarization first, then transcription"""
436
- if audio_file is None:
437
- return {"error": "No audio file provided"}
438
-
439
- converted_audio_path = None
 
 
 
440
  try:
441
  print("Starting new processing pipeline...")
442
 
443
- # Step 1: Convert audio format first
444
- print("Converting audio format...")
445
- converted_audio_path = self.convert_audio_format(audio_file)
 
 
 
 
446
 
447
- # Step 2: Perform diarization on converted audio
448
- diarization_segments, detected_num_speakers = self.perform_diarization(
449
- converted_audio_path, num_speakers
450
  )
451
-
452
- # Step 3: Cut audio into segments based on diarization
453
- audio_segments = self.cut_audio_segments(converted_audio_path, diarization_segments)
454
-
455
- # Step 4: Transcribe each segment with batching
456
- transcription_results, detected_language = self.transcribe_audio_segments(
457
- audio_segments, language, translate, prompt, batch_size
458
  )
 
 
 
459
 
460
  # Step 5: Group segments if requested
461
  if group_segments:
@@ -467,7 +641,8 @@ class WhisperTranscriber:
467
  "language": detected_language,
468
  "num_speakers": detected_num_speakers,
469
  "transcription_method": "diarized_segments_batched",
470
- "batch_size": batch_size
 
471
  }
472
 
473
  except Exception as e:
@@ -475,10 +650,12 @@ class WhisperTranscriber:
475
  traceback.print_exc()
476
  return {"error": f"Processing failed: {str(e)}"}
477
  finally:
478
- # Clean up converted audio file
479
- if converted_audio_path and os.path.exists(converted_audio_path):
480
- os.unlink(converted_audio_path)
481
- print("Cleaned up converted audio file")
 
 
482
 
483
  # Initialize transcriber
484
  transcriber = WhisperTranscriber()
@@ -515,11 +692,11 @@ def format_segments_for_display(result):
515
  return output
516
 
517
  @spaces.GPU
518
- def process_audio_gradio(audio_file, num_speakers, language, translate, prompt, group_segments, use_diarization, batch_size):
519
  """Gradio interface function"""
520
  if use_diarization:
521
  result = transcriber.process_audio(
522
- audio_file=audio_file,
523
  num_speakers=num_speakers if num_speakers > 0 else None,
524
  language=language if language != "auto" else None,
525
  translate=translate,
@@ -529,7 +706,7 @@ def process_audio_gradio(audio_file, num_speakers, language, translate, prompt,
529
  )
530
  else:
531
  result = transcriber.process_audio_full(
532
- audio_file=audio_file,
533
  language=language if language != "auto" else None,
534
  translate=translate,
535
  prompt=prompt if prompt and prompt.strip() else None,
@@ -558,9 +735,10 @@ with demo:
558
 
559
  with gr.Row():
560
  with gr.Column():
561
- audio_input = gr.Audio(
562
- label="๐ŸŽต Upload Audio File",
563
- type="filepath",
 
564
  )
565
 
566
  with gr.Accordion("โš™๏ธ Advanced Settings", open=False):
@@ -572,7 +750,7 @@ with demo:
572
 
573
  batch_size = gr.Slider(
574
  minimum=1,
575
- maximum=32,
576
  value=16,
577
  step=1,
578
  label="Batch Size",
@@ -615,7 +793,7 @@ with demo:
615
  with gr.Column():
616
  output_text = gr.Markdown(
617
  label="๐Ÿ“ Transcription Results",
618
- value="Upload an audio file and click 'Transcribe Audio' to get started!"
619
  )
620
 
621
  output_json = gr.JSON(
@@ -634,7 +812,7 @@ with demo:
634
  process_btn.click(
635
  fn=process_audio_gradio,
636
  inputs=[
637
- audio_input,
638
  num_speakers,
639
  language,
640
  translate,
@@ -649,11 +827,11 @@ with demo:
649
  # Examples
650
  gr.Markdown("### ๐Ÿ“‹ Usage Tips:")
651
  gr.Markdown("""
652
- - **Supported formats**: MP3, WAV, M4A, FLAC, OGG, and more
653
- - **Batch Size**: Higher values (16-24) = faster processing but more GPU memory
654
- - **Speaker diarization**: Enable for speaker identification (slower), disable for faster transcription
655
- - **Languages**: Supports 100+ languages with auto-detection
656
- - **Vocabulary**: Add names and technical terms in the prompt for better accuracy
657
  """)
658
 
659
  if __name__ == "__main__":
 
32
  from faster_whisper.vad import VadOptions
33
  import requests
34
  import base64
35
+ from pyannote.audio import Pipeline, Inference
36
 
37
  import os, sys, importlib.util, pathlib, ctypes, tempfile, wave, math
38
+ import json
39
+ import webrtcvad
40
  spec = importlib.util.find_spec("nvidia.cudnn")
41
  if spec is None:
42
  sys.exit("โŒ nvidia-cudnn-cu12 wheel not found. Run: pip install nvidia-cudnn-cu12")
 
55
  MODEL_REPO = "deepdml/faster-whisper-large-v3-turbo-ct2" # CT2 format
56
  LOCAL_DIR = f"{CACHE_ROOT}/whisper_turbo"
57
 
58
+ # -----------------------------------------------------------------------------
59
+ # Audio preprocess helper (from input_and_preprocess rule)
60
+ # -----------------------------------------------------------------------------
61
+
62
+ TRIM_THRESHOLD_MS = 10_000 # 10 seconds
63
+ DEFAULT_PAD_MS = 250 # safety context around detected speech
64
+ FRAME_MS = 30 # VAD frame
65
+ HANG_MS = 240 # hangover (keep speech "on" after silence)
66
+ VAD_LEVEL = 2 # 0-3
67
+
68
+ def _decode_chunk_to_pcm(task: dict) -> bytes:
69
+ """Use ffmpeg to decode the chunk to s16le mono @ 16k PCM bytes."""
70
+ src = task["source_uri"]
71
+ ing = task["ingest_recipe"]
72
+ seek = task["ffmpeg_seek"]
73
+
74
+ cmd = [
75
+ "ffmpeg", "-nostdin", "-hide_banner", "-v", "error",
76
+ "-ss", f"{max(0.0, float(seek['pre_ss_sec'])):.3f}",
77
+ "-i", src,
78
+ "-map", "0:a:0",
79
+ "-ss", f"{float(seek['post_ss_sec']):.2f}",
80
+ "-t", f"{float(seek['t_sec']):.3f}",
81
+ ]
82
+
83
+ # Optional L/R extraction
84
+ if ing.get("channel_extract_filter"):
85
+ cmd += ["-af", ing["channel_extract_filter"]]
86
+
87
+ # Force mono 16k s16le to stdout
88
+ cmd += ["-ar", "16000", "-ac", "1", "-c:a", "pcm_s16le", "-f", "s16le", "pipe:1"]
89
+
90
+ p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
91
+ pcm, err = p.communicate()
92
+ if p.returncode != 0:
93
+ raise RuntimeError(f"ffmpeg failed: {err.decode('utf-8', 'ignore')}")
94
+ return pcm
95
+
96
+ def _find_head_tail_speech_ms(
97
+ pcm: bytes,
98
+ sr: int = 16000,
99
+ frame_ms: int = FRAME_MS,
100
+ vad_level: int = VAD_LEVEL,
101
+ hang_ms: int = HANG_MS,
102
+ ):
103
+ """Return (first_ms, last_ms) speech boundaries using webrtcvad with hangover."""
104
+ if not pcm:
105
+ return None, None
106
+ vad = webrtcvad.Vad(int(vad_level))
107
+ bpf = 2 # bytes per sample (s16)
108
+ samples_per_ms = sr // 1000 # 16
109
+ bytes_per_frame = samples_per_ms * bpf * frame_ms
110
+
111
+ n_frames = len(pcm) // bytes_per_frame
112
+ if n_frames == 0:
113
+ return None, None
114
+
115
+ first_ms, last_ms = None, None
116
+ t_ms = 0
117
+ in_speech = False
118
+ silence_run = 0
119
+
120
+ view = memoryview(pcm)[: n_frames * bytes_per_frame]
121
+ for i in range(n_frames):
122
+ frame = view[i * bytes_per_frame : (i + 1) * bytes_per_frame]
123
+ if vad.is_speech(frame, sr):
124
+ if first_ms is None:
125
+ first_ms = t_ms
126
+ in_speech = True
127
+ silence_run = 0
128
+ else:
129
+ if in_speech:
130
+ silence_run += frame_ms
131
+ if silence_run >= hang_ms:
132
+ last_ms = t_ms - (silence_run - hang_ms)
133
+ in_speech = False
134
+ silence_run = 0
135
+ t_ms += frame_ms
136
+ if in_speech:
137
+ last_ms = t_ms
138
+ return first_ms, last_ms
139
+
140
+ def _write_wav(path: str, pcm: bytes, sr: int = 16000):
141
+ os.makedirs(os.path.dirname(path), exist_ok=True)
142
+ with wave.open(path, "wb") as w:
143
+ w.setnchannels(1)
144
+ w.setsampwidth(2) # s16
145
+ w.setframerate(sr)
146
+ w.writeframes(pcm)
147
+
148
+ def prepare_and_save_audio_for_model(task: dict, out_dir: str) -> dict:
149
+ """
150
+ 1) Decode chunk to mono 16k PCM.
151
+ 2) Run VAD to locate head/tail silence.
152
+ 3) Trim only if head or tail silence >= 10s.
153
+ 4) Save the (possibly trimmed) WAV to local file.
154
+ 5) Return timing metadata, including 'trimmed_start_ms' to preserve global timestamps.
155
+ """
156
+ # 0) Names & constants
157
+ sr = 16000
158
+ bpf = 2
159
+ samples_per_ms = sr // 1000
160
+
161
+ def bytes_from_ms(ms: int) -> int:
162
+ return int(ms * samples_per_ms) * bpf
163
+
164
+ ch = task["channel"]
165
+ ck = task["chunk"]
166
+ job = task.get("job_id", "job")
167
+ idx = str(ck["idx"])
168
+
169
+ # 1) Decode chunk
170
+ pcm = _decode_chunk_to_pcm(task)
171
+ planned_dur_ms = int(ck["dur_ms"])
172
+
173
+ # 2) VAD head/tail detection
174
+ first_ms, last_ms = _find_head_tail_speech_ms(pcm, sr=sr)
175
+ head_sil_ms = int(first_ms) if first_ms is not None else planned_dur_ms
176
+ tail_sil_ms = int(planned_dur_ms - last_ms) if last_ms is not None else planned_dur_ms
177
+
178
+ # 3) Decide trimming (only if head or tail >= 10s)
179
+ trim_applied = False
180
+ eff_start_ms = 0
181
+ eff_end_ms = planned_dur_ms
182
+ trimmed_pcm = pcm
183
+
184
+ if (head_sil_ms >= TRIM_THRESHOLD_MS) or (tail_sil_ms >= TRIM_THRESHOLD_MS):
185
+ # If no speech found at all, mark skip
186
+ if first_ms is None or last_ms is None or last_ms <= first_ms:
187
+ out_wav_path = os.path.join(out_dir, f"{job}_{ch}_{idx}_nospeech.wav")
188
+ _write_wav(out_wav_path, b"", sr)
189
+ return {
190
+ "out_wav_path": out_wav_path,
191
+ "sr": sr,
192
+ "trim_applied": False,
193
+ "trimmed_start_ms": 0,
194
+ "head_silence_ms": head_sil_ms,
195
+ "tail_silence_ms": tail_sil_ms,
196
+ "effective_start_ms": 0,
197
+ "effective_dur_ms": 0,
198
+ "abs_start_ms": ck["global_offset_ms"],
199
+ "chunk_idx": idx,
200
+ "channel": ch,
201
+ "skip": True,
202
+ }
203
+
204
+ # Apply padding & slice
205
+ start_ms = max(0, int(first_ms) - DEFAULT_PAD_MS)
206
+ end_ms = min(planned_dur_ms, int(last_ms) + DEFAULT_PAD_MS)
207
+
208
+ if end_ms > start_ms:
209
+ eff_start_ms = start_ms
210
+ eff_end_ms = end_ms
211
+ trimmed_pcm = pcm[bytes_from_ms(start_ms) : bytes_from_ms(end_ms)]
212
+ trim_applied = True
213
+
214
+ # 4) Write WAV to local file (trimmed or original)
215
+ tag = "trim" if trim_applied else "full"
216
+ out_wav_path = os.path.join(out_dir, f"{job}_{ch}_{idx}_{tag}.wav")
217
+ _write_wav(out_wav_path, trimmed_pcm, sr)
218
+
219
+ # 5) Return metadata
220
+ return {
221
+ "out_wav_path": out_wav_path,
222
+ "sr": sr,
223
+ "trim_applied": trim_applied,
224
+ "trimmed_start_ms": eff_start_ms if trim_applied else 0,
225
+ "head_silence_ms": head_sil_ms,
226
+ "tail_silence_ms": tail_sil_ms,
227
+ "effective_start_ms": eff_start_ms,
228
+ "effective_dur_ms": eff_end_ms - eff_start_ms,
229
+ "abs_start_ms": int(ck["global_offset_ms"]) + eff_start_ms,
230
+ "chunk_idx": idx,
231
+ "channel": ch,
232
+ "skip": False if (trim_applied or len(pcm) > 0) else True,
233
+ }
234
+
235
  # Download once; later runs are instant
236
  snapshot_download(
237
  repo_id=MODEL_REPO,
 
245
  _whisper = None
246
  _batched_whisper = None
247
  _diarizer = None
248
+ _embedder = None
249
 
250
  # Create global diarization pipeline
251
  try:
 
288
  # do **not** create the models here!
289
  pass
290
 
291
+ def preprocess_from_task_json(self, task_json: str) -> dict:
292
+ """Parse task JSON and run prepare_and_save_audio_for_model, returning metadata."""
 
 
 
 
293
  try:
294
+ task = json.loads(task_json)
295
+ except Exception as e:
296
+ raise RuntimeError(f"Invalid JSON: {e}")
297
+
298
+ out_dir = os.path.join(CACHE_ROOT, "preprocessed")
299
+ os.makedirs(out_dir, exist_ok=True)
300
+ meta = prepare_and_save_audio_for_model(task, out_dir)
301
+ return meta
302
 
303
  @spaces.GPU # each call gets a GPU slice
304
+ def transcribe_full_audio(self, audio_path, language=None, translate=False, prompt=None, batch_size=16, base_offset_s: float = 0.0):
305
  """Transcribe the entire audio file without speaker diarization using batched inference"""
306
  whisper, batched_whisper, _ = _load_models() # models live on the GPU
307
 
 
344
  if seg.words:
345
  for word in seg.words:
346
  words_list.append({
347
+ "start": float(word.start) + float(base_offset_s),
348
+ "end": float(word.end) + float(base_offset_s),
349
  "word": word.word,
350
  "probability": word.probability,
351
  "speaker": "SPEAKER_00" # No speaker identification in full transcription
352
  })
353
 
354
  results.append({
355
+ "start": float(seg.start) + float(base_offset_s),
356
+ "end": float(seg.end) + float(base_offset_s),
357
  "text": seg.text,
358
  "speaker": "SPEAKER_00", # Single speaker assumption
359
  "avg_logprob": seg.avg_logprob,
 
366
  print(results)
367
  return results, detected_language
368
 
369
+ # Removed audio cutting; transcription is done once on the full (preprocessed) audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
 
371
  @spaces.GPU # each call gets a GPU slice
372
+ # Removed segment-wise transcription; using single full-audio transcription
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
 
374
  @spaces.GPU # each call gets a GPU slice
375
+ def perform_diarization(self, audio_path, num_speakers=None, base_offset_s: float = 0.0):
376
+ """Perform speaker diarization; return segments with global timestamps and per-speaker embeddings."""
377
  _, _, diarizer = _load_models() # models live on the GPU
378
 
379
  if diarizer is None:
 
381
  # Load audio to get duration
382
  waveform, sample_rate = torchaudio.load(audio_path)
383
  duration = waveform.shape[1] / sample_rate
384
+ # Try to compute a single-speaker embedding
385
+ speaker_embeddings = {}
386
+ try:
387
+ embedder = self._load_embedder()
388
+ # waveform is (1, T); embedder expects mono 1D
389
+ emb = embedder({"waveform": waveform.squeeze(0), "sample_rate": sample_rate})
390
+ speaker_embeddings["SPEAKER_00"] = emb.squeeze().tolist()
391
+ except Exception:
392
+ pass
393
  return [{
394
+ "start": 0.0 + float(base_offset_s),
395
+ "end": duration + float(base_offset_s),
396
  "speaker": "SPEAKER_00"
397
+ }], 1, speaker_embeddings
398
 
399
  print("Starting diarization...")
400
  start_time = time.time()
 
411
  # Convert to list format
412
  diarize_segments = []
413
  diarization_list = list(diarization.itertracks(yield_label=True))
414
+ print(diarization_list)
415
  for turn, _, speaker in diarization_list:
416
  diarize_segments.append({
417
+ "start": float(turn.start) + float(base_offset_s),
418
+ "end": float(turn.end) + float(base_offset_s),
419
  "speaker": speaker
420
  })
421
 
422
  unique_speakers = {speaker for segment in diarize_segments for speaker in [segment["speaker"]]}
423
  detected_num_speakers = len(unique_speakers)
424
 
425
+ # Compute per-speaker embeddings by averaging segment embeddings
426
+ speaker_embeddings = {}
427
+ try:
428
+ embedder = self._load_embedder()
429
+ spk_to_embs = {spk: [] for spk in unique_speakers}
430
+ for turn, _, speaker in diarization_list:
431
+ start_sample = int(float(turn.start) * sample_rate)
432
+ end_sample = int(float(turn.end) * sample_rate)
433
+ if end_sample > start_sample:
434
+ seg_wav = waveform[0, start_sample:end_sample].contiguous()
435
+ emb = embedder({"waveform": seg_wav, "sample_rate": sample_rate})
436
+ spk_to_embs[speaker].append(emb.squeeze())
437
+ # average
438
+ for spk, embs in spk_to_embs.items():
439
+ if len(embs) == 0:
440
+ continue
441
+ # stack and mean
442
+ try:
443
+ import torch as _torch
444
+ embs_tensor = _torch.stack([_torch.as_tensor(e) for e in embs], dim=0)
445
+ centroid = embs_tensor.mean(dim=0)
446
+ # L2 normalize
447
+ centroid = centroid / (centroid.norm(p=2) + 1e-12)
448
+ speaker_embeddings[spk] = centroid.cpu().tolist()
449
+ except Exception:
450
+ # fallback to first embedding
451
+ speaker_embeddings[spk] = embs[0].cpu().tolist()
452
+ except Exception:
453
+ pass
454
+
455
  diarization_time = time.time() - start_time
456
  print(f"Diarization completed in {diarization_time:.2f} seconds")
457
 
458
+ return diarize_segments, detected_num_speakers, speaker_embeddings
459
+
460
+ def _load_embedder(self):
461
+ """Lazy-load speaker embedding inference model on GPU."""
462
+ global _embedder
463
+ if _embedder is None:
464
+ # window="whole" to compute one embedding per provided chunk
465
+ _embedder = Inference("pyannote/embedding", window="whole", device=torch.device("cuda"))
466
+ return _embedder
467
+
468
+ def assign_speakers_to_transcription(self, transcription_results, diarization_segments):
469
+ """Assign speakers to words and segments based on overlap with diarization segments."""
470
+ if not diarization_segments:
471
+ return transcription_results
472
+ # simple helper to find speaker at given time
473
+ def speaker_at(t: float):
474
+ for seg in diarization_segments:
475
+ if seg["start"] <= t < seg["end"]:
476
+ return seg["speaker"]
477
+ # if not inside, return closest segment's speaker
478
+ closest = None
479
+ best = float("inf")
480
+ for seg in diarization_segments:
481
+ if t < seg["start"]:
482
+ d = seg["start"] - t
483
+ elif t > seg["end"]:
484
+ d = t - seg["end"]
485
+ else:
486
+ d = 0.0
487
+ if d < best:
488
+ best = d
489
+ closest = seg
490
+ return closest["speaker"] if closest else "SPEAKER_00"
491
+
492
+ for seg in transcription_results:
493
+ # Assign per-word speakers
494
+ if seg.get("words"):
495
+ speaker_counts = {}
496
+ for w in seg["words"]:
497
+ mid = (float(w["start"]) + float(w["end"])) / 2.0
498
+ spk = speaker_at(mid)
499
+ w["speaker"] = spk
500
+ speaker_counts[spk] = speaker_counts.get(spk, 0) + (float(w["end"]) - float(w["start"]))
501
+ # Segment speaker = speaker with max accumulated word duration
502
+ if speaker_counts:
503
+ seg["speaker"] = max(speaker_counts.items(), key=lambda kv: kv[1])[0]
504
+ else:
505
+ mid = (float(seg["start"]) + float(seg["end"])) / 2.0
506
+ seg["speaker"] = speaker_at(mid)
507
+ return transcription_results
508
 
509
  def group_segments_by_speaker(self, segments, max_gap=1.0, max_duration=30.0):
510
  """Group consecutive segments from the same speaker"""
 
548
  return grouped_segments
549
 
550
  @spaces.GPU # each call gets a GPU slice
551
+ def process_audio_full(self, task_json, language=None, translate=False, prompt=None, group_segments=True, batch_size=16):
552
+ """Process a single chunk using task JSON (no diarization)."""
553
+ if not task_json or not str(task_json).strip():
554
+ return {"error": "No JSON provided"}
555
+
556
+ pre_meta = None
557
  try:
558
  print("Starting full transcription pipeline...")
559
 
560
+ # Step 1: Preprocess per chunk JSON
561
+ print("Preprocessing chunk JSON...")
562
+ pre_meta = self.preprocess_from_task_json(task_json)
563
+ if pre_meta.get("skip"):
564
+ return {"segments": [], "language": "unknown", "num_speakers": 1, "transcription_method": "full_audio_batched", "batch_size": batch_size}
565
+ wav_path = pre_meta["out_wav_path"]
566
+ # Adjust timestamps by trimmed_start_ms: abs_start_ms is already global start for saved file
567
+ base_offset_s = float(pre_meta.get("abs_start_ms", 0)) / 1000.0
568
 
569
  # Step 2: Transcribe the entire audio with batching
570
  transcription_results, detected_language = self.transcribe_full_audio(
571
+ wav_path, language, translate, prompt, batch_size, base_offset_s=base_offset_s
572
  )
573
 
574
  # Step 3: Group segments if requested (based on time gaps and sentence endings)
 
589
  traceback.print_exc()
590
  return {"error": f"Processing failed: {str(e)}"}
591
  finally:
592
+ # Clean up preprocessed wav
593
+ if pre_meta and pre_meta.get("out_wav_path") and os.path.exists(pre_meta["out_wav_path"]):
594
+ try:
595
+ os.unlink(pre_meta["out_wav_path"])
596
+ except Exception:
597
+ pass
598
 
599
  @spaces.GPU # each call gets a GPU slice
600
+ def process_audio(self, task_json, num_speakers=None, language=None,
601
  translate=False, prompt=None, group_segments=True, batch_size=8):
602
+ """Main processing function with diarization using task JSON for a single chunk.
603
+
604
+ Transcribes full (preprocessed) audio once, performs diarization, merges speakers into transcription.
605
+ """
606
+ if not task_json or not str(task_json).strip():
607
+ return {"error": "No JSON provided"}
608
+
609
+ pre_meta = None
610
  try:
611
  print("Starting new processing pipeline...")
612
 
613
+ # Step 1: Preprocess per chunk JSON
614
+ print("Preprocessing chunk JSON...")
615
+ pre_meta = self.preprocess_from_task_json(task_json)
616
+ if pre_meta.get("skip"):
617
+ return {"segments": [], "language": "unknown", "num_speakers": 0, "transcription_method": "diarized_segments_batched", "batch_size": batch_size}
618
+ wav_path = pre_meta["out_wav_path"]
619
+ base_offset_s = float(pre_meta.get("abs_start_ms", 0)) / 1000.0
620
 
621
+ # Step 2: Transcribe full audio once
622
+ transcription_results, detected_language = self.transcribe_full_audio(
623
+ wav_path, language, translate, prompt, batch_size, base_offset_s=base_offset_s
624
  )
625
+
626
+ # Step 3: Perform diarization with global offset
627
+ diarization_segments, detected_num_speakers, speaker_embeddings = self.perform_diarization(
628
+ wav_path, num_speakers, base_offset_s=base_offset_s
 
 
 
629
  )
630
+
631
+ # Step 4: Merge diarization into transcription (assign speakers)
632
+ transcription_results = self.assign_speakers_to_transcription(transcription_results, diarization_segments)
633
 
634
  # Step 5: Group segments if requested
635
  if group_segments:
 
641
  "language": detected_language,
642
  "num_speakers": detected_num_speakers,
643
  "transcription_method": "diarized_segments_batched",
644
+ "batch_size": batch_size,
645
+ "speaker_embeddings": speaker_embeddings,
646
  }
647
 
648
  except Exception as e:
 
650
  traceback.print_exc()
651
  return {"error": f"Processing failed: {str(e)}"}
652
  finally:
653
+ # Clean up preprocessed wav
654
+ if pre_meta and pre_meta.get("out_wav_path") and os.path.exists(pre_meta["out_wav_path"]):
655
+ try:
656
+ os.unlink(pre_meta["out_wav_path"])
657
+ except Exception:
658
+ pass
659
 
660
  # Initialize transcriber
661
  transcriber = WhisperTranscriber()
 
692
  return output
693
 
694
  @spaces.GPU
695
+ def process_audio_gradio(task_json, num_speakers, language, translate, prompt, group_segments, use_diarization, batch_size):
696
  """Gradio interface function"""
697
  if use_diarization:
698
  result = transcriber.process_audio(
699
+ task_json=task_json,
700
  num_speakers=num_speakers if num_speakers > 0 else None,
701
  language=language if language != "auto" else None,
702
  translate=translate,
 
706
  )
707
  else:
708
  result = transcriber.process_audio_full(
709
+ task_json=task_json,
710
  language=language if language != "auto" else None,
711
  translate=translate,
712
  prompt=prompt if prompt and prompt.strip() else None,
 
735
 
736
  with gr.Row():
737
  with gr.Column():
738
+ task_json_input = gr.Textbox(
739
+ label="๐Ÿงพ Paste Task JSON",
740
+ placeholder="Paste the per-chunk task JSON here...",
741
+ lines=16,
742
  )
743
 
744
  with gr.Accordion("โš™๏ธ Advanced Settings", open=False):
 
750
 
751
  batch_size = gr.Slider(
752
  minimum=1,
753
+ maximum=128,
754
  value=16,
755
  step=1,
756
  label="Batch Size",
 
793
  with gr.Column():
794
  output_text = gr.Markdown(
795
  label="๐Ÿ“ Transcription Results",
796
+ value="Paste task JSON and click 'Transcribe Audio' to get started!"
797
  )
798
 
799
  output_json = gr.JSON(
 
812
  process_btn.click(
813
  fn=process_audio_gradio,
814
  inputs=[
815
+ task_json_input,
816
  num_speakers,
817
  language,
818
  translate,
 
827
  # Examples
828
  gr.Markdown("### ๐Ÿ“‹ Usage Tips:")
829
  gr.Markdown("""
830
+ - Paste a single-chunk task JSON matching the preprocess schema
831
+ - Batch Size: Higher values (16-24) = faster but uses more GPU memory
832
+ - Speaker diarization: Enable for speaker identification (slower)
833
+ - Languages: Supports 100+ languages with auto-detection
834
+ - Vocabulary: Add names and technical terms in the prompt for better accuracy
835
  """)
836
 
837
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -19,4 +19,5 @@ librosa>=0.10.0
19
  soundfile>=0.12.0
20
  ffmpeg-python>=0.2.0
21
  requests>=2.28.0
22
- nvidia-cudnn-cu12==9.1.0.70 # any 9.1.x that pip can find is fine
 
 
19
  soundfile>=0.12.0
20
  ffmpeg-python>=0.2.0
21
  requests>=2.28.0
22
+ nvidia-cudnn-cu12==9.1.0.70 # any 9.1.x that pip can find is fine
23
+ webrtcvad>=2.0.10