liuyang commited on
Commit
e3d9c9e
·
1 Parent(s): a4568c6

Refactor model management and transcription process: Introduced a model registry for easier management of Whisper models, added functionality to download models on startup, and streamlined the audio processing pipeline to support both chunk and segment transcriptions with improved error handling and cleanup.

Browse files
Files changed (1) hide show
  1. app.py +341 -66
app.py CHANGED
@@ -112,8 +112,45 @@ def upload_data_to_r2(data, bucket_name, object_name, content_type='application/
112
  return False
113
 
114
  from huggingface_hub import snapshot_download
115
- MODEL_REPO = "deepdml/faster-whisper-large-v3-turbo-ct2" # CT2 format
116
- LOCAL_DIR = f"{CACHE_ROOT}/whisper_turbo"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  # -----------------------------------------------------------------------------
119
  # Audio preprocess helper (from input_and_preprocess rule)
@@ -207,6 +244,55 @@ def _write_wav(path: str, pcm: bytes, sr: int = 16000):
207
 
208
  def prepare_and_save_audio_for_model(task: dict, out_dir: str) -> dict:
209
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  1) Decode chunk to mono 16k PCM.
211
  2) Run VAD to locate head/tail silence.
212
  3) Trim only if head or tail >= 10s.
@@ -294,17 +380,17 @@ def prepare_and_save_audio_for_model(task: dict, out_dir: str) -> dict:
294
  }
295
 
296
  # Download once; later runs are instant
297
- snapshot_download(
298
- repo_id=MODEL_REPO,
299
- local_dir=LOCAL_DIR,
300
- local_dir_use_symlinks=True, # saves disk space
301
- resume_download=True
302
- )
303
- model_cache_path = LOCAL_DIR # <‑‑ this is what we pass to WhisperModel
304
 
305
  # Lazy global holder ----------------------------------------------------------
306
- _whisper = None
307
- _batched_whisper = None
308
  _diarizer = None
309
  _embedder = None
310
 
@@ -328,20 +414,32 @@ except Exception as e:
328
  _diarizer = None
329
 
330
  @spaces.GPU # GPU is guaranteed to exist *inside* this function
331
- def _load_models():
332
- global _whisper, _batched_whisper, _diarizer
333
- if _whisper is None:
334
- print("Loading Whisper model...")
335
- _whisper = WhisperModel(
 
 
 
 
336
  model_cache_path,
337
  device="cuda",
338
  compute_type="float16",
339
  )
340
 
341
  # Create batched inference pipeline for improved performance
342
- _batched_whisper = BatchedInferencePipeline(model=_whisper)
343
- print("Whisper model and batched pipeline loaded successfully")
344
- return _whisper, _batched_whisper, _diarizer
 
 
 
 
 
 
 
 
345
 
346
  # -----------------------------------------------------------------------------
347
  class WhisperTranscriber:
@@ -362,11 +460,11 @@ class WhisperTranscriber:
362
  return meta
363
 
364
  @spaces.GPU # each call gets a GPU slice
365
- def transcribe_full_audio(self, audio_path, language=None, translate=False, prompt=None, batch_size=16, base_offset_s: float = 0.0, clip_timestamps=None):
366
  """Transcribe the entire audio file without speaker diarization using batched inference"""
367
- whisper, batched_whisper, _ = _load_models() # models live on the GPU
368
 
369
- print(f"Transcribing full audio with batch size {batch_size}...")
370
  start_time = time.time()
371
 
372
  # Prepare options for batched inference
@@ -383,12 +481,14 @@ class WhisperTranscriber:
383
  options["vad_filter"] = False
384
  options["clip_timestamps"] = clip_timestamps
385
  else:
 
386
  options["vad_filter"] = True # VAD is enabled by default for batched transcription
387
- options["vad_parameters"] = VadOptions(
388
  max_speech_duration_s=whisper.feature_extractor.chunk_length,
389
- min_speech_duration_ms=150, # ignore ultra-short blips
390
- min_silence_duration_ms=150, # split on short Mandarin pauses (if supported) speech_pad_ms=100,
391
- threshold=0.25,
 
392
  neg_threshold=0.2,
393
  )
394
  if batch_size > 1:
@@ -439,10 +539,7 @@ class WhisperTranscriber:
439
  return results, detected_language
440
 
441
  # Removed audio cutting; transcription is done once on the full (preprocessed) audio
442
-
443
- @spaces.GPU # each call gets a GPU slice
444
- # Removed segment-wise transcription; using single full-audio transcription
445
-
446
  @spaces.GPU # each call gets a GPU slice
447
  def perform_diarization(self, audio_path, num_speakers=None, base_offset_s: float = 0.0):
448
  """Perform speaker diarization; return segments with global timestamps and per-speaker embeddings."""
@@ -840,47 +937,215 @@ class WhisperTranscriber:
840
 
841
  return grouped_segments
842
 
843
- @spaces.GPU # each call gets a GPU slice
844
- def process_audio_full(self, task_json, language=None, translate=False, prompt=None, group_segments=True, batch_size=16):
845
- """Process a single chunk using task JSON (no diarization)."""
 
 
 
 
846
  if not task_json or not str(task_json).strip():
847
  return {"error": "No JSON provided"}
848
 
849
  pre_meta = None
850
  try:
851
- print("Starting full transcription pipeline...")
852
 
853
  # Step 1: Preprocess per chunk JSON
854
  print("Preprocessing chunk JSON...")
855
  pre_meta = self.preprocess_from_task_json(task_json)
856
- if pre_meta.get("skip"):
857
- return {"segments": [], "language": "unknown", "num_speakers": 1, "transcription_method": "full_audio_batched", "batch_size": batch_size}
858
- wav_path = pre_meta["out_wav_path"]
859
- # Adjust timestamps by trimmed_start_ms: abs_start_ms is already global start for saved file
860
- base_offset_s = float(pre_meta.get("abs_start_ms", 0)) / 1000.0
861
-
862
- # Step 2: Transcribe the entire audio with batching
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
863
  transcription_results, detected_language = self.transcribe_full_audio(
864
- wav_path, language, translate, prompt, batch_size, base_offset_s=base_offset_s
865
  )
866
 
867
- # Step 3: Group segments if requested (based on time gaps and sentence endings)
868
- if group_segments:
869
- transcription_results = self.group_segments_by_speaker(transcription_results)
870
-
871
- # Step 4: Return results
872
- return {
873
  "segments": transcription_results,
874
  "language": detected_language,
875
- "num_speakers": 1, # Single speaker assumption
876
- "transcription_method": "full_audio_batched",
877
- "batch_size": batch_size
878
  }
 
 
 
 
 
 
 
 
879
 
880
  except Exception as e:
881
  import traceback
882
  traceback.print_exc()
883
  return {"error": f"Processing failed: {str(e)}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
884
  finally:
885
  # Clean up preprocessed wav
886
  if pre_meta and pre_meta.get("out_wav_path") and os.path.exists(pre_meta["out_wav_path"]):
@@ -888,10 +1153,10 @@ class WhisperTranscriber:
888
  os.unlink(pre_meta["out_wav_path"])
889
  except Exception:
890
  pass
891
-
892
  @spaces.GPU # each call gets a GPU slice
893
  def process_audio(self, task_json, num_speakers=None, language=None,
894
- translate=False, prompt=None, group_segments=True, batch_size=8):
895
  """Main processing function with diarization using task JSON for a single chunk.
896
 
897
  Transcribes full (preprocessed) audio once, performs diarization, merges speakers into transcription.
@@ -929,7 +1194,7 @@ class WhisperTranscriber:
929
 
930
  # Step 2: Transcribe full audio once
931
  transcription_results, detected_language = self.transcribe_full_audio(
932
- wav_path, language, translate, prompt, batch_size, base_offset_s=base_offset_s, clip_timestamps=None
933
  )
934
 
935
  unmatched_diarization_segments = []
@@ -967,6 +1232,7 @@ class WhisperTranscriber:
967
  prompt=prompt,
968
  batch_size=batch_size,
969
  base_offset_s=d_start,
 
970
  )
971
  extra_segments.extend(seg_transcription)
972
  finally:
@@ -1051,30 +1317,31 @@ def format_segments_for_display(result):
1051
  return output
1052
 
1053
  @spaces.GPU
1054
- def process_audio_gradio(task_json, num_speakers, language, translate, prompt, group_segments, use_diarization, batch_size):
1055
  """Gradio interface function"""
1056
- if use_diarization:
1057
- result = transcriber.process_audio(
1058
  task_json=task_json,
1059
  num_speakers=num_speakers if num_speakers > 0 else None,
1060
  language=language if language != "auto" else None,
1061
  translate=translate,
1062
  prompt=prompt if prompt and prompt.strip() else None,
1063
  group_segments=group_segments,
1064
- batch_size=batch_size
 
1065
  )
1066
- else:
1067
- result = transcriber.process_audio_full(
1068
  task_json=task_json,
1069
  language=language if language != "auto" else None,
1070
  translate=translate,
1071
  prompt=prompt if prompt and prompt.strip() else None,
1072
- group_segments=group_segments,
1073
- batch_size=batch_size
1074
  )
1075
-
1076
- formatted_output = format_segments_for_display(result)
1077
- return formatted_output, result
1078
 
1079
  # Create Gradio interface
1080
  demo = gr.Blocks(
@@ -1101,6 +1368,13 @@ with demo:
1101
  )
1102
 
1103
  with gr.Accordion("⚙️ Advanced Settings", open=False):
 
 
 
 
 
 
 
1104
  use_diarization = gr.Checkbox(
1105
  label="Enable Speaker Diarization",
1106
  value=True,
@@ -1178,7 +1452,8 @@ with demo:
1178
  prompt,
1179
  group_segments,
1180
  use_diarization,
1181
- batch_size
 
1182
  ],
1183
  outputs=[output_text, output_json]
1184
  )
 
112
  return False
113
 
114
  from huggingface_hub import snapshot_download
115
+
116
+ # -----------------------------------------------------------------------------
117
+ # Model Management
118
+ # -----------------------------------------------------------------------------
119
+ MODELS = {
120
+ "large-v3-turbo": {
121
+ "repo_id": "deepdml/faster-whisper-large-v3-turbo-ct2",
122
+ "local_dir": f"{CACHE_ROOT}/whisper_turbo_v3"
123
+ },
124
+ "large-v3": {
125
+ "repo_id": "Systran/faster-whisper-large-v3",
126
+ "local_dir": f"{CACHE_ROOT}/whisper_large_v3"
127
+ },
128
+ "large-v2": {
129
+ "repo_id": "Systran/faster-whisper-large-v2",
130
+ "local_dir": f"{CACHE_ROOT}/whisper_large_v2"
131
+ },
132
+ }
133
+ DEFAULT_MODEL = "large-v3-turbo"
134
+
135
+ def _download_model(model_name: str):
136
+ """Downloads a model from the hub if not already present."""
137
+ if model_name not in MODELS:
138
+ raise ValueError(f"Model '{model_name}' not found in MODELS registry.")
139
+
140
+ model_info = MODELS[model_name]
141
+ if not os.path.exists(model_info["local_dir"]):
142
+ print(f"Downloading model '{model_name}' from {model_info['repo_id']}...")
143
+ snapshot_download(
144
+ repo_id=model_info["repo_id"],
145
+ local_dir=model_info["local_dir"],
146
+ local_dir_use_symlinks=True,
147
+ resume_download=True
148
+ )
149
+ return model_info["local_dir"]
150
+
151
+ # Download the default model on startup
152
+ _download_model(DEFAULT_MODEL)
153
+
154
 
155
  # -----------------------------------------------------------------------------
156
  # Audio preprocess helper (from input_and_preprocess rule)
 
244
 
245
  def prepare_and_save_audio_for_model(task: dict, out_dir: str) -> dict:
246
  """
247
+ 1) Decode chunk(s) to mono 16k PCM.
248
+ 2) Run VAD to locate head/tail silence.
249
+ 3) Trim only if head or tail >= 10s.
250
+ 4) Save the (possibly trimmed) WAV to local file(s).
251
+ 5) Return timing metadata, including 'trimmed_start_ms' to preserve global timestamps.
252
+
253
+ Args:
254
+ task: dict containing either:
255
+ - "chunk": single chunk dict, or
256
+ - "chunk": list of chunk dicts
257
+ out_dir: output directory for WAV files
258
+
259
+ Returns:
260
+ A wrapper dict with general fields (e.g., job_id, channel, sr, filekey)
261
+ and a "chunks" array containing metadata dict(s) for each processed chunk.
262
+ This structure is returned for both single and multiple chunk inputs.
263
+ """
264
+ chunks = task["chunk"]
265
+ result = {
266
+ "job_id": task.get("job_id", "job"),
267
+ "channel": task["channel"],
268
+ "sr": 16000,
269
+ "options": task.get("options", None),
270
+ "filekey": task.get("filekey", None),
271
+ }
272
+
273
+ # Handle both single chunk and multiple chunks
274
+ if isinstance(chunks, list):
275
+ # Process multiple chunks
276
+ results = []
277
+ for chunk in chunks:
278
+ # Create a task for each chunk
279
+ single_chunk_task = task.copy()
280
+ single_chunk_task["chunk"] = chunk
281
+ result = _process_single_chunk(single_chunk_task, out_dir)
282
+ results.append(result)
283
+ # Compose wrapper dict with general fields applicable to all chunks
284
+ result["chunks"] = results
285
+ else:
286
+ # Process single chunk and wrap in the standard response structure
287
+ result = _process_single_chunk(task, out_dir)
288
+ result["chunk"] = result
289
+ return result
290
+
291
+
292
+ def _process_single_chunk(task: dict, out_dir: str) -> dict:
293
+ """
294
+ Process a single chunk - extracted from the original prepare_and_save_audio_for_model logic.
295
+
296
  1) Decode chunk to mono 16k PCM.
297
  2) Run VAD to locate head/tail silence.
298
  3) Trim only if head or tail >= 10s.
 
380
  }
381
 
382
  # Download once; later runs are instant
383
+ # snapshot_download(
384
+ # repo_id=MODEL_REPO,
385
+ # local_dir=LOCAL_DIR,
386
+ # local_dir_use_symlinks=True, # saves disk space
387
+ # resume_download=True
388
+ # )
389
+ # model_cache_path = LOCAL_DIR # <‑‑ this is what we pass to WhisperModel
390
 
391
  # Lazy global holder ----------------------------------------------------------
392
+ _whisper_models = {}
393
+ _batched_whisper_models = {}
394
  _diarizer = None
395
  _embedder = None
396
 
 
414
  _diarizer = None
415
 
416
  @spaces.GPU # GPU is guaranteed to exist *inside* this function
417
+ def _load_models(model_name: str = DEFAULT_MODEL):
418
+ global _whisper_models, _batched_whisper_models, _diarizer
419
+
420
+ if model_name not in _whisper_models:
421
+ print(f"Loading Whisper model '{model_name}'...")
422
+
423
+ model_cache_path = _download_model(model_name)
424
+
425
+ model = WhisperModel(
426
  model_cache_path,
427
  device="cuda",
428
  compute_type="float16",
429
  )
430
 
431
  # Create batched inference pipeline for improved performance
432
+ batched_model = BatchedInferencePipeline(model=model)
433
+
434
+ _whisper_models[model_name] = model
435
+ _batched_whisper_models[model_name] = batched_model
436
+
437
+ print(f"Whisper model '{model_name}' and batched pipeline loaded successfully")
438
+
439
+ whisper = _whisper_models[model_name]
440
+ batched_whisper = _batched_whisper_models[model_name]
441
+
442
+ return whisper, batched_whisper, _diarizer
443
 
444
  # -----------------------------------------------------------------------------
445
  class WhisperTranscriber:
 
460
  return meta
461
 
462
  @spaces.GPU # each call gets a GPU slice
463
+ def transcribe_full_audio(self, audio_path, language=None, translate=False, prompt=None, batch_size=16, base_offset_s: float = 0.0, clip_timestamps=None, model_name: str = DEFAULT_MODEL, transcribe_options: dict = None):
464
  """Transcribe the entire audio file without speaker diarization using batched inference"""
465
+ whisper, batched_whisper, _ = _load_models(model_name) # models live on the GPU
466
 
467
+ print(f"Transcribing full audio with '{model_name}' and batch size {batch_size}...")
468
  start_time = time.time()
469
 
470
  # Prepare options for batched inference
 
481
  options["vad_filter"] = False
482
  options["clip_timestamps"] = clip_timestamps
483
  else:
484
+ vad_options = transcribe_options.get("vad_parameters", None)
485
  options["vad_filter"] = True # VAD is enabled by default for batched transcription
486
+ options["vad_parameters"] = VadOptions(**vad_options) if vad_options else VadOptions(
487
  max_speech_duration_s=whisper.feature_extractor.chunk_length,
488
+ min_speech_duration_ms=180, # ignore ultra-short blips
489
+ min_silence_duration_ms=120, # split on short Mandarin pauses (if supported)
490
+ speech_pad_ms=120,
491
+ threshold=0.35,
492
  neg_threshold=0.2,
493
  )
494
  if batch_size > 1:
 
539
  return results, detected_language
540
 
541
  # Removed audio cutting; transcription is done once on the full (preprocessed) audio
542
+
 
 
 
543
  @spaces.GPU # each call gets a GPU slice
544
  def perform_diarization(self, audio_path, num_speakers=None, base_offset_s: float = 0.0):
545
  """Perform speaker diarization; return segments with global timestamps and per-speaker embeddings."""
 
937
 
938
  return grouped_segments
939
 
940
+ @spaces.GPU
941
+ def process_audio_transcribe(self, task_json, language=None,
942
+ translate=False, prompt=None, batch_size=8, model_name: str = DEFAULT_MODEL):
943
+ """Main processing function with diarization using task JSON for a single chunk.
944
+
945
+ Transcribes full (preprocessed) audio once, performs diarization, merges speakers into transcription.
946
+ """
947
  if not task_json or not str(task_json).strip():
948
  return {"error": "No JSON provided"}
949
 
950
  pre_meta = None
951
  try:
952
+ print("Starting new processing pipeline...")
953
 
954
  # Step 1: Preprocess per chunk JSON
955
  print("Preprocessing chunk JSON...")
956
  pre_meta = self.preprocess_from_task_json(task_json)
957
+ transcribe_options = pre_meta.get("options", None)
958
+ if "chunk" in pre_meta:
959
+ self.transcribe_chunk(pre_meta, language, translate, prompt, batch_size, model_name, transcribe_options)
960
+ elif "segments" in pre_meta:
961
+ self.transcribe_segments(pre_meta, language, translate, prompt, batch_size, model_name, transcribe_options)
962
+ except Exception as e:
963
+ import traceback
964
+ traceback.print_exc()
965
+ return {"error": f"Processing failed: {str(e)}"}
966
+
967
+
968
+ @spaces.GPU
969
+ def transcribe_chunk(self, pre_meta, language=None,
970
+ translate=False, prompt=None, batch_size=8, model_name: str = DEFAULT_MODEL, transcribe_options: dict = None):
971
+ """Main processing function with diarization using task JSON for a single chunk.
972
+
973
+ Transcribes full (preprocessed) audio once, performs diarization, merges speakers into transcription.
974
+ """
975
+ try:
976
+ print("Transcribing chunk...")
977
+ # Step 1: Preprocess per chunk JSON
978
+ if pre_meta["chunk"].get("skip"):
979
+ return {"segments": [], "language": "unknown", "num_speakers": 0, "transcription_method": "diarized_segments_batched", "batch_size": batch_size}
980
+ wav_path = pre_meta["chunk"]["out_wav_path"]
981
+ base_offset_s = float(pre_meta["chunk"].get("abs_start_ms", 0)) / 1000.0
982
+
983
+ # Step 2: Transcribe full audio once
984
  transcription_results, detected_language = self.transcribe_full_audio(
985
+ wav_path, language, translate, prompt, batch_size, base_offset_s=base_offset_s, clip_timestamps=None, model_name=model_name, transcribe_options=transcribe_options
986
  )
987
 
988
+ # Step 6: Return results
989
+ result = {
 
 
 
 
990
  "segments": transcription_results,
991
  "language": detected_language,
992
+ "batch_size": batch_size,
 
 
993
  }
994
+ # job_id = pre_meta["job_id"]
995
+ # task_id = pre_meta["chunk_idx"]
996
+ filekey = pre_meta["filekey"]#f"ai-transcribe/split/{job_id}-{task_id}.json"
997
+ ret = upload_data_to_r2(json.dumps(result), "intermediate", filekey)
998
+ if ret:
999
+ return {"filekey": filekey}
1000
+ else:
1001
+ return {"error": "Failed to upload to R2"}
1002
 
1003
  except Exception as e:
1004
  import traceback
1005
  traceback.print_exc()
1006
  return {"error": f"Processing failed: {str(e)}"}
1007
+ finally:
1008
+ # Clean up preprocessed wav
1009
+ if pre_meta and pre_meta["chunk"].get("out_wav_path") and os.path.exists(pre_meta["chunk"]["out_wav_path"]):
1010
+ try:
1011
+ os.unlink(pre_meta["chunk"]["out_wav_path"])
1012
+ except Exception:
1013
+ pass
1014
+
1015
+ @spaces.GPU
1016
+ def transcribe_segments(self, pre_meta, language=None,
1017
+ translate=False, prompt=None, batch_size=8, model_name: str = DEFAULT_MODEL, transcribe_options: dict = None):
1018
+ """Main processing function with diarization using task JSON for a single chunk.
1019
+
1020
+ Transcribes full (preprocessed) audio once, performs diarization, merges speakers into transcription.
1021
+ """
1022
+ try:
1023
+ print("Transcribing segments...")
1024
+
1025
+ # Step 1: Preprocess per chunk JSON
1026
+ chunks = pre_meta["segments"]
1027
+ for chunk in chunks:
1028
+ if chunk.get("skip"):
1029
+ return {"segments": [], "language": "unknown", "num_speakers": 0, "transcription_method": "diarized_segments_batched", "batch_size": batch_size}
1030
+ wav_path = chunk["out_wav_path"]
1031
+ base_offset_s = float(chunk.get("abs_start_ms", 0)) / 1000.0
1032
+
1033
+ # Step 2: Transcribe full audio once
1034
+ transcription_results, detected_language = self.transcribe_full_audio(
1035
+ wav_path, language, translate, prompt, batch_size, base_offset_s=base_offset_s, clip_timestamps=None, model_name=model_name, transcribe_options=transcribe_options
1036
+ )
1037
+
1038
+ # Step 6: Return results
1039
+ result = {
1040
+ "chunk_idx": chunk["chunk_idx"],
1041
+ "channel": chunk["channel"],
1042
+ "job_id": pre_meta["job_id"],
1043
+ "segments": transcription_results,
1044
+ "language": detected_language,
1045
+ "batch_size": batch_size,
1046
+ }
1047
+ # job_id = pre_meta["job_id"]
1048
+ # task_id = pre_meta["chunk_idx"]
1049
+ filekey = pre_meta["filekey"]#f"ai-transcribe/split/{job_id}-{task_id}.json"
1050
+ ret = upload_data_to_r2(json.dumps(result), "intermediate", filekey)
1051
+ if ret:
1052
+ return {"filekey": filekey}
1053
+ else:
1054
+ return {"error": "Failed to upload to R2"}
1055
+
1056
+ except Exception as e:
1057
+ import traceback
1058
+ traceback.print_exc()
1059
+ return {"error": f"Processing failed: {str(e)}"}
1060
+ finally:
1061
+ # Clean up preprocessed wav
1062
+ if pre_meta and pre_meta["segments"]:
1063
+ for chunk in pre_meta["segments"]:
1064
+ if chunk.get("out_wav_path") and os.path.exists(chunk["out_wav_path"]):
1065
+ try:
1066
+ os.unlink(chunk["out_wav_path"])
1067
+ except Exception:
1068
+ pass
1069
+
1070
+ @spaces.GPU # each call gets a GPU slice
1071
+ def process_audio_diarization(self, task_json, num_speakers=0):
1072
+ """Process audio for diarization only, returning speaker information.
1073
+
1074
+ Args:
1075
+ task_json: Task JSON containing audio processing information
1076
+ num_speakers: Number of speakers (0 for auto-detection)
1077
+
1078
+ Returns:
1079
+ str: filekey of uploaded JSON file containing diarization results
1080
+ """
1081
+ if not task_json or not str(task_json).strip():
1082
+ return {"error": "No JSON provided"}
1083
+
1084
+ pre_meta = None
1085
+ try:
1086
+ print("Starting diarization-only pipeline...")
1087
+
1088
+ # Step 1: Preprocess from task JSON
1089
+ print("Preprocessing chunk JSON...")
1090
+ pre_meta = self.preprocess_from_task_json(task_json)
1091
+ if pre_meta.get("skip"):
1092
+ # Return minimal result for skipped audio
1093
+ task = json.loads(task_json)
1094
+ job_id = task.get("job_id", "job")
1095
+ task_id = str(task["chunk"]["idx"])
1096
+
1097
+ result = {
1098
+ "num_speakers": 0,
1099
+ "speaker_embeddings": {}
1100
+ }
1101
+
1102
+ filekey = pre_meta["filekey"]#f"ai-transcribe/split/{job_id}-{task_id}-diarization.json"
1103
+ ret = upload_data_to_r2(json.dumps(result), "intermediate", filekey)
1104
+ if ret:
1105
+ return filekey
1106
+ else:
1107
+ return {"error": "Failed to upload to R2"}
1108
+
1109
+ wav_path = pre_meta["chunk"]["out_wav_path"]
1110
+ base_offset_s = float(pre_meta["chunk"].get("abs_start_ms", 0)) / 1000.0
1111
+
1112
+ # Step 2: Perform diarization
1113
+ print("Performing diarization...")
1114
+ start_time = time.time()
1115
+ diarization_segments, detected_num_speakers, speaker_embeddings = self.perform_diarization(
1116
+ wav_path, num_speakers if num_speakers > 0 else None, base_offset_s=base_offset_s
1117
+ )
1118
+ diarization_time = time.time() - start_time
1119
+ print(f"Diarization completed in {diarization_time:.2f} seconds")
1120
+ # Step 3: Compose JSON response
1121
+ result = {
1122
+ "num_speakers": detected_num_speakers,
1123
+ "speaker_embeddings": speaker_embeddings,
1124
+ "diarization_segments": diarization_segments,
1125
+
1126
+ }
1127
+ if pre_meta.get("channel", None):
1128
+ result["channel"] = pre_meta["channel"]
1129
+ # set channel in each diarization segment
1130
+ for seg in diarization_segments:
1131
+ seg["channel"] = pre_meta["channel"]
1132
+
1133
+ # Step 4: Upload to R2
1134
+ #job_id = pre_meta["job_id"]
1135
+ #task_id = pre_meta["chunk_idx"]
1136
+ #filekey = f"ai-transcribe/split/{job_id}-{task_id}-diarization.json"
1137
+ filekey = pre_meta["filekey"]
1138
+ ret = upload_data_to_r2(json.dumps(result), "intermediate", filekey)
1139
+ if ret:
1140
+ # Step 5: Return filekey
1141
+ return filekey
1142
+ else:
1143
+ return {"error": "Failed to upload to R2"}
1144
+
1145
+ except Exception as e:
1146
+ import traceback
1147
+ traceback.print_exc()
1148
+ return {"error": f"Diarization processing failed: {str(e)}"}
1149
  finally:
1150
  # Clean up preprocessed wav
1151
  if pre_meta and pre_meta.get("out_wav_path") and os.path.exists(pre_meta["out_wav_path"]):
 
1153
  os.unlink(pre_meta["out_wav_path"])
1154
  except Exception:
1155
  pass
1156
+
1157
  @spaces.GPU # each call gets a GPU slice
1158
  def process_audio(self, task_json, num_speakers=None, language=None,
1159
+ translate=False, prompt=None, group_segments=True, batch_size=8, model_name: str = DEFAULT_MODEL):
1160
  """Main processing function with diarization using task JSON for a single chunk.
1161
 
1162
  Transcribes full (preprocessed) audio once, performs diarization, merges speakers into transcription.
 
1194
 
1195
  # Step 2: Transcribe full audio once
1196
  transcription_results, detected_language = self.transcribe_full_audio(
1197
+ wav_path, language, translate, prompt, batch_size, base_offset_s=base_offset_s, clip_timestamps=None, model_name=model_name
1198
  )
1199
 
1200
  unmatched_diarization_segments = []
 
1232
  prompt=prompt,
1233
  batch_size=batch_size,
1234
  base_offset_s=d_start,
1235
+ model_name=model_name
1236
  )
1237
  extra_segments.extend(seg_transcription)
1238
  finally:
 
1317
  return output
1318
 
1319
  @spaces.GPU
1320
+ def process_audio_gradio(task_json, num_speakers, language, translate, prompt, group_segments, use_diarization, batch_size, model_name):
1321
  """Gradio interface function"""
1322
+
1323
+ result = transcriber.process_audio_transcribe(
1324
  task_json=task_json,
1325
  num_speakers=num_speakers if num_speakers > 0 else None,
1326
  language=language if language != "auto" else None,
1327
  translate=translate,
1328
  prompt=prompt if prompt and prompt.strip() else None,
1329
  group_segments=group_segments,
1330
+ batch_size=batch_size,
1331
+ model_name=model_name
1332
  )
1333
+ '''
1334
+ result = transcriber.process_audio_transcribe(
1335
  task_json=task_json,
1336
  language=language if language != "auto" else None,
1337
  translate=translate,
1338
  prompt=prompt if prompt and prompt.strip() else None,
1339
+ batch_size=batch_size,
1340
+ model_name=model_name
1341
  )
1342
+ '''
1343
+ #formatted_output = format_segments_for_display(result)
1344
+ return "OK", result
1345
 
1346
  # Create Gradio interface
1347
  demo = gr.Blocks(
 
1368
  )
1369
 
1370
  with gr.Accordion("⚙️ Advanced Settings", open=False):
1371
+ model_name_dropdown = gr.Dropdown(
1372
+ label="Whisper Model",
1373
+ choices=list(MODELS.keys()),
1374
+ value=DEFAULT_MODEL,
1375
+ info="Select the Whisper model to use for transcription."
1376
+ )
1377
+
1378
  use_diarization = gr.Checkbox(
1379
  label="Enable Speaker Diarization",
1380
  value=True,
 
1452
  prompt,
1453
  group_segments,
1454
  use_diarization,
1455
+ batch_size,
1456
+ model_name_dropdown
1457
  ],
1458
  outputs=[output_text, output_json]
1459
  )