|
|
import gradio as gr |
|
|
import torch |
|
|
import torchaudio |
|
|
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model_name = "Hatman/audio-emotion-detection" |
|
|
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name) |
|
|
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
EMOTION_LABELS = { |
|
|
0: "angry", |
|
|
1: "disgust", |
|
|
2: "fear", |
|
|
3: "happy", |
|
|
4: "neutral", |
|
|
5: "sad", |
|
|
6: "surprise" |
|
|
} |
|
|
|
|
|
def preprocess_audio(audio): |
|
|
"""Preprocess audio file for model input""" |
|
|
waveform, sampling_rate = torchaudio.load(audio) |
|
|
resampled_waveform = torchaudio.transforms.Resample( |
|
|
orig_freq=sampling_rate, |
|
|
new_freq=16000 |
|
|
)(waveform) |
|
|
return { |
|
|
'speech': resampled_waveform.numpy().flatten(), |
|
|
'sampling_rate': 16000 |
|
|
} |
|
|
|
|
|
def inference(audio): |
|
|
"""Full inference function returning emotion, logits, and predicted IDs""" |
|
|
example = preprocess_audio(audio) |
|
|
inputs = feature_extractor( |
|
|
example['speech'], |
|
|
sampling_rate=16000, |
|
|
return_tensors="pt", |
|
|
padding=True |
|
|
) |
|
|
|
|
|
|
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
logits = outputs.logits |
|
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
|
|
|
|
predicted_emotion = EMOTION_LABELS[predicted_ids.item()] |
|
|
return predicted_emotion, logits.tolist(), predicted_ids.tolist() |
|
|
|
|
|
def inference_label(audio): |
|
|
"""Simplified inference function returning only the emotion label""" |
|
|
emotion, _, _ = inference(audio) |
|
|
return emotion |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# Audio Emotion Detection") |
|
|
|
|
|
with gr.Tab("Quick Analysis"): |
|
|
gr.Interface( |
|
|
fn=inference_label, |
|
|
inputs=gr.Audio(type="filepath"), |
|
|
outputs=gr.Label(label="Detected Emotion"), |
|
|
title="Audio Emotion Analysis", |
|
|
description="Upload or record audio to detect the emotional content." |
|
|
) |
|
|
|
|
|
with gr.Tab("Detailed Analysis"): |
|
|
gr.Interface( |
|
|
fn=inference, |
|
|
inputs=gr.Audio(type="filepath"), |
|
|
outputs=[ |
|
|
gr.Label(label="Detected Emotion"), |
|
|
gr.JSON(label="Confidence Scores"), |
|
|
gr.JSON(label="Internal IDs") |
|
|
], |
|
|
title="Audio Emotion Analysis (Detailed)", |
|
|
description="Get detailed analysis including confidence scores for each emotion." |
|
|
) |
|
|
|
|
|
|
|
|
demo.launch(share=True) |