Spaces:
Running
Running
student wav2small 17K params
Browse files
app.py
CHANGED
|
@@ -17,21 +17,6 @@ plt.style.use('seaborn-v0_8-whitegrid')
|
|
| 17 |
|
| 18 |
|
| 19 |
|
| 20 |
-
def _prenorm(x, attention_mask=None):
|
| 21 |
-
'''mean/var'''
|
| 22 |
-
if attention_mask is not None:
|
| 23 |
-
N = attention_mask.sum(1, keepdim=True) # 0=ignored 1=valid
|
| 24 |
-
x -= x.sum(1, keepdim=True) / N
|
| 25 |
-
var = (x * x).sum(1, keepdim=True) / N
|
| 26 |
-
|
| 27 |
-
else:
|
| 28 |
-
x -= x.mean(1, keepdim=True) # mean is an onnx operator reducemean saves some ops compared to casting integer N to float and the div
|
| 29 |
-
var = (x * x).mean(1, keepdim=True)
|
| 30 |
-
return x / torch.sqrt(var + 1e-7)
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
class ADV(nn.Module):
|
| 36 |
|
| 37 |
def __init__(self, config):
|
|
@@ -96,16 +81,275 @@ dawn = Dawn.from_pretrained(
|
|
| 96 |
).to(device).eval()
|
| 97 |
|
| 98 |
|
| 99 |
-
def wav2small(x):
|
| 100 |
-
return .5 * dawn(x) + .5 * base(x)
|
| 101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
-
fig_error, ax = plt.subplots(figsize=(8, 6))
|
| 104 |
|
| 105 |
-
# Set the text to display
|
| 106 |
-
error_message = "Error: No .wav or Mic. audio provided."
|
| 107 |
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
ax.text(0.5, 0.5, error_message,
|
| 110 |
ha='center',
|
| 111 |
va='center',
|
|
@@ -113,125 +357,164 @@ ax.text(0.5, 0.5, error_message,
|
|
| 113 |
color='gray',
|
| 114 |
fontweight='bold',
|
| 115 |
transform=ax.transAxes)
|
| 116 |
-
|
| 117 |
-
# Hide the axis ticks and labels for a cleaner look
|
| 118 |
ax.set_xticks([])
|
| 119 |
ax.set_yticks([])
|
| 120 |
ax.set_xticklabels([])
|
| 121 |
ax.set_yticklabels([])
|
| 122 |
-
|
| 123 |
-
# Optional: Add a border around the text to make it stand out more
|
| 124 |
ax.set_frame_on(True)
|
| 125 |
ax.spines['top'].set_visible(False)
|
| 126 |
ax.spines['right'].set_visible(False)
|
| 127 |
ax.spines['bottom'].set_visible(False)
|
| 128 |
ax.spines['left'].set_visible(False)
|
| 129 |
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
def process_audio(audio_filepath):
|
| 135 |
if audio_filepath is None:
|
| 136 |
-
|
|
|
|
| 137 |
|
| 138 |
-
|
| 139 |
-
waveform, sample_rate = librosa.load(audio_filepath)
|
| 140 |
|
| 141 |
-
#
|
| 142 |
-
|
| 143 |
-
# Resample audio to 16kHz if necessary
|
| 144 |
if sample_rate != 16000:
|
| 145 |
resampled_waveform_np = audresample.resample(waveform, sample_rate, 16000)
|
| 146 |
-
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
| 148 |
with torch.no_grad():
|
|
|
|
| 149 |
logits_dawn = dawn(x).cpu().numpy()[0, :]
|
| 150 |
-
logits_wavlm = base(x).cpu().numpy()[0, :]
|
| 151 |
|
| 152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
-
# left_bars_data = np.array([0.75, 0.5, 0.9])
|
| 155 |
-
# right_bars_data = np.array([0.3, 0.8, 0.65])
|
| 156 |
left_bars_data = logits_dawn.clip(0, 1)
|
| 157 |
right_bars_data = logits_wav2small.clip(0, 1)
|
| 158 |
|
| 159 |
-
|
| 160 |
bar_labels = ['\nArousal', '\nDominance', '\nValence']
|
| 161 |
y_pos = np.arange(len(bar_labels))
|
| 162 |
|
| 163 |
-
# Define
|
| 164 |
-
# Using Greys for Dominance as requested
|
| 165 |
category_colormaps = [plt.cm.Blues, plt.cm.Greys, plt.cm.Oranges]
|
| 166 |
|
| 167 |
-
# Define color shades for left and right for each category
|
| 168 |
left_filled_colors = []
|
| 169 |
right_filled_colors = []
|
| 170 |
background_colors = []
|
| 171 |
|
|
|
|
| 172 |
for i, cmap in enumerate(category_colormaps):
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
# Pick a slightly lighter shade for the right filled bar
|
| 176 |
-
right_filled_colors.append(cmap(0.64)) # 0.5
|
| 177 |
-
# Pick a very light shade for the transparent background bar
|
| 178 |
background_colors.append(cmap(0.1))
|
| 179 |
|
| 180 |
-
#
|
| 181 |
-
fig, ax = plt.subplots(figsize=(10, 6))
|
| 182 |
-
|
| 183 |
-
# Plot the background bars with transparency
|
| 184 |
for i in range(len(bar_labels)):
|
| 185 |
-
# Left background bar (transparent, light shade of category color)
|
| 186 |
ax.barh(y_pos[i], -1, color=background_colors[i], alpha=0.3, height=0.6)
|
| 187 |
-
# Right background bar (transparent, light shade of category color)
|
| 188 |
ax.barh(y_pos[i], 1, color=background_colors[i], alpha=0.3, height=0.6)
|
| 189 |
|
| 190 |
-
# Plot the filled bars for
|
| 191 |
for i in range(len(bar_labels)):
|
| 192 |
-
# Left filled bar (opaque, darker shade of category color)
|
| 193 |
ax.barh(y_pos[i], -left_bars_data[i], color=left_filled_colors[i], alpha=1, height=0.6)
|
| 194 |
-
# Right filled bar (opaque, lighter shade of category color)
|
| 195 |
ax.barh(y_pos[i], right_bars_data[i], color=right_filled_colors[i], alpha=1, height=0.6)
|
| 196 |
|
| 197 |
-
# Add a central axis divider
|
| 198 |
ax.axvline(0, color='black', linewidth=0.8, linestyle='--')
|
| 199 |
|
| 200 |
-
# Set x-axis limits and y-axis ticks
|
| 201 |
ax.set_xlim(-1, 1)
|
| 202 |
ax.set_yticks(y_pos)
|
| 203 |
ax.set_yticklabels(bar_labels, fontsize=12)
|
| 204 |
|
| 205 |
-
|
| 206 |
def abs_tick_formatter(x, pos):
|
| 207 |
return f'{int(abs(x) * 100)}%'
|
| 208 |
ax.xaxis.set_major_formatter(plt.FuncFormatter(abs_tick_formatter))
|
| 209 |
|
| 210 |
-
#
|
| 211 |
ax.set_title('', fontsize=16, pad=20)
|
| 212 |
-
ax.set_xlabel('
|
| 213 |
|
| 214 |
-
# Remove
|
| 215 |
ax.spines['top'].set_visible(False)
|
| 216 |
ax.spines['right'].set_visible(False)
|
| 217 |
ax.spines['left'].set_visible(False)
|
| 218 |
|
| 219 |
-
# Add annotations to the filled bars
|
| 220 |
for i in range(len(bar_labels)):
|
| 221 |
-
# Left annotation (uses left_filled_colors for text color)
|
| 222 |
ax.text(-left_bars_data[i] - 0.05, y_pos[i], f'{int(left_bars_data[i] * 100)}%',
|
| 223 |
va='center', ha='right', color=left_filled_colors[i], fontweight='bold')
|
| 224 |
-
# Right annotation (uses right_filled_colors for text color)
|
| 225 |
ax.text(right_bars_data[i] + 0.05, y_pos[i], f'{int(right_bars_data[i] * 100)}%',
|
| 226 |
va='center', ha='left', color=right_filled_colors[i], fontweight='bold')
|
| 227 |
|
| 228 |
|
| 229 |
-
return fig
|
| 230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
|
| 236 |
|
| 237 |
iface = gr.Interface(
|
|
@@ -242,25 +525,27 @@ iface = gr.Interface(
|
|
| 242 |
label=''
|
| 243 |
),
|
| 244 |
outputs=[
|
| 245 |
-
gr.Plot(label="
|
|
|
|
| 246 |
],
|
| 247 |
title='',
|
| 248 |
description='',
|
| 249 |
-
flagging_mode="never", #
|
| 250 |
examples=[
|
| 251 |
"female-46-neutral.wav",
|
| 252 |
"female-20-happy.wav",
|
| 253 |
"male-60-angry.wav",
|
| 254 |
"male-27-sad.wav",
|
| 255 |
],
|
| 256 |
-
css="footer {visibility: hidden}"
|
| 257 |
)
|
| 258 |
|
|
|
|
| 259 |
with gr.Blocks() as demo:
|
| 260 |
-
|
| 261 |
-
# https://discuss.huggingface.co/t/how-to-get-the-microphone-streaming-input-file-when-using-blocks/37204/3
|
| 262 |
with gr.Tab(label="Arousal / Dominance / Valence"):
|
| 263 |
iface.render()
|
|
|
|
| 264 |
with gr.Tab(label="CCC"):
|
| 265 |
gr.Markdown('''<table style="width:500px"><tr><th colspan=5 >CCC MSP Podcast v1.7</th></tr>
|
| 266 |
<tr> <td> </td><td>Arousal</td> <td>Dominance</td> <td>Valence</td> <td> Associated Paper </td> </tr>
|
|
@@ -269,5 +554,6 @@ with gr.Blocks() as demo:
|
|
| 269 |
</table>
|
| 270 |
''')
|
| 271 |
|
|
|
|
| 272 |
if __name__ == "__main__":
|
| 273 |
demo.launch(share=False)
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
class ADV(nn.Module):
|
| 21 |
|
| 22 |
def __init__(self, config):
|
|
|
|
| 81 |
).to(device).eval()
|
| 82 |
|
| 83 |
|
|
|
|
|
|
|
| 84 |
|
| 85 |
+
# Wav2Small
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
import torch
|
| 90 |
+
import numpy as np
|
| 91 |
+
import torch.nn.functional as F
|
| 92 |
+
import librosa
|
| 93 |
+
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2PreTrainedModel, Wav2Vec2Model
|
| 94 |
+
from torch import nn
|
| 95 |
+
from transformers import PretrainedConfig
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _prenorm(x, attention_mask=None):
|
| 99 |
+
'''mean/var'''
|
| 100 |
+
if attention_mask is not None:
|
| 101 |
+
N = attention_mask.sum(1, keepdim=True) # here attn msk is unprocessed just the original input
|
| 102 |
+
x -= x.sum(1, keepdim=True) / N
|
| 103 |
+
var = (x * x).sum(1, keepdim=True) / N
|
| 104 |
+
|
| 105 |
+
else:
|
| 106 |
+
x -= x.mean(1, keepdim=True) # mean is an onnx operator reducemean saves some ops compared to casting integer N to float and the div
|
| 107 |
+
var = (x * x).mean(1, keepdim=True)
|
| 108 |
+
return x / torch.sqrt(var + 1e-7)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class Spectrogram(nn.Module):
|
| 114 |
+
def __init__(self,
|
| 115 |
+
n_fft=64, # num cols of DFT
|
| 116 |
+
n_time=64, # num rows of DFT matrix
|
| 117 |
+
hop_length=32,
|
| 118 |
+
freeze_parameters=True):
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
super().__init__()
|
| 122 |
+
|
| 123 |
+
fft_window = librosa.filters.get_window('hann', n_time, fftbins=True)
|
| 124 |
+
|
| 125 |
+
fft_window = librosa.util.pad_center(fft_window, size=n_time)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
out_channels = n_fft // 2 + 1
|
| 132 |
+
|
| 133 |
+
(x, y) = np.meshgrid(np.arange(n_time), np.arange(n_fft))
|
| 134 |
+
omega = np.exp(-2 * np.pi * 1j / n_time)
|
| 135 |
+
dft_matrix = np.power(omega, x * y) # (n_fft, n_time)
|
| 136 |
+
dft_matrix = dft_matrix * fft_window[None, :]
|
| 137 |
+
dft_matrix = dft_matrix[0 : out_channels, :]
|
| 138 |
+
dft_matrix = dft_matrix[:, None, :]
|
| 139 |
+
|
| 140 |
+
# ---- Assymetric DFT Non Square
|
| 141 |
+
|
| 142 |
+
self.conv_real = nn.Conv1d(1, out_channels, n_fft, stride=hop_length, padding=0, bias=False)
|
| 143 |
+
self.conv_imag = nn.Conv1d(1, out_channels, n_fft, stride=hop_length, padding=0, bias=False)
|
| 144 |
+
self.conv_real.weight.data = torch.tensor(np.real(dft_matrix), dtype=self.conv_real.weight.dtype).to(self.conv_real.weight.device)
|
| 145 |
+
self.conv_imag.weight.data = torch.tensor(np.imag(dft_matrix), dtype=self.conv_imag.weight.dtype).to(self.conv_imag.weight.device)
|
| 146 |
+
if freeze_parameters:
|
| 147 |
+
for param in self.parameters():
|
| 148 |
+
param.requires_grad = False
|
| 149 |
+
|
| 150 |
+
def forward(self, input):
|
| 151 |
+
x = input[:, None, :]
|
| 152 |
+
|
| 153 |
+
real = self.conv_real(x)
|
| 154 |
+
imag = self.conv_imag(x)
|
| 155 |
+
return real ** 2 + imag ** 2 # bs, mel, time-frames
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class LogmelFilterBank(nn.Module):
|
| 159 |
+
def __init__(self,
|
| 160 |
+
sr=16000,
|
| 161 |
+
n_fft=64,
|
| 162 |
+
n_mels=26, # maxpool
|
| 163 |
+
fmin=0.0,
|
| 164 |
+
freeze_parameters=True):
|
| 165 |
+
|
| 166 |
+
super().__init__()
|
| 167 |
+
|
| 168 |
+
fmax = sr//2
|
| 169 |
+
|
| 170 |
+
W2 = librosa.filters.mel(sr=sr,
|
| 171 |
+
n_fft=n_fft,
|
| 172 |
+
n_mels=n_mels,
|
| 173 |
+
fmin=fmin,
|
| 174 |
+
fmax=fmax).T
|
| 175 |
+
|
| 176 |
+
self.register_buffer('melW', torch.Tensor(W2))
|
| 177 |
+
self.register_buffer('amin', torch.Tensor([1e-10]))
|
| 178 |
+
|
| 179 |
+
def forward(self, x):
|
| 180 |
+
|
| 181 |
+
x = torch.matmul(x[:, None, :, :].transpose(2, 3), self.melW) # changes melf not num frames
|
| 182 |
+
|
| 183 |
+
x = torch.where(x > self.amin, x, self.amin) # not in place
|
| 184 |
+
|
| 185 |
+
x = 10 * torch.log10(x)
|
| 186 |
+
return x
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def length_after_conv_layer(_length, k=None, pad=None, stride=None):
|
| 193 |
+
return torch.floor( (_length + 2*pad - k) / stride + 1 )
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class Conv(nn.Module):
|
| 201 |
+
|
| 202 |
+
def __init__(self, c_in, c_out, k=3, stride=1, padding=1):
|
| 203 |
+
|
| 204 |
+
super().__init__()
|
| 205 |
+
|
| 206 |
+
self.conv = nn.Conv2d(c_in, c_out, k, stride=stride, padding=padding, bias=False)
|
| 207 |
+
self.norm = nn.BatchNorm2d(c_out)
|
| 208 |
+
|
| 209 |
+
def forward(self, x):
|
| 210 |
+
x = self.conv(x)
|
| 211 |
+
x = self.norm(x)
|
| 212 |
+
return torch.relu_(x)
|
| 213 |
|
|
|
|
| 214 |
|
|
|
|
|
|
|
| 215 |
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class Vgg7(nn.Module):
|
| 219 |
+
|
| 220 |
+
def __init__(self):
|
| 221 |
+
|
| 222 |
+
super().__init__()
|
| 223 |
+
|
| 224 |
+
self.l1 = Conv( 1, 13)
|
| 225 |
+
self.l2 = Conv(13, 13)
|
| 226 |
+
self.l3 = Conv(13, 13)
|
| 227 |
+
self.maxpool_A = nn.MaxPool2d(3,
|
| 228 |
+
stride=2,
|
| 229 |
+
padding=1)
|
| 230 |
+
self.l4 = Conv(13, 13)
|
| 231 |
+
self.l5 = Conv(13, 13)
|
| 232 |
+
self.l6 = Conv(13, 13)
|
| 233 |
+
self.l7 = Conv(13, 13)
|
| 234 |
+
self.lin = nn.Conv2d(13, 13, 1, padding=0, stride=1)
|
| 235 |
+
self.sof = nn.Conv2d(13, 13, 1, padding=0, stride=1) # pool time - reshape mel into channels after pooling
|
| 236 |
+
self.spectrogram_extractor = Spectrogram()
|
| 237 |
+
self.logmel_extractor = LogmelFilterBank()
|
| 238 |
+
|
| 239 |
+
def final_length(self, L):
|
| 240 |
+
conv_kernel = [64, 3] # [nfft, maxpool]
|
| 241 |
+
conv_stride = [32, 2] # [hop_len, maxpool_stride] # consider only layers of stride > 1
|
| 242 |
+
conv_pad = [0, 1] # [pad_stft, pad_maxpool]
|
| 243 |
+
for k, stride, pad in zip(conv_kernel, conv_stride, conv_pad):
|
| 244 |
+
L = length_after_conv_layer(L, k=k, stride=stride, pad=pad)
|
| 245 |
+
return L
|
| 246 |
+
|
| 247 |
+
def final_attention_mask(self, feature_vector_length, attention_mask=None):
|
| 248 |
+
non_padded_lengths = attention_mask.sum(1)
|
| 249 |
+
out_lengths = self.final_length(non_padded_lengths) # how can non_padded_lengths get exact 0 here DOES IT MEAN ATTNMASK WAS NOT FILLED?
|
| 250 |
+
out_lengths = out_lengths.to(torch.long)
|
| 251 |
+
bs, _ = attention_mask.shape
|
| 252 |
+
attention_mask = torch.ones((bs, feature_vector_length),
|
| 253 |
+
dtype=attention_mask.dtype,
|
| 254 |
+
device=attention_mask.device)
|
| 255 |
+
for b, _len in enumerate(out_lengths):
|
| 256 |
+
attention_mask[b, _len:] = 0
|
| 257 |
+
return attention_mask
|
| 258 |
+
|
| 259 |
+
def forward(self, x, attention_mask=None):
|
| 260 |
+
x = _prenorm(x,
|
| 261 |
+
attention_mask=attention_mask)
|
| 262 |
+
x = self.spectrogram_extractor(x)
|
| 263 |
+
x = self.logmel_extractor(x)
|
| 264 |
+
x = self.l1(x)
|
| 265 |
+
x = self.l2(x)
|
| 266 |
+
x = self.l3(x)
|
| 267 |
+
x = self.maxpool_A(x) # reshape here? so these conv will have large kernel
|
| 268 |
+
x = self.l4(x)
|
| 269 |
+
x = self.l5(x)
|
| 270 |
+
x = self.l6(x)
|
| 271 |
+
x = self.l7(x)
|
| 272 |
+
if attention_mask is not None:
|
| 273 |
+
bs, _, t, _ = x.shape
|
| 274 |
+
a = self.final_attention_mask(feature_vector_length=t,
|
| 275 |
+
attention_mask=attention_mask)[:, None, :, None]
|
| 276 |
+
#print(a.shape, x.shape, '\n\n\n\n')
|
| 277 |
+
x = torch.masked_fill(x, a < 1, 0)
|
| 278 |
+
# mask also affects lin !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
| 279 |
+
x = self.lin(x) * ( self.sof(x) -10000. * torch.logical_not(a) ).softmax(2)
|
| 280 |
+
else:
|
| 281 |
+
x = self.lin(x) * self.sof(x).softmax(2)
|
| 282 |
+
|
| 283 |
+
x = x.sum(2) # bs, ch, time-frames, HALF_MEL -> bs, ch, HALF_MEL
|
| 284 |
+
# --
|
| 285 |
+
xT = x.transpose(1,2)
|
| 286 |
+
x = torch.cat([x,
|
| 287 |
+
torch.bmm(x, xT), # corr (chxmel) x (melxCH)
|
| 288 |
+
# torch.bmm(x, x), # corr ch * ch
|
| 289 |
+
# torch.bmm(xT, xT) # corr mel * mel
|
| 290 |
+
], 2)
|
| 291 |
+
# --
|
| 292 |
+
return x.reshape(-1, 338)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class Wav2SmallConfig(PretrainedConfig):
|
| 296 |
+
model_type = "wav2vec2"
|
| 297 |
+
|
| 298 |
+
def __init__(self,
|
| 299 |
+
**kwargs):
|
| 300 |
+
super().__init__(**kwargs)
|
| 301 |
+
self.half_mel = 13
|
| 302 |
+
self.n_fft = 64
|
| 303 |
+
self.n_time = 64
|
| 304 |
+
self.hidden = 2 * self.half_mel * self.half_mel
|
| 305 |
+
self.hop = self.n_time // 2
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
class Wav2Small(Wav2Vec2PreTrainedModel):
|
| 309 |
+
|
| 310 |
+
def __init__(self,
|
| 311 |
+
config):
|
| 312 |
+
super().__init__(config)
|
| 313 |
+
self.vgg7 = Vgg7()
|
| 314 |
+
self.adv = nn.Linear(config.hidden, 3) # 0=arousal, 1=dominance, 2=valence
|
| 315 |
+
|
| 316 |
+
def forward(self, x, attention_mask=None):
|
| 317 |
+
x = self.vgg7(x, attention_mask=attention_mask)
|
| 318 |
+
return self.adv(x)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def _ccc(x, y):
|
| 327 |
+
'''if len(x) = len(y) = 1 we have 0/0 as a&b can both be negative we should add 1e-7 to denominator protecting sign of denominator
|
| 328 |
+
to find sign of denominator and add 1e-7 if sgn>=0 or -1e-7 if sgn<0'''
|
| 329 |
+
|
| 330 |
+
mean_y = y.mean()
|
| 331 |
+
mean_x = x.mean()
|
| 332 |
+
a = x - mean_x
|
| 333 |
+
b = y - mean_y
|
| 334 |
+
L = (mean_x - mean_y).abs() * .1 * x.shape[0]
|
| 335 |
+
#print(L / ((mean_x - mean_y) **2 * x.shape[0]))
|
| 336 |
+
numerator = torch.dot(a, b) # L term if both a,b scalars dissallows 0 numerator [OFFICIAL CCC HAS L ONLY IN D]
|
| 337 |
+
denominator = torch.dot(a, a) + torch.dot(b, b) + L # if both a,b are equalscalars then the dots are all zero and ccc=1
|
| 338 |
+
denominator = torch.where(denominator.sign() < 0,
|
| 339 |
+
denominator - 1e-7,
|
| 340 |
+
denominator + 1e-7)
|
| 341 |
+
ccc = numerator / denominator
|
| 342 |
+
|
| 343 |
+
return -ccc #+ F.l1_loss(a, b)
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
wav2small = Wav2Small.from_pretrained('audeering/wav2small').to(device).eval()
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
# Error figure for the first plot
|
| 351 |
+
fig_error, ax = plt.subplots(figsize=(8, 6))
|
| 352 |
+
error_message = "Error: No .wav or Mic. audio provided."
|
| 353 |
ax.text(0.5, 0.5, error_message,
|
| 354 |
ha='center',
|
| 355 |
va='center',
|
|
|
|
| 357 |
color='gray',
|
| 358 |
fontweight='bold',
|
| 359 |
transform=ax.transAxes)
|
|
|
|
|
|
|
| 360 |
ax.set_xticks([])
|
| 361 |
ax.set_yticks([])
|
| 362 |
ax.set_xticklabels([])
|
| 363 |
ax.set_yticklabels([])
|
|
|
|
|
|
|
| 364 |
ax.set_frame_on(True)
|
| 365 |
ax.spines['top'].set_visible(False)
|
| 366 |
ax.spines['right'].set_visible(False)
|
| 367 |
ax.spines['bottom'].set_visible(False)
|
| 368 |
ax.spines['left'].set_visible(False)
|
| 369 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
def process_audio(audio_filepath):
|
| 371 |
if audio_filepath is None:
|
| 372 |
+
|
| 373 |
+
return fig_error, fig_error
|
| 374 |
|
| 375 |
+
waveform, sample_rate = librosa.load(audio_filepath, sr=None)
|
|
|
|
| 376 |
|
| 377 |
+
# Resample audio to 16kHz if the sample rate is different
|
|
|
|
|
|
|
| 378 |
if sample_rate != 16000:
|
| 379 |
resampled_waveform_np = audresample.resample(waveform, sample_rate, 16000)
|
| 380 |
+
else:
|
| 381 |
+
resampled_waveform_np = waveform[None, :]
|
| 382 |
+
|
| 383 |
+
x = torch.from_numpy(resampled_waveform_np).to(torch.float)
|
| 384 |
+
|
| 385 |
with torch.no_grad():
|
| 386 |
+
|
| 387 |
logits_dawn = dawn(x).cpu().numpy()[0, :]
|
|
|
|
| 388 |
|
| 389 |
+
logits_wavlm = base(x).cpu().numpy()[0, :]
|
| 390 |
+
|
| 391 |
+
# 17K params
|
| 392 |
+
logits_wav2small = wav2small(x).cpu().numpy()[0, :]
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
# --- Plot 1: Wav2Vec2 vs Wav2Small Teacher Outputs ---
|
| 396 |
+
|
| 397 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 398 |
|
|
|
|
|
|
|
| 399 |
left_bars_data = logits_dawn.clip(0, 1)
|
| 400 |
right_bars_data = logits_wav2small.clip(0, 1)
|
| 401 |
|
|
|
|
| 402 |
bar_labels = ['\nArousal', '\nDominance', '\nValence']
|
| 403 |
y_pos = np.arange(len(bar_labels))
|
| 404 |
|
| 405 |
+
# Define colormaps for each category to ensure distinct colors
|
|
|
|
| 406 |
category_colormaps = [plt.cm.Blues, plt.cm.Greys, plt.cm.Oranges]
|
| 407 |
|
|
|
|
| 408 |
left_filled_colors = []
|
| 409 |
right_filled_colors = []
|
| 410 |
background_colors = []
|
| 411 |
|
| 412 |
+
# Assign specific shades for filled bars and background bars
|
| 413 |
for i, cmap in enumerate(category_colormaps):
|
| 414 |
+
left_filled_colors.append(cmap(0.74))
|
| 415 |
+
right_filled_colors.append(cmap(0.64))
|
|
|
|
|
|
|
|
|
|
| 416 |
background_colors.append(cmap(0.1))
|
| 417 |
|
| 418 |
+
# Plot transparent background bars
|
|
|
|
|
|
|
|
|
|
| 419 |
for i in range(len(bar_labels)):
|
|
|
|
| 420 |
ax.barh(y_pos[i], -1, color=background_colors[i], alpha=0.3, height=0.6)
|
|
|
|
| 421 |
ax.barh(y_pos[i], 1, color=background_colors[i], alpha=0.3, height=0.6)
|
| 422 |
|
| 423 |
+
# Plot the filled bars for actual data
|
| 424 |
for i in range(len(bar_labels)):
|
|
|
|
| 425 |
ax.barh(y_pos[i], -left_bars_data[i], color=left_filled_colors[i], alpha=1, height=0.6)
|
|
|
|
| 426 |
ax.barh(y_pos[i], right_bars_data[i], color=right_filled_colors[i], alpha=1, height=0.6)
|
| 427 |
|
| 428 |
+
# Add a central vertical axis divider
|
| 429 |
ax.axvline(0, color='black', linewidth=0.8, linestyle='--')
|
| 430 |
|
| 431 |
+
# Set x-axis limits and y-axis ticks/labels
|
| 432 |
ax.set_xlim(-1, 1)
|
| 433 |
ax.set_yticks(y_pos)
|
| 434 |
ax.set_yticklabels(bar_labels, fontsize=12)
|
| 435 |
|
| 436 |
+
# Custom formatter for x-axis to show absolute percentage values
|
| 437 |
def abs_tick_formatter(x, pos):
|
| 438 |
return f'{int(abs(x) * 100)}%'
|
| 439 |
ax.xaxis.set_major_formatter(plt.FuncFormatter(abs_tick_formatter))
|
| 440 |
|
| 441 |
+
# Set plot title and x-axis label
|
| 442 |
ax.set_title('', fontsize=16, pad=20)
|
| 443 |
+
ax.set_xlabel('Wav2Vev2 (Dawn) Wav2Small (17K param.)', fontsize=12)
|
| 444 |
|
| 445 |
+
# Remove top, right, and left spines for a cleaner look
|
| 446 |
ax.spines['top'].set_visible(False)
|
| 447 |
ax.spines['right'].set_visible(False)
|
| 448 |
ax.spines['left'].set_visible(False)
|
| 449 |
|
| 450 |
+
# Add annotations (percentage values) to the filled bars
|
| 451 |
for i in range(len(bar_labels)):
|
|
|
|
| 452 |
ax.text(-left_bars_data[i] - 0.05, y_pos[i], f'{int(left_bars_data[i] * 100)}%',
|
| 453 |
va='center', ha='right', color=left_filled_colors[i], fontweight='bold')
|
|
|
|
| 454 |
ax.text(right_bars_data[i] + 0.05, y_pos[i], f'{int(right_bars_data[i] * 100)}%',
|
| 455 |
va='center', ha='left', color=right_filled_colors[i], fontweight='bold')
|
| 456 |
|
| 457 |
|
|
|
|
| 458 |
|
| 459 |
+
# -- PLOT 2 : WavLM / Wav2Small Teacher
|
| 460 |
+
|
| 461 |
+
fig_2, ax_2 = plt.subplots(figsize=(10, 6))
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
left_bars_data = logits_wavlm.clip(0, 1)
|
| 465 |
+
right_bars_data = (.5 * logits_dawn + .5 * logits_wavlm).clip(0, 1)
|
| 466 |
+
|
| 467 |
+
bar_labels = ['\nArousal', '\nDominance', '\nValence']
|
| 468 |
+
y_pos = np.arange(len(bar_labels))
|
| 469 |
+
|
| 470 |
+
# Define colormaps for each category to ensure distinct colors
|
| 471 |
+
category_colormaps = [plt.cm.Blues, plt.cm.Greys, plt.cm.Oranges]
|
| 472 |
+
|
| 473 |
+
left_filled_colors = []
|
| 474 |
+
right_filled_colors = []
|
| 475 |
+
background_colors = []
|
| 476 |
+
|
| 477 |
+
# Assign specific shades for filled bars and background bars
|
| 478 |
+
for i, cmap in enumerate(category_colormaps):
|
| 479 |
+
left_filled_colors.append(cmap(0.74))
|
| 480 |
+
right_filled_colors.append(cmap(0.64))
|
| 481 |
+
background_colors.append(cmap(0.1))
|
| 482 |
|
| 483 |
+
# Plot transparent background bars
|
| 484 |
+
for i in range(len(bar_labels)):
|
| 485 |
+
ax_2.barh(y_pos[i], -1, color=background_colors[i], alpha=0.3, height=0.6)
|
| 486 |
+
ax_2.barh(y_pos[i], 1, color=background_colors[i], alpha=0.3, height=0.6)
|
| 487 |
|
| 488 |
+
# Plot the filled bars for actual data
|
| 489 |
+
for i in range(len(bar_labels)):
|
| 490 |
+
ax_2.barh(y_pos[i], -left_bars_data[i], color=left_filled_colors[i], alpha=1, height=0.6)
|
| 491 |
+
ax_2.barh(y_pos[i], right_bars_data[i], color=right_filled_colors[i], alpha=1, height=0.6)
|
| 492 |
+
|
| 493 |
+
# Add a central vertical axis divider
|
| 494 |
+
ax_2.axvline(0, color='black', linewidth=0.8, linestyle='--')
|
| 495 |
|
| 496 |
+
# Set x-axis limits and y-axis ticks/labels
|
| 497 |
+
ax_2.set_xlim(-1, 1)
|
| 498 |
+
ax_2.set_yticks(y_pos)
|
| 499 |
+
ax_2.set_yticklabels(bar_labels, fontsize=12)
|
| 500 |
|
| 501 |
+
# Custom formatter for x-axis to show absolute percentage values
|
| 502 |
+
def abs_tick_formatter(x, pos):
|
| 503 |
+
return f'{int(abs(x) * 100)}%'
|
| 504 |
+
ax_2.xaxis.set_major_formatter(plt.FuncFormatter(abs_tick_formatter))
|
| 505 |
+
ax_2.set_title('', fontsize=16, pad=20)
|
| 506 |
+
ax_2.set_xlabel('WavLM (Baseline) Wav2Small Teacher (0.4B param.)', fontsize=12)
|
| 507 |
+
ax_2.spines['top'].set_visible(False)
|
| 508 |
+
ax_2.spines['right'].set_visible(False)
|
| 509 |
+
ax_2.spines['left'].set_visible(False)
|
| 510 |
+
|
| 511 |
+
# Add annotations (percentage values) to the filled bars
|
| 512 |
+
for i in range(len(bar_labels)):
|
| 513 |
+
ax_2.text(-left_bars_data[i] - 0.05, y_pos[i], f'{int(left_bars_data[i] * 100)}%',
|
| 514 |
+
va='center', ha='right', color=left_filled_colors[i], fontweight='bold')
|
| 515 |
+
ax_2.text(right_bars_data[i] + 0.05, y_pos[i], f'{int(right_bars_data[i] * 100)}%',
|
| 516 |
+
va='center', ha='left', color=right_filled_colors[i], fontweight='bold')
|
| 517 |
+
return fig, fig_2
|
| 518 |
|
| 519 |
|
| 520 |
iface = gr.Interface(
|
|
|
|
| 525 |
label=''
|
| 526 |
),
|
| 527 |
outputs=[
|
| 528 |
+
gr.Plot(label="Wav2Vec2 vs Wav2Small (17K params) Plot"), # First plot output
|
| 529 |
+
gr.Plot(label="WavLM vs Wav2Small Teacher Plot"), # Second plot output
|
| 530 |
],
|
| 531 |
title='',
|
| 532 |
description='',
|
| 533 |
+
flagging_mode="never", # Disables flagging feature
|
| 534 |
examples=[
|
| 535 |
"female-46-neutral.wav",
|
| 536 |
"female-20-happy.wav",
|
| 537 |
"male-60-angry.wav",
|
| 538 |
"male-27-sad.wav",
|
| 539 |
],
|
| 540 |
+
css="footer {visibility: hidden}" # Hides the Gradio footer
|
| 541 |
)
|
| 542 |
|
| 543 |
+
# Gradio Blocks for tabbed interface
|
| 544 |
with gr.Blocks() as demo:
|
| 545 |
+
# First tab for the existing Arousal/Dominance/Valence plots
|
|
|
|
| 546 |
with gr.Tab(label="Arousal / Dominance / Valence"):
|
| 547 |
iface.render()
|
| 548 |
+
# Second tab for CCC (Concordance Correlation Coefficient) information
|
| 549 |
with gr.Tab(label="CCC"):
|
| 550 |
gr.Markdown('''<table style="width:500px"><tr><th colspan=5 >CCC MSP Podcast v1.7</th></tr>
|
| 551 |
<tr> <td> </td><td>Arousal</td> <td>Dominance</td> <td>Valence</td> <td> Associated Paper </td> </tr>
|
|
|
|
| 554 |
</table>
|
| 555 |
''')
|
| 556 |
|
| 557 |
+
# Launch the Gradio application
|
| 558 |
if __name__ == "__main__":
|
| 559 |
demo.launch(share=False)
|