Spaces:
Running
on
Zero
Running
on
Zero
| from io import BytesIO | |
| import os | |
| import gradio as gr | |
| import spaces | |
| from pydub import AudioSegment | |
| import json | |
| import requests | |
| from nemo.collections.asr.models import SortformerEncLabelModel | |
| diar_model = SortformerEncLabelModel.from_pretrained("nvidia/diar_streaming_sortformer_4spk-v2.1") | |
| diar_model.eval() | |
| diar_model.sortformer_modules.chunk_len = 340 | |
| diar_model.sortformer_modules.chunk_right_context = 40 | |
| diar_model.sortformer_modules.fifo_len = 40 | |
| diar_model.sortformer_modules.spkcache_update_period = 300 | |
| diar_model.sortformer_modules.spkcache_len = 188 | |
| diar_model.sortformer_modules._check_streaming_parameters() | |
| def preprocess_audio(audio_path): | |
| """Convert audio to mono, 16kHz WAV format suitable for pyannote.""" | |
| try: | |
| if isinstance(audio_path, str): | |
| bytes = False | |
| else: | |
| bytes = True | |
| # Load audio with pydub | |
| audio = AudioSegment.from_file(BytesIO(audio_path) if bytes else audio_path) | |
| # Convert to mono and set sample rate to 16kHz | |
| audio = audio.set_channels(1).set_frame_rate(16000) | |
| # Export to temporary WAV file | |
| temp_wav = "temp_audio.wav" | |
| audio.export(temp_wav, format="wav") | |
| return temp_wav | |
| except Exception as e: | |
| raise ValueError(f"Error preprocessing audio: {str(e)}") | |
| def handle_audio(url, audio_path): | |
| """Handle audio processing and diarization.""" | |
| if url: | |
| response = requests.get(url, timeout=60) | |
| audio_path = response.content | |
| audio_path = preprocess_audio(audio_path) | |
| res = diarize_audio_diar1(audio_path) | |
| # Clean up temporary file | |
| if os.path.exists(audio_path): | |
| os.remove(audio_path) | |
| return json.dumps(res) | |
| def diarize_audio_diar1(audio_path): | |
| """Perform speaker diarization and return formatted results.""" | |
| try: | |
| predicted_segments = diar_model.diarize(audio=audio_path, batch_size=1) | |
| return format_results(predicted_segments[0]) | |
| except Exception as e: | |
| return f"Error: {str(e)}", "" | |
| def format_results(results): | |
| """Format results into a readable string.""" | |
| if isinstance(results, str): | |
| import json | |
| results = json.loads(results) | |
| if not isinstance(results, list): | |
| return [] | |
| formatted_results = [] | |
| for item in results: | |
| if isinstance(item, str): | |
| parts = item.strip().split() | |
| if len(parts) == 3: | |
| formatted_results.append({ | |
| "start": float(parts[0]), | |
| "end": float(parts[1]), | |
| "speaker_id": parts[2] | |
| }) | |
| elif isinstance(item, dict): | |
| formatted_results.append({ | |
| "start": item.get("start", 0), | |
| "end": item.get("end", 0), | |
| "speaker_id": item.get("speaker", item.get("speaker_id", "unknown")) | |
| }) | |
| formatted_results.sort(key=lambda x: x["start"]) | |
| return formatted_results | |
| # Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Speaker Diarization with nvidia/diar_streaming_sortformer_4spk-v2.1") | |
| gr.Markdown("Upload an audio file and specify the number of speakers to diarize the audio.") | |
| with gr.Row(): | |
| url_input = gr.Textbox(label="URL") | |
| audio_input = gr.Audio(label="Upload Audio File", type="filepath") | |
| submit_btn = gr.Button("Diarize") | |
| with gr.Row(): | |
| json_output = gr.Textbox(label="Diarization Results (JSON)") | |
| submit_btn.click( | |
| fn=handle_audio, | |
| inputs=[url_input, audio_input], | |
| outputs=[json_output], | |
| concurrency_limit=20, | |
| ) | |
| # Launch the Gradio app | |
| demo.launch() | |