luckyhookin's picture
update
6d43bc6
from io import BytesIO
import os
import gradio as gr
import spaces
from pydub import AudioSegment
import json
import requests
from nemo.collections.asr.models import SortformerEncLabelModel
diar_model = SortformerEncLabelModel.from_pretrained("nvidia/diar_streaming_sortformer_4spk-v2.1")
diar_model.eval()
diar_model.sortformer_modules.chunk_len = 340
diar_model.sortformer_modules.chunk_right_context = 40
diar_model.sortformer_modules.fifo_len = 40
diar_model.sortformer_modules.spkcache_update_period = 300
diar_model.sortformer_modules.spkcache_len = 188
diar_model.sortformer_modules._check_streaming_parameters()
def preprocess_audio(audio_path):
"""Convert audio to mono, 16kHz WAV format suitable for pyannote."""
try:
if isinstance(audio_path, str):
bytes = False
else:
bytes = True
# Load audio with pydub
audio = AudioSegment.from_file(BytesIO(audio_path) if bytes else audio_path)
# Convert to mono and set sample rate to 16kHz
audio = audio.set_channels(1).set_frame_rate(16000)
# Export to temporary WAV file
temp_wav = "temp_audio.wav"
audio.export(temp_wav, format="wav")
return temp_wav
except Exception as e:
raise ValueError(f"Error preprocessing audio: {str(e)}")
def handle_audio(url, audio_path):
"""Handle audio processing and diarization."""
if url:
response = requests.get(url, timeout=60)
audio_path = response.content
audio_path = preprocess_audio(audio_path)
res = diarize_audio_diar1(audio_path)
# Clean up temporary file
if os.path.exists(audio_path):
os.remove(audio_path)
return json.dumps(res)
@spaces.GPU(duration=120)
def diarize_audio_diar1(audio_path):
"""Perform speaker diarization and return formatted results."""
try:
predicted_segments = diar_model.diarize(audio=audio_path, batch_size=1)
return format_results(predicted_segments[0])
except Exception as e:
return f"Error: {str(e)}", ""
def format_results(results):
"""Format results into a readable string."""
if isinstance(results, str):
import json
results = json.loads(results)
if not isinstance(results, list):
return []
formatted_results = []
for item in results:
if isinstance(item, str):
parts = item.strip().split()
if len(parts) == 3:
formatted_results.append({
"start": float(parts[0]),
"end": float(parts[1]),
"speaker_id": parts[2]
})
elif isinstance(item, dict):
formatted_results.append({
"start": item.get("start", 0),
"end": item.get("end", 0),
"speaker_id": item.get("speaker", item.get("speaker_id", "unknown"))
})
formatted_results.sort(key=lambda x: x["start"])
return formatted_results
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Speaker Diarization with nvidia/diar_streaming_sortformer_4spk-v2.1")
gr.Markdown("Upload an audio file and specify the number of speakers to diarize the audio.")
with gr.Row():
url_input = gr.Textbox(label="URL")
audio_input = gr.Audio(label="Upload Audio File", type="filepath")
submit_btn = gr.Button("Diarize")
with gr.Row():
json_output = gr.Textbox(label="Diarization Results (JSON)")
submit_btn.click(
fn=handle_audio,
inputs=[url_input, audio_input],
outputs=[json_output],
concurrency_limit=20,
)
# Launch the Gradio app
demo.launch()