Spaces:
Runtime error
Runtime error
liuyang
commited on
Commit
ยท
70438f0
1
Parent(s):
7c60c3b
Implement audio preprocessing and speaker diarization enhancements in WhisperTranscriber. Introduce methods for audio chunk preparation, VAD-based trimming, and speaker embedding extraction. Update process_audio methods to utilize task JSON for improved workflow and metadata handling. Add webrtcvad dependency for voice activity detection.
Browse files- app.py +365 -187
- requirements.txt +2 -1
app.py
CHANGED
|
@@ -32,9 +32,11 @@ from faster_whisper import WhisperModel, BatchedInferencePipeline
|
|
| 32 |
from faster_whisper.vad import VadOptions
|
| 33 |
import requests
|
| 34 |
import base64
|
| 35 |
-
from pyannote.audio import Pipeline
|
| 36 |
|
| 37 |
import os, sys, importlib.util, pathlib, ctypes, tempfile, wave, math
|
|
|
|
|
|
|
| 38 |
spec = importlib.util.find_spec("nvidia.cudnn")
|
| 39 |
if spec is None:
|
| 40 |
sys.exit("โ nvidia-cudnn-cu12 wheel not found. Run: pip install nvidia-cudnn-cu12")
|
|
@@ -53,6 +55,183 @@ from huggingface_hub import snapshot_download
|
|
| 53 |
MODEL_REPO = "deepdml/faster-whisper-large-v3-turbo-ct2" # CT2 format
|
| 54 |
LOCAL_DIR = f"{CACHE_ROOT}/whisper_turbo"
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
# Download once; later runs are instant
|
| 57 |
snapshot_download(
|
| 58 |
repo_id=MODEL_REPO,
|
|
@@ -66,6 +245,7 @@ model_cache_path = LOCAL_DIR # <โโ this is what we pass to WhisperModel
|
|
| 66 |
_whisper = None
|
| 67 |
_batched_whisper = None
|
| 68 |
_diarizer = None
|
|
|
|
| 69 |
|
| 70 |
# Create global diarization pipeline
|
| 71 |
try:
|
|
@@ -108,24 +288,20 @@ class WhisperTranscriber:
|
|
| 108 |
# do **not** create the models here!
|
| 109 |
pass
|
| 110 |
|
| 111 |
-
def
|
| 112 |
-
"""
|
| 113 |
-
temp_wav = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
| 114 |
-
temp_wav_path = temp_wav.name
|
| 115 |
-
temp_wav.close()
|
| 116 |
-
|
| 117 |
try:
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
|
| 127 |
@spaces.GPU # each call gets a GPU slice
|
| 128 |
-
def transcribe_full_audio(self, audio_path, language=None, translate=False, prompt=None, batch_size=16):
|
| 129 |
"""Transcribe the entire audio file without speaker diarization using batched inference"""
|
| 130 |
whisper, batched_whisper, _ = _load_models() # models live on the GPU
|
| 131 |
|
|
@@ -168,16 +344,16 @@ class WhisperTranscriber:
|
|
| 168 |
if seg.words:
|
| 169 |
for word in seg.words:
|
| 170 |
words_list.append({
|
| 171 |
-
"start": float(word.start),
|
| 172 |
-
"end": float(word.end),
|
| 173 |
"word": word.word,
|
| 174 |
"probability": word.probability,
|
| 175 |
"speaker": "SPEAKER_00" # No speaker identification in full transcription
|
| 176 |
})
|
| 177 |
|
| 178 |
results.append({
|
| 179 |
-
"start": float(seg.start),
|
| 180 |
-
"end": float(seg.end),
|
| 181 |
"text": seg.text,
|
| 182 |
"speaker": "SPEAKER_00", # Single speaker assumption
|
| 183 |
"avg_logprob": seg.avg_logprob,
|
|
@@ -190,118 +366,14 @@ class WhisperTranscriber:
|
|
| 190 |
print(results)
|
| 191 |
return results, detected_language
|
| 192 |
|
| 193 |
-
|
| 194 |
-
"""Cut audio into segments based on diarization results"""
|
| 195 |
-
print("Cutting audio into segments...")
|
| 196 |
-
|
| 197 |
-
# Load the full audio
|
| 198 |
-
waveform, sample_rate = torchaudio.load(audio_path)
|
| 199 |
-
|
| 200 |
-
audio_segments = []
|
| 201 |
-
for segment in diarization_segments:
|
| 202 |
-
start_sample = int(segment["start"] * sample_rate)
|
| 203 |
-
end_sample = int(segment["end"] * sample_rate)
|
| 204 |
-
|
| 205 |
-
# Extract the segment
|
| 206 |
-
segment_waveform = waveform[:, start_sample:end_sample]
|
| 207 |
-
|
| 208 |
-
# Create temporary file for this segment
|
| 209 |
-
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
| 210 |
-
temp_file.close()
|
| 211 |
-
|
| 212 |
-
# Save the segment
|
| 213 |
-
torchaudio.save(temp_file.name, segment_waveform, sample_rate)
|
| 214 |
-
|
| 215 |
-
audio_segments.append({
|
| 216 |
-
"audio_path": temp_file.name,
|
| 217 |
-
"start": segment["start"],
|
| 218 |
-
"end": segment["end"],
|
| 219 |
-
"speaker": segment["speaker"]
|
| 220 |
-
})
|
| 221 |
-
|
| 222 |
-
return audio_segments
|
| 223 |
|
| 224 |
@spaces.GPU # each call gets a GPU slice
|
| 225 |
-
|
| 226 |
-
"""Transcribe multiple audio segments using faster_whisper with batching"""
|
| 227 |
-
whisper, batched_whisper, _ = _load_models() # models live on the GPU
|
| 228 |
-
|
| 229 |
-
print(f"Transcribing {len(audio_segments)} audio segments with batch size {batch_size}...")
|
| 230 |
-
start_time = time.time()
|
| 231 |
-
|
| 232 |
-
# Prepare options
|
| 233 |
-
options = dict(
|
| 234 |
-
language=language,
|
| 235 |
-
beam_size=5,
|
| 236 |
-
vad_filter=True,
|
| 237 |
-
vad_parameters=VadOptions(
|
| 238 |
-
max_speech_duration_s=whisper.feature_extractor.chunk_length,
|
| 239 |
-
min_speech_duration_ms=100,
|
| 240 |
-
speech_pad_ms=100,
|
| 241 |
-
threshold=0.25,
|
| 242 |
-
neg_threshold=0.2,
|
| 243 |
-
),
|
| 244 |
-
word_timestamps=True,
|
| 245 |
-
initial_prompt=prompt,
|
| 246 |
-
language_detection_segments=1,
|
| 247 |
-
task="translate" if translate else "transcribe",
|
| 248 |
-
)
|
| 249 |
-
|
| 250 |
-
results = []
|
| 251 |
-
detected_language = None
|
| 252 |
-
|
| 253 |
-
for i, segment in enumerate(audio_segments):
|
| 254 |
-
print(f"Processing segment {i+1}/{len(audio_segments)}")
|
| 255 |
-
|
| 256 |
-
# Use batched inference for each segment
|
| 257 |
-
segments, transcript_info = batched_whisper.transcribe(
|
| 258 |
-
segment["audio_path"],
|
| 259 |
-
batch_size=batch_size,
|
| 260 |
-
**options
|
| 261 |
-
)
|
| 262 |
-
segments = list(segments)
|
| 263 |
-
|
| 264 |
-
# Get detected language from first segment
|
| 265 |
-
if detected_language is None:
|
| 266 |
-
detected_language = transcript_info.language
|
| 267 |
-
|
| 268 |
-
# Process each transcribed segment
|
| 269 |
-
for seg in segments:
|
| 270 |
-
# Create result entry with detailed format
|
| 271 |
-
words_list = []
|
| 272 |
-
if seg.words:
|
| 273 |
-
for word in seg.words:
|
| 274 |
-
words_list.append({
|
| 275 |
-
"start": float(word.start) + segment["start"],
|
| 276 |
-
"end": float(word.end) + segment["start"],
|
| 277 |
-
"word": word.word,
|
| 278 |
-
"probability": word.probability,
|
| 279 |
-
"speaker": segment["speaker"]
|
| 280 |
-
})
|
| 281 |
-
|
| 282 |
-
results.append({
|
| 283 |
-
"start": float(seg.start) + segment["start"],
|
| 284 |
-
"end": float(seg.end) + segment["start"],
|
| 285 |
-
"text": seg.text,
|
| 286 |
-
"speaker": segment["speaker"],
|
| 287 |
-
"avg_logprob": seg.avg_logprob,
|
| 288 |
-
"words": words_list,
|
| 289 |
-
"duration": float(seg.end - seg.start)
|
| 290 |
-
})
|
| 291 |
-
|
| 292 |
-
# Clean up temporary files
|
| 293 |
-
for segment in audio_segments:
|
| 294 |
-
if os.path.exists(segment["audio_path"]):
|
| 295 |
-
os.unlink(segment["audio_path"])
|
| 296 |
-
|
| 297 |
-
transcription_time = time.time() - start_time
|
| 298 |
-
print(f"All segments transcribed in {transcription_time:.2f} seconds using batch size {batch_size}")
|
| 299 |
-
|
| 300 |
-
return results, detected_language
|
| 301 |
|
| 302 |
@spaces.GPU # each call gets a GPU slice
|
| 303 |
-
def perform_diarization(self, audio_path, num_speakers=None):
|
| 304 |
-
"""Perform speaker diarization"""
|
| 305 |
_, _, diarizer = _load_models() # models live on the GPU
|
| 306 |
|
| 307 |
if diarizer is None:
|
|
@@ -309,11 +381,20 @@ class WhisperTranscriber:
|
|
| 309 |
# Load audio to get duration
|
| 310 |
waveform, sample_rate = torchaudio.load(audio_path)
|
| 311 |
duration = waveform.shape[1] / sample_rate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
return [{
|
| 313 |
-
"start": 0.0,
|
| 314 |
-
"end": duration,
|
| 315 |
"speaker": "SPEAKER_00"
|
| 316 |
-
}], 1
|
| 317 |
|
| 318 |
print("Starting diarization...")
|
| 319 |
start_time = time.time()
|
|
@@ -330,21 +411,100 @@ class WhisperTranscriber:
|
|
| 330 |
# Convert to list format
|
| 331 |
diarize_segments = []
|
| 332 |
diarization_list = list(diarization.itertracks(yield_label=True))
|
| 333 |
-
|
| 334 |
for turn, _, speaker in diarization_list:
|
| 335 |
diarize_segments.append({
|
| 336 |
-
"start": turn.start,
|
| 337 |
-
"end": turn.end,
|
| 338 |
"speaker": speaker
|
| 339 |
})
|
| 340 |
|
| 341 |
unique_speakers = {speaker for segment in diarize_segments for speaker in [segment["speaker"]]}
|
| 342 |
detected_num_speakers = len(unique_speakers)
|
| 343 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
diarization_time = time.time() - start_time
|
| 345 |
print(f"Diarization completed in {diarization_time:.2f} seconds")
|
| 346 |
|
| 347 |
-
return diarize_segments, detected_num_speakers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
|
| 349 |
def group_segments_by_speaker(self, segments, max_gap=1.0, max_duration=30.0):
|
| 350 |
"""Group consecutive segments from the same speaker"""
|
|
@@ -388,22 +548,27 @@ class WhisperTranscriber:
|
|
| 388 |
return grouped_segments
|
| 389 |
|
| 390 |
@spaces.GPU # each call gets a GPU slice
|
| 391 |
-
def process_audio_full(self,
|
| 392 |
-
"""Process
|
| 393 |
-
if
|
| 394 |
-
return {"error": "No
|
| 395 |
-
|
| 396 |
-
|
| 397 |
try:
|
| 398 |
print("Starting full transcription pipeline...")
|
| 399 |
|
| 400 |
-
# Step 1:
|
| 401 |
-
print("
|
| 402 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
|
| 404 |
# Step 2: Transcribe the entire audio with batching
|
| 405 |
transcription_results, detected_language = self.transcribe_full_audio(
|
| 406 |
-
|
| 407 |
)
|
| 408 |
|
| 409 |
# Step 3: Group segments if requested (based on time gaps and sentence endings)
|
|
@@ -424,38 +589,47 @@ class WhisperTranscriber:
|
|
| 424 |
traceback.print_exc()
|
| 425 |
return {"error": f"Processing failed: {str(e)}"}
|
| 426 |
finally:
|
| 427 |
-
# Clean up
|
| 428 |
-
if
|
| 429 |
-
|
| 430 |
-
|
|
|
|
|
|
|
| 431 |
|
| 432 |
@spaces.GPU # each call gets a GPU slice
|
| 433 |
-
def process_audio(self,
|
| 434 |
translate=False, prompt=None, group_segments=True, batch_size=8):
|
| 435 |
-
"""Main processing function
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
|
|
|
|
|
|
|
|
|
| 440 |
try:
|
| 441 |
print("Starting new processing pipeline...")
|
| 442 |
|
| 443 |
-
# Step 1:
|
| 444 |
-
print("
|
| 445 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
|
| 447 |
-
# Step 2:
|
| 448 |
-
|
| 449 |
-
|
| 450 |
)
|
| 451 |
-
|
| 452 |
-
# Step 3:
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
# Step 4: Transcribe each segment with batching
|
| 456 |
-
transcription_results, detected_language = self.transcribe_audio_segments(
|
| 457 |
-
audio_segments, language, translate, prompt, batch_size
|
| 458 |
)
|
|
|
|
|
|
|
|
|
|
| 459 |
|
| 460 |
# Step 5: Group segments if requested
|
| 461 |
if group_segments:
|
|
@@ -467,7 +641,8 @@ class WhisperTranscriber:
|
|
| 467 |
"language": detected_language,
|
| 468 |
"num_speakers": detected_num_speakers,
|
| 469 |
"transcription_method": "diarized_segments_batched",
|
| 470 |
-
"batch_size": batch_size
|
|
|
|
| 471 |
}
|
| 472 |
|
| 473 |
except Exception as e:
|
|
@@ -475,10 +650,12 @@ class WhisperTranscriber:
|
|
| 475 |
traceback.print_exc()
|
| 476 |
return {"error": f"Processing failed: {str(e)}"}
|
| 477 |
finally:
|
| 478 |
-
# Clean up
|
| 479 |
-
if
|
| 480 |
-
|
| 481 |
-
|
|
|
|
|
|
|
| 482 |
|
| 483 |
# Initialize transcriber
|
| 484 |
transcriber = WhisperTranscriber()
|
|
@@ -515,11 +692,11 @@ def format_segments_for_display(result):
|
|
| 515 |
return output
|
| 516 |
|
| 517 |
@spaces.GPU
|
| 518 |
-
def process_audio_gradio(
|
| 519 |
"""Gradio interface function"""
|
| 520 |
if use_diarization:
|
| 521 |
result = transcriber.process_audio(
|
| 522 |
-
|
| 523 |
num_speakers=num_speakers if num_speakers > 0 else None,
|
| 524 |
language=language if language != "auto" else None,
|
| 525 |
translate=translate,
|
|
@@ -529,7 +706,7 @@ def process_audio_gradio(audio_file, num_speakers, language, translate, prompt,
|
|
| 529 |
)
|
| 530 |
else:
|
| 531 |
result = transcriber.process_audio_full(
|
| 532 |
-
|
| 533 |
language=language if language != "auto" else None,
|
| 534 |
translate=translate,
|
| 535 |
prompt=prompt if prompt and prompt.strip() else None,
|
|
@@ -558,9 +735,10 @@ with demo:
|
|
| 558 |
|
| 559 |
with gr.Row():
|
| 560 |
with gr.Column():
|
| 561 |
-
|
| 562 |
-
label="
|
| 563 |
-
|
|
|
|
| 564 |
)
|
| 565 |
|
| 566 |
with gr.Accordion("โ๏ธ Advanced Settings", open=False):
|
|
@@ -572,7 +750,7 @@ with demo:
|
|
| 572 |
|
| 573 |
batch_size = gr.Slider(
|
| 574 |
minimum=1,
|
| 575 |
-
maximum=
|
| 576 |
value=16,
|
| 577 |
step=1,
|
| 578 |
label="Batch Size",
|
|
@@ -615,7 +793,7 @@ with demo:
|
|
| 615 |
with gr.Column():
|
| 616 |
output_text = gr.Markdown(
|
| 617 |
label="๐ Transcription Results",
|
| 618 |
-
value="
|
| 619 |
)
|
| 620 |
|
| 621 |
output_json = gr.JSON(
|
|
@@ -634,7 +812,7 @@ with demo:
|
|
| 634 |
process_btn.click(
|
| 635 |
fn=process_audio_gradio,
|
| 636 |
inputs=[
|
| 637 |
-
|
| 638 |
num_speakers,
|
| 639 |
language,
|
| 640 |
translate,
|
|
@@ -649,11 +827,11 @@ with demo:
|
|
| 649 |
# Examples
|
| 650 |
gr.Markdown("### ๐ Usage Tips:")
|
| 651 |
gr.Markdown("""
|
| 652 |
-
-
|
| 653 |
-
-
|
| 654 |
-
-
|
| 655 |
-
-
|
| 656 |
-
-
|
| 657 |
""")
|
| 658 |
|
| 659 |
if __name__ == "__main__":
|
|
|
|
| 32 |
from faster_whisper.vad import VadOptions
|
| 33 |
import requests
|
| 34 |
import base64
|
| 35 |
+
from pyannote.audio import Pipeline, Inference
|
| 36 |
|
| 37 |
import os, sys, importlib.util, pathlib, ctypes, tempfile, wave, math
|
| 38 |
+
import json
|
| 39 |
+
import webrtcvad
|
| 40 |
spec = importlib.util.find_spec("nvidia.cudnn")
|
| 41 |
if spec is None:
|
| 42 |
sys.exit("โ nvidia-cudnn-cu12 wheel not found. Run: pip install nvidia-cudnn-cu12")
|
|
|
|
| 55 |
MODEL_REPO = "deepdml/faster-whisper-large-v3-turbo-ct2" # CT2 format
|
| 56 |
LOCAL_DIR = f"{CACHE_ROOT}/whisper_turbo"
|
| 57 |
|
| 58 |
+
# -----------------------------------------------------------------------------
|
| 59 |
+
# Audio preprocess helper (from input_and_preprocess rule)
|
| 60 |
+
# -----------------------------------------------------------------------------
|
| 61 |
+
|
| 62 |
+
TRIM_THRESHOLD_MS = 10_000 # 10 seconds
|
| 63 |
+
DEFAULT_PAD_MS = 250 # safety context around detected speech
|
| 64 |
+
FRAME_MS = 30 # VAD frame
|
| 65 |
+
HANG_MS = 240 # hangover (keep speech "on" after silence)
|
| 66 |
+
VAD_LEVEL = 2 # 0-3
|
| 67 |
+
|
| 68 |
+
def _decode_chunk_to_pcm(task: dict) -> bytes:
|
| 69 |
+
"""Use ffmpeg to decode the chunk to s16le mono @ 16k PCM bytes."""
|
| 70 |
+
src = task["source_uri"]
|
| 71 |
+
ing = task["ingest_recipe"]
|
| 72 |
+
seek = task["ffmpeg_seek"]
|
| 73 |
+
|
| 74 |
+
cmd = [
|
| 75 |
+
"ffmpeg", "-nostdin", "-hide_banner", "-v", "error",
|
| 76 |
+
"-ss", f"{max(0.0, float(seek['pre_ss_sec'])):.3f}",
|
| 77 |
+
"-i", src,
|
| 78 |
+
"-map", "0:a:0",
|
| 79 |
+
"-ss", f"{float(seek['post_ss_sec']):.2f}",
|
| 80 |
+
"-t", f"{float(seek['t_sec']):.3f}",
|
| 81 |
+
]
|
| 82 |
+
|
| 83 |
+
# Optional L/R extraction
|
| 84 |
+
if ing.get("channel_extract_filter"):
|
| 85 |
+
cmd += ["-af", ing["channel_extract_filter"]]
|
| 86 |
+
|
| 87 |
+
# Force mono 16k s16le to stdout
|
| 88 |
+
cmd += ["-ar", "16000", "-ac", "1", "-c:a", "pcm_s16le", "-f", "s16le", "pipe:1"]
|
| 89 |
+
|
| 90 |
+
p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
| 91 |
+
pcm, err = p.communicate()
|
| 92 |
+
if p.returncode != 0:
|
| 93 |
+
raise RuntimeError(f"ffmpeg failed: {err.decode('utf-8', 'ignore')}")
|
| 94 |
+
return pcm
|
| 95 |
+
|
| 96 |
+
def _find_head_tail_speech_ms(
|
| 97 |
+
pcm: bytes,
|
| 98 |
+
sr: int = 16000,
|
| 99 |
+
frame_ms: int = FRAME_MS,
|
| 100 |
+
vad_level: int = VAD_LEVEL,
|
| 101 |
+
hang_ms: int = HANG_MS,
|
| 102 |
+
):
|
| 103 |
+
"""Return (first_ms, last_ms) speech boundaries using webrtcvad with hangover."""
|
| 104 |
+
if not pcm:
|
| 105 |
+
return None, None
|
| 106 |
+
vad = webrtcvad.Vad(int(vad_level))
|
| 107 |
+
bpf = 2 # bytes per sample (s16)
|
| 108 |
+
samples_per_ms = sr // 1000 # 16
|
| 109 |
+
bytes_per_frame = samples_per_ms * bpf * frame_ms
|
| 110 |
+
|
| 111 |
+
n_frames = len(pcm) // bytes_per_frame
|
| 112 |
+
if n_frames == 0:
|
| 113 |
+
return None, None
|
| 114 |
+
|
| 115 |
+
first_ms, last_ms = None, None
|
| 116 |
+
t_ms = 0
|
| 117 |
+
in_speech = False
|
| 118 |
+
silence_run = 0
|
| 119 |
+
|
| 120 |
+
view = memoryview(pcm)[: n_frames * bytes_per_frame]
|
| 121 |
+
for i in range(n_frames):
|
| 122 |
+
frame = view[i * bytes_per_frame : (i + 1) * bytes_per_frame]
|
| 123 |
+
if vad.is_speech(frame, sr):
|
| 124 |
+
if first_ms is None:
|
| 125 |
+
first_ms = t_ms
|
| 126 |
+
in_speech = True
|
| 127 |
+
silence_run = 0
|
| 128 |
+
else:
|
| 129 |
+
if in_speech:
|
| 130 |
+
silence_run += frame_ms
|
| 131 |
+
if silence_run >= hang_ms:
|
| 132 |
+
last_ms = t_ms - (silence_run - hang_ms)
|
| 133 |
+
in_speech = False
|
| 134 |
+
silence_run = 0
|
| 135 |
+
t_ms += frame_ms
|
| 136 |
+
if in_speech:
|
| 137 |
+
last_ms = t_ms
|
| 138 |
+
return first_ms, last_ms
|
| 139 |
+
|
| 140 |
+
def _write_wav(path: str, pcm: bytes, sr: int = 16000):
|
| 141 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
| 142 |
+
with wave.open(path, "wb") as w:
|
| 143 |
+
w.setnchannels(1)
|
| 144 |
+
w.setsampwidth(2) # s16
|
| 145 |
+
w.setframerate(sr)
|
| 146 |
+
w.writeframes(pcm)
|
| 147 |
+
|
| 148 |
+
def prepare_and_save_audio_for_model(task: dict, out_dir: str) -> dict:
|
| 149 |
+
"""
|
| 150 |
+
1) Decode chunk to mono 16k PCM.
|
| 151 |
+
2) Run VAD to locate head/tail silence.
|
| 152 |
+
3) Trim only if head or tail silence >= 10s.
|
| 153 |
+
4) Save the (possibly trimmed) WAV to local file.
|
| 154 |
+
5) Return timing metadata, including 'trimmed_start_ms' to preserve global timestamps.
|
| 155 |
+
"""
|
| 156 |
+
# 0) Names & constants
|
| 157 |
+
sr = 16000
|
| 158 |
+
bpf = 2
|
| 159 |
+
samples_per_ms = sr // 1000
|
| 160 |
+
|
| 161 |
+
def bytes_from_ms(ms: int) -> int:
|
| 162 |
+
return int(ms * samples_per_ms) * bpf
|
| 163 |
+
|
| 164 |
+
ch = task["channel"]
|
| 165 |
+
ck = task["chunk"]
|
| 166 |
+
job = task.get("job_id", "job")
|
| 167 |
+
idx = str(ck["idx"])
|
| 168 |
+
|
| 169 |
+
# 1) Decode chunk
|
| 170 |
+
pcm = _decode_chunk_to_pcm(task)
|
| 171 |
+
planned_dur_ms = int(ck["dur_ms"])
|
| 172 |
+
|
| 173 |
+
# 2) VAD head/tail detection
|
| 174 |
+
first_ms, last_ms = _find_head_tail_speech_ms(pcm, sr=sr)
|
| 175 |
+
head_sil_ms = int(first_ms) if first_ms is not None else planned_dur_ms
|
| 176 |
+
tail_sil_ms = int(planned_dur_ms - last_ms) if last_ms is not None else planned_dur_ms
|
| 177 |
+
|
| 178 |
+
# 3) Decide trimming (only if head or tail >= 10s)
|
| 179 |
+
trim_applied = False
|
| 180 |
+
eff_start_ms = 0
|
| 181 |
+
eff_end_ms = planned_dur_ms
|
| 182 |
+
trimmed_pcm = pcm
|
| 183 |
+
|
| 184 |
+
if (head_sil_ms >= TRIM_THRESHOLD_MS) or (tail_sil_ms >= TRIM_THRESHOLD_MS):
|
| 185 |
+
# If no speech found at all, mark skip
|
| 186 |
+
if first_ms is None or last_ms is None or last_ms <= first_ms:
|
| 187 |
+
out_wav_path = os.path.join(out_dir, f"{job}_{ch}_{idx}_nospeech.wav")
|
| 188 |
+
_write_wav(out_wav_path, b"", sr)
|
| 189 |
+
return {
|
| 190 |
+
"out_wav_path": out_wav_path,
|
| 191 |
+
"sr": sr,
|
| 192 |
+
"trim_applied": False,
|
| 193 |
+
"trimmed_start_ms": 0,
|
| 194 |
+
"head_silence_ms": head_sil_ms,
|
| 195 |
+
"tail_silence_ms": tail_sil_ms,
|
| 196 |
+
"effective_start_ms": 0,
|
| 197 |
+
"effective_dur_ms": 0,
|
| 198 |
+
"abs_start_ms": ck["global_offset_ms"],
|
| 199 |
+
"chunk_idx": idx,
|
| 200 |
+
"channel": ch,
|
| 201 |
+
"skip": True,
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
# Apply padding & slice
|
| 205 |
+
start_ms = max(0, int(first_ms) - DEFAULT_PAD_MS)
|
| 206 |
+
end_ms = min(planned_dur_ms, int(last_ms) + DEFAULT_PAD_MS)
|
| 207 |
+
|
| 208 |
+
if end_ms > start_ms:
|
| 209 |
+
eff_start_ms = start_ms
|
| 210 |
+
eff_end_ms = end_ms
|
| 211 |
+
trimmed_pcm = pcm[bytes_from_ms(start_ms) : bytes_from_ms(end_ms)]
|
| 212 |
+
trim_applied = True
|
| 213 |
+
|
| 214 |
+
# 4) Write WAV to local file (trimmed or original)
|
| 215 |
+
tag = "trim" if trim_applied else "full"
|
| 216 |
+
out_wav_path = os.path.join(out_dir, f"{job}_{ch}_{idx}_{tag}.wav")
|
| 217 |
+
_write_wav(out_wav_path, trimmed_pcm, sr)
|
| 218 |
+
|
| 219 |
+
# 5) Return metadata
|
| 220 |
+
return {
|
| 221 |
+
"out_wav_path": out_wav_path,
|
| 222 |
+
"sr": sr,
|
| 223 |
+
"trim_applied": trim_applied,
|
| 224 |
+
"trimmed_start_ms": eff_start_ms if trim_applied else 0,
|
| 225 |
+
"head_silence_ms": head_sil_ms,
|
| 226 |
+
"tail_silence_ms": tail_sil_ms,
|
| 227 |
+
"effective_start_ms": eff_start_ms,
|
| 228 |
+
"effective_dur_ms": eff_end_ms - eff_start_ms,
|
| 229 |
+
"abs_start_ms": int(ck["global_offset_ms"]) + eff_start_ms,
|
| 230 |
+
"chunk_idx": idx,
|
| 231 |
+
"channel": ch,
|
| 232 |
+
"skip": False if (trim_applied or len(pcm) > 0) else True,
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
# Download once; later runs are instant
|
| 236 |
snapshot_download(
|
| 237 |
repo_id=MODEL_REPO,
|
|
|
|
| 245 |
_whisper = None
|
| 246 |
_batched_whisper = None
|
| 247 |
_diarizer = None
|
| 248 |
+
_embedder = None
|
| 249 |
|
| 250 |
# Create global diarization pipeline
|
| 251 |
try:
|
|
|
|
| 288 |
# do **not** create the models here!
|
| 289 |
pass
|
| 290 |
|
| 291 |
+
def preprocess_from_task_json(self, task_json: str) -> dict:
|
| 292 |
+
"""Parse task JSON and run prepare_and_save_audio_for_model, returning metadata."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
try:
|
| 294 |
+
task = json.loads(task_json)
|
| 295 |
+
except Exception as e:
|
| 296 |
+
raise RuntimeError(f"Invalid JSON: {e}")
|
| 297 |
+
|
| 298 |
+
out_dir = os.path.join(CACHE_ROOT, "preprocessed")
|
| 299 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 300 |
+
meta = prepare_and_save_audio_for_model(task, out_dir)
|
| 301 |
+
return meta
|
| 302 |
|
| 303 |
@spaces.GPU # each call gets a GPU slice
|
| 304 |
+
def transcribe_full_audio(self, audio_path, language=None, translate=False, prompt=None, batch_size=16, base_offset_s: float = 0.0):
|
| 305 |
"""Transcribe the entire audio file without speaker diarization using batched inference"""
|
| 306 |
whisper, batched_whisper, _ = _load_models() # models live on the GPU
|
| 307 |
|
|
|
|
| 344 |
if seg.words:
|
| 345 |
for word in seg.words:
|
| 346 |
words_list.append({
|
| 347 |
+
"start": float(word.start) + float(base_offset_s),
|
| 348 |
+
"end": float(word.end) + float(base_offset_s),
|
| 349 |
"word": word.word,
|
| 350 |
"probability": word.probability,
|
| 351 |
"speaker": "SPEAKER_00" # No speaker identification in full transcription
|
| 352 |
})
|
| 353 |
|
| 354 |
results.append({
|
| 355 |
+
"start": float(seg.start) + float(base_offset_s),
|
| 356 |
+
"end": float(seg.end) + float(base_offset_s),
|
| 357 |
"text": seg.text,
|
| 358 |
"speaker": "SPEAKER_00", # Single speaker assumption
|
| 359 |
"avg_logprob": seg.avg_logprob,
|
|
|
|
| 366 |
print(results)
|
| 367 |
return results, detected_language
|
| 368 |
|
| 369 |
+
# Removed audio cutting; transcription is done once on the full (preprocessed) audio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
|
| 371 |
@spaces.GPU # each call gets a GPU slice
|
| 372 |
+
# Removed segment-wise transcription; using single full-audio transcription
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
|
| 374 |
@spaces.GPU # each call gets a GPU slice
|
| 375 |
+
def perform_diarization(self, audio_path, num_speakers=None, base_offset_s: float = 0.0):
|
| 376 |
+
"""Perform speaker diarization; return segments with global timestamps and per-speaker embeddings."""
|
| 377 |
_, _, diarizer = _load_models() # models live on the GPU
|
| 378 |
|
| 379 |
if diarizer is None:
|
|
|
|
| 381 |
# Load audio to get duration
|
| 382 |
waveform, sample_rate = torchaudio.load(audio_path)
|
| 383 |
duration = waveform.shape[1] / sample_rate
|
| 384 |
+
# Try to compute a single-speaker embedding
|
| 385 |
+
speaker_embeddings = {}
|
| 386 |
+
try:
|
| 387 |
+
embedder = self._load_embedder()
|
| 388 |
+
# waveform is (1, T); embedder expects mono 1D
|
| 389 |
+
emb = embedder({"waveform": waveform.squeeze(0), "sample_rate": sample_rate})
|
| 390 |
+
speaker_embeddings["SPEAKER_00"] = emb.squeeze().tolist()
|
| 391 |
+
except Exception:
|
| 392 |
+
pass
|
| 393 |
return [{
|
| 394 |
+
"start": 0.0 + float(base_offset_s),
|
| 395 |
+
"end": duration + float(base_offset_s),
|
| 396 |
"speaker": "SPEAKER_00"
|
| 397 |
+
}], 1, speaker_embeddings
|
| 398 |
|
| 399 |
print("Starting diarization...")
|
| 400 |
start_time = time.time()
|
|
|
|
| 411 |
# Convert to list format
|
| 412 |
diarize_segments = []
|
| 413 |
diarization_list = list(diarization.itertracks(yield_label=True))
|
| 414 |
+
print(diarization_list)
|
| 415 |
for turn, _, speaker in diarization_list:
|
| 416 |
diarize_segments.append({
|
| 417 |
+
"start": float(turn.start) + float(base_offset_s),
|
| 418 |
+
"end": float(turn.end) + float(base_offset_s),
|
| 419 |
"speaker": speaker
|
| 420 |
})
|
| 421 |
|
| 422 |
unique_speakers = {speaker for segment in diarize_segments for speaker in [segment["speaker"]]}
|
| 423 |
detected_num_speakers = len(unique_speakers)
|
| 424 |
|
| 425 |
+
# Compute per-speaker embeddings by averaging segment embeddings
|
| 426 |
+
speaker_embeddings = {}
|
| 427 |
+
try:
|
| 428 |
+
embedder = self._load_embedder()
|
| 429 |
+
spk_to_embs = {spk: [] for spk in unique_speakers}
|
| 430 |
+
for turn, _, speaker in diarization_list:
|
| 431 |
+
start_sample = int(float(turn.start) * sample_rate)
|
| 432 |
+
end_sample = int(float(turn.end) * sample_rate)
|
| 433 |
+
if end_sample > start_sample:
|
| 434 |
+
seg_wav = waveform[0, start_sample:end_sample].contiguous()
|
| 435 |
+
emb = embedder({"waveform": seg_wav, "sample_rate": sample_rate})
|
| 436 |
+
spk_to_embs[speaker].append(emb.squeeze())
|
| 437 |
+
# average
|
| 438 |
+
for spk, embs in spk_to_embs.items():
|
| 439 |
+
if len(embs) == 0:
|
| 440 |
+
continue
|
| 441 |
+
# stack and mean
|
| 442 |
+
try:
|
| 443 |
+
import torch as _torch
|
| 444 |
+
embs_tensor = _torch.stack([_torch.as_tensor(e) for e in embs], dim=0)
|
| 445 |
+
centroid = embs_tensor.mean(dim=0)
|
| 446 |
+
# L2 normalize
|
| 447 |
+
centroid = centroid / (centroid.norm(p=2) + 1e-12)
|
| 448 |
+
speaker_embeddings[spk] = centroid.cpu().tolist()
|
| 449 |
+
except Exception:
|
| 450 |
+
# fallback to first embedding
|
| 451 |
+
speaker_embeddings[spk] = embs[0].cpu().tolist()
|
| 452 |
+
except Exception:
|
| 453 |
+
pass
|
| 454 |
+
|
| 455 |
diarization_time = time.time() - start_time
|
| 456 |
print(f"Diarization completed in {diarization_time:.2f} seconds")
|
| 457 |
|
| 458 |
+
return diarize_segments, detected_num_speakers, speaker_embeddings
|
| 459 |
+
|
| 460 |
+
def _load_embedder(self):
|
| 461 |
+
"""Lazy-load speaker embedding inference model on GPU."""
|
| 462 |
+
global _embedder
|
| 463 |
+
if _embedder is None:
|
| 464 |
+
# window="whole" to compute one embedding per provided chunk
|
| 465 |
+
_embedder = Inference("pyannote/embedding", window="whole", device=torch.device("cuda"))
|
| 466 |
+
return _embedder
|
| 467 |
+
|
| 468 |
+
def assign_speakers_to_transcription(self, transcription_results, diarization_segments):
|
| 469 |
+
"""Assign speakers to words and segments based on overlap with diarization segments."""
|
| 470 |
+
if not diarization_segments:
|
| 471 |
+
return transcription_results
|
| 472 |
+
# simple helper to find speaker at given time
|
| 473 |
+
def speaker_at(t: float):
|
| 474 |
+
for seg in diarization_segments:
|
| 475 |
+
if seg["start"] <= t < seg["end"]:
|
| 476 |
+
return seg["speaker"]
|
| 477 |
+
# if not inside, return closest segment's speaker
|
| 478 |
+
closest = None
|
| 479 |
+
best = float("inf")
|
| 480 |
+
for seg in diarization_segments:
|
| 481 |
+
if t < seg["start"]:
|
| 482 |
+
d = seg["start"] - t
|
| 483 |
+
elif t > seg["end"]:
|
| 484 |
+
d = t - seg["end"]
|
| 485 |
+
else:
|
| 486 |
+
d = 0.0
|
| 487 |
+
if d < best:
|
| 488 |
+
best = d
|
| 489 |
+
closest = seg
|
| 490 |
+
return closest["speaker"] if closest else "SPEAKER_00"
|
| 491 |
+
|
| 492 |
+
for seg in transcription_results:
|
| 493 |
+
# Assign per-word speakers
|
| 494 |
+
if seg.get("words"):
|
| 495 |
+
speaker_counts = {}
|
| 496 |
+
for w in seg["words"]:
|
| 497 |
+
mid = (float(w["start"]) + float(w["end"])) / 2.0
|
| 498 |
+
spk = speaker_at(mid)
|
| 499 |
+
w["speaker"] = spk
|
| 500 |
+
speaker_counts[spk] = speaker_counts.get(spk, 0) + (float(w["end"]) - float(w["start"]))
|
| 501 |
+
# Segment speaker = speaker with max accumulated word duration
|
| 502 |
+
if speaker_counts:
|
| 503 |
+
seg["speaker"] = max(speaker_counts.items(), key=lambda kv: kv[1])[0]
|
| 504 |
+
else:
|
| 505 |
+
mid = (float(seg["start"]) + float(seg["end"])) / 2.0
|
| 506 |
+
seg["speaker"] = speaker_at(mid)
|
| 507 |
+
return transcription_results
|
| 508 |
|
| 509 |
def group_segments_by_speaker(self, segments, max_gap=1.0, max_duration=30.0):
|
| 510 |
"""Group consecutive segments from the same speaker"""
|
|
|
|
| 548 |
return grouped_segments
|
| 549 |
|
| 550 |
@spaces.GPU # each call gets a GPU slice
|
| 551 |
+
def process_audio_full(self, task_json, language=None, translate=False, prompt=None, group_segments=True, batch_size=16):
|
| 552 |
+
"""Process a single chunk using task JSON (no diarization)."""
|
| 553 |
+
if not task_json or not str(task_json).strip():
|
| 554 |
+
return {"error": "No JSON provided"}
|
| 555 |
+
|
| 556 |
+
pre_meta = None
|
| 557 |
try:
|
| 558 |
print("Starting full transcription pipeline...")
|
| 559 |
|
| 560 |
+
# Step 1: Preprocess per chunk JSON
|
| 561 |
+
print("Preprocessing chunk JSON...")
|
| 562 |
+
pre_meta = self.preprocess_from_task_json(task_json)
|
| 563 |
+
if pre_meta.get("skip"):
|
| 564 |
+
return {"segments": [], "language": "unknown", "num_speakers": 1, "transcription_method": "full_audio_batched", "batch_size": batch_size}
|
| 565 |
+
wav_path = pre_meta["out_wav_path"]
|
| 566 |
+
# Adjust timestamps by trimmed_start_ms: abs_start_ms is already global start for saved file
|
| 567 |
+
base_offset_s = float(pre_meta.get("abs_start_ms", 0)) / 1000.0
|
| 568 |
|
| 569 |
# Step 2: Transcribe the entire audio with batching
|
| 570 |
transcription_results, detected_language = self.transcribe_full_audio(
|
| 571 |
+
wav_path, language, translate, prompt, batch_size, base_offset_s=base_offset_s
|
| 572 |
)
|
| 573 |
|
| 574 |
# Step 3: Group segments if requested (based on time gaps and sentence endings)
|
|
|
|
| 589 |
traceback.print_exc()
|
| 590 |
return {"error": f"Processing failed: {str(e)}"}
|
| 591 |
finally:
|
| 592 |
+
# Clean up preprocessed wav
|
| 593 |
+
if pre_meta and pre_meta.get("out_wav_path") and os.path.exists(pre_meta["out_wav_path"]):
|
| 594 |
+
try:
|
| 595 |
+
os.unlink(pre_meta["out_wav_path"])
|
| 596 |
+
except Exception:
|
| 597 |
+
pass
|
| 598 |
|
| 599 |
@spaces.GPU # each call gets a GPU slice
|
| 600 |
+
def process_audio(self, task_json, num_speakers=None, language=None,
|
| 601 |
translate=False, prompt=None, group_segments=True, batch_size=8):
|
| 602 |
+
"""Main processing function with diarization using task JSON for a single chunk.
|
| 603 |
+
|
| 604 |
+
Transcribes full (preprocessed) audio once, performs diarization, merges speakers into transcription.
|
| 605 |
+
"""
|
| 606 |
+
if not task_json or not str(task_json).strip():
|
| 607 |
+
return {"error": "No JSON provided"}
|
| 608 |
+
|
| 609 |
+
pre_meta = None
|
| 610 |
try:
|
| 611 |
print("Starting new processing pipeline...")
|
| 612 |
|
| 613 |
+
# Step 1: Preprocess per chunk JSON
|
| 614 |
+
print("Preprocessing chunk JSON...")
|
| 615 |
+
pre_meta = self.preprocess_from_task_json(task_json)
|
| 616 |
+
if pre_meta.get("skip"):
|
| 617 |
+
return {"segments": [], "language": "unknown", "num_speakers": 0, "transcription_method": "diarized_segments_batched", "batch_size": batch_size}
|
| 618 |
+
wav_path = pre_meta["out_wav_path"]
|
| 619 |
+
base_offset_s = float(pre_meta.get("abs_start_ms", 0)) / 1000.0
|
| 620 |
|
| 621 |
+
# Step 2: Transcribe full audio once
|
| 622 |
+
transcription_results, detected_language = self.transcribe_full_audio(
|
| 623 |
+
wav_path, language, translate, prompt, batch_size, base_offset_s=base_offset_s
|
| 624 |
)
|
| 625 |
+
|
| 626 |
+
# Step 3: Perform diarization with global offset
|
| 627 |
+
diarization_segments, detected_num_speakers, speaker_embeddings = self.perform_diarization(
|
| 628 |
+
wav_path, num_speakers, base_offset_s=base_offset_s
|
|
|
|
|
|
|
|
|
|
| 629 |
)
|
| 630 |
+
|
| 631 |
+
# Step 4: Merge diarization into transcription (assign speakers)
|
| 632 |
+
transcription_results = self.assign_speakers_to_transcription(transcription_results, diarization_segments)
|
| 633 |
|
| 634 |
# Step 5: Group segments if requested
|
| 635 |
if group_segments:
|
|
|
|
| 641 |
"language": detected_language,
|
| 642 |
"num_speakers": detected_num_speakers,
|
| 643 |
"transcription_method": "diarized_segments_batched",
|
| 644 |
+
"batch_size": batch_size,
|
| 645 |
+
"speaker_embeddings": speaker_embeddings,
|
| 646 |
}
|
| 647 |
|
| 648 |
except Exception as e:
|
|
|
|
| 650 |
traceback.print_exc()
|
| 651 |
return {"error": f"Processing failed: {str(e)}"}
|
| 652 |
finally:
|
| 653 |
+
# Clean up preprocessed wav
|
| 654 |
+
if pre_meta and pre_meta.get("out_wav_path") and os.path.exists(pre_meta["out_wav_path"]):
|
| 655 |
+
try:
|
| 656 |
+
os.unlink(pre_meta["out_wav_path"])
|
| 657 |
+
except Exception:
|
| 658 |
+
pass
|
| 659 |
|
| 660 |
# Initialize transcriber
|
| 661 |
transcriber = WhisperTranscriber()
|
|
|
|
| 692 |
return output
|
| 693 |
|
| 694 |
@spaces.GPU
|
| 695 |
+
def process_audio_gradio(task_json, num_speakers, language, translate, prompt, group_segments, use_diarization, batch_size):
|
| 696 |
"""Gradio interface function"""
|
| 697 |
if use_diarization:
|
| 698 |
result = transcriber.process_audio(
|
| 699 |
+
task_json=task_json,
|
| 700 |
num_speakers=num_speakers if num_speakers > 0 else None,
|
| 701 |
language=language if language != "auto" else None,
|
| 702 |
translate=translate,
|
|
|
|
| 706 |
)
|
| 707 |
else:
|
| 708 |
result = transcriber.process_audio_full(
|
| 709 |
+
task_json=task_json,
|
| 710 |
language=language if language != "auto" else None,
|
| 711 |
translate=translate,
|
| 712 |
prompt=prompt if prompt and prompt.strip() else None,
|
|
|
|
| 735 |
|
| 736 |
with gr.Row():
|
| 737 |
with gr.Column():
|
| 738 |
+
task_json_input = gr.Textbox(
|
| 739 |
+
label="๐งพ Paste Task JSON",
|
| 740 |
+
placeholder="Paste the per-chunk task JSON here...",
|
| 741 |
+
lines=16,
|
| 742 |
)
|
| 743 |
|
| 744 |
with gr.Accordion("โ๏ธ Advanced Settings", open=False):
|
|
|
|
| 750 |
|
| 751 |
batch_size = gr.Slider(
|
| 752 |
minimum=1,
|
| 753 |
+
maximum=128,
|
| 754 |
value=16,
|
| 755 |
step=1,
|
| 756 |
label="Batch Size",
|
|
|
|
| 793 |
with gr.Column():
|
| 794 |
output_text = gr.Markdown(
|
| 795 |
label="๐ Transcription Results",
|
| 796 |
+
value="Paste task JSON and click 'Transcribe Audio' to get started!"
|
| 797 |
)
|
| 798 |
|
| 799 |
output_json = gr.JSON(
|
|
|
|
| 812 |
process_btn.click(
|
| 813 |
fn=process_audio_gradio,
|
| 814 |
inputs=[
|
| 815 |
+
task_json_input,
|
| 816 |
num_speakers,
|
| 817 |
language,
|
| 818 |
translate,
|
|
|
|
| 827 |
# Examples
|
| 828 |
gr.Markdown("### ๐ Usage Tips:")
|
| 829 |
gr.Markdown("""
|
| 830 |
+
- Paste a single-chunk task JSON matching the preprocess schema
|
| 831 |
+
- Batch Size: Higher values (16-24) = faster but uses more GPU memory
|
| 832 |
+
- Speaker diarization: Enable for speaker identification (slower)
|
| 833 |
+
- Languages: Supports 100+ languages with auto-detection
|
| 834 |
+
- Vocabulary: Add names and technical terms in the prompt for better accuracy
|
| 835 |
""")
|
| 836 |
|
| 837 |
if __name__ == "__main__":
|
requirements.txt
CHANGED
|
@@ -19,4 +19,5 @@ librosa>=0.10.0
|
|
| 19 |
soundfile>=0.12.0
|
| 20 |
ffmpeg-python>=0.2.0
|
| 21 |
requests>=2.28.0
|
| 22 |
-
nvidia-cudnn-cu12==9.1.0.70 # any 9.1.x that pip can find is fine
|
|
|
|
|
|
| 19 |
soundfile>=0.12.0
|
| 20 |
ffmpeg-python>=0.2.0
|
| 21 |
requests>=2.28.0
|
| 22 |
+
nvidia-cudnn-cu12==9.1.0.70 # any 9.1.x that pip can find is fine
|
| 23 |
+
webrtcvad>=2.0.10
|