File size: 3,736 Bytes
e1e635d
 
 
 
 
 
 
 
 
6d43bc6
e1e635d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d43bc6
e1e635d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from io import BytesIO
import os
import gradio as gr
import spaces
from pydub import AudioSegment
import json
import requests
from nemo.collections.asr.models import SortformerEncLabelModel

diar_model = SortformerEncLabelModel.from_pretrained("nvidia/diar_streaming_sortformer_4spk-v2.1")
diar_model.eval()

diar_model.sortformer_modules.chunk_len = 340
diar_model.sortformer_modules.chunk_right_context = 40
diar_model.sortformer_modules.fifo_len = 40
diar_model.sortformer_modules.spkcache_update_period = 300
diar_model.sortformer_modules.spkcache_len = 188
diar_model.sortformer_modules._check_streaming_parameters()

def preprocess_audio(audio_path):
    """Convert audio to mono, 16kHz WAV format suitable for pyannote."""
    try:
        if isinstance(audio_path, str):
            bytes = False
        else:
            bytes = True
            
        # Load audio with pydub
        audio = AudioSegment.from_file(BytesIO(audio_path) if bytes else audio_path)
        # Convert to mono and set sample rate to 16kHz
        audio = audio.set_channels(1).set_frame_rate(16000)
        # Export to temporary WAV file
        temp_wav = "temp_audio.wav"
        audio.export(temp_wav, format="wav")
        return temp_wav
    except Exception as e:
        raise ValueError(f"Error preprocessing audio: {str(e)}")

def handle_audio(url, audio_path):
    """Handle audio processing and diarization."""
    if url:
        response = requests.get(url, timeout=60)
        audio_path = response.content

    audio_path = preprocess_audio(audio_path)
    
    res = diarize_audio_diar1(audio_path)

    # Clean up temporary file
    if os.path.exists(audio_path):
        os.remove(audio_path)
    return json.dumps(res)

    
@spaces.GPU(duration=120)
def diarize_audio_diar1(audio_path):
    """Perform speaker diarization and return formatted results."""
    try:
        predicted_segments = diar_model.diarize(audio=audio_path, batch_size=1)
        return format_results(predicted_segments[0])

    except Exception as e:
        return f"Error: {str(e)}", ""

def format_results(results):
    """Format results into a readable string."""
    
    if isinstance(results, str):
        import json
        results = json.loads(results)
    
    if not isinstance(results, list):
        return []
    
    formatted_results = []
    for item in results:
        if isinstance(item, str):
            parts = item.strip().split()
            if len(parts) == 3:
                formatted_results.append({
                    "start": float(parts[0]),
                    "end": float(parts[1]),
                    "speaker_id": parts[2]
                })
        elif isinstance(item, dict):
            formatted_results.append({
                "start": item.get("start", 0),
                "end": item.get("end", 0),
                "speaker_id": item.get("speaker", item.get("speaker_id", "unknown"))
            })
    
    formatted_results.sort(key=lambda x: x["start"])
    
    return formatted_results

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Speaker Diarization with nvidia/diar_streaming_sortformer_4spk-v2.1")
    gr.Markdown("Upload an audio file and specify the number of speakers to diarize the audio.")
    
    with gr.Row():
        url_input = gr.Textbox(label="URL")
        audio_input = gr.Audio(label="Upload Audio File", type="filepath")
    
    submit_btn = gr.Button("Diarize")
    
    with gr.Row():
        json_output = gr.Textbox(label="Diarization Results (JSON)")
    
    submit_btn.click(
        fn=handle_audio,
        inputs=[url_input, audio_input],
        outputs=[json_output],
        concurrency_limit=20,
    )

# Launch the Gradio app
demo.launch()