omsandeeppatil's picture
Update app.py
0a54d22 verified
raw
history blame
2.73 kB
import gradio as gr
import torch
import torchaudio
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification
# Initialize device and model
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "Hatman/audio-emotion-detection"
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name)
# Define emotion labels
EMOTION_LABELS = {
0: "angry",
1: "disgust",
2: "fear",
3: "happy",
4: "neutral",
5: "sad",
6: "surprise"
}
def preprocess_audio(audio):
"""Preprocess audio file for model input"""
waveform, sampling_rate = torchaudio.load(audio)
resampled_waveform = torchaudio.transforms.Resample(
orig_freq=sampling_rate,
new_freq=16000
)(waveform)
return {
'speech': resampled_waveform.numpy().flatten(),
'sampling_rate': 16000
}
def inference(audio):
"""Full inference function returning emotion, logits, and predicted IDs"""
example = preprocess_audio(audio)
inputs = feature_extractor(
example['speech'],
sampling_rate=16000,
return_tensors="pt",
padding=True
)
# Move inputs to appropriate device
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_ids = torch.argmax(logits, dim=-1)
predicted_emotion = EMOTION_LABELS[predicted_ids.item()]
return predicted_emotion, logits.tolist(), predicted_ids.tolist()
def inference_label(audio):
"""Simplified inference function returning only the emotion label"""
emotion, _, _ = inference(audio)
return emotion
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Audio Emotion Detection")
with gr.Tab("Quick Analysis"):
gr.Interface(
fn=inference_label,
inputs=gr.Audio(type="filepath"),
outputs=gr.Label(label="Detected Emotion"),
title="Audio Emotion Analysis",
description="Upload or record audio to detect the emotional content."
)
with gr.Tab("Detailed Analysis"):
gr.Interface(
fn=inference,
inputs=gr.Audio(type="filepath"),
outputs=[
gr.Label(label="Detected Emotion"),
gr.JSON(label="Confidence Scores"),
gr.JSON(label="Internal IDs")
],
title="Audio Emotion Analysis (Detailed)",
description="Get detailed analysis including confidence scores for each emotion."
)
# Launch the app
demo.launch(share=True)