Spaces:
Runtime error
Runtime error
| """ | |
| Implementation of objective functions used in the task 'End-to-end Remastering System' | |
| """ | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.nn as nn | |
| import os | |
| import sys | |
| currentdir = os.path.dirname(os.path.realpath(__file__)) | |
| sys.path.append(os.path.dirname(currentdir)) | |
| from modules.training_utils import * | |
| from modules.front_back_end import * | |
| ''' | |
| Normalized Temperature-scaled Cross Entropy (NT-Xent) Loss | |
| below source code (class NT_Xent) is a replication from the github repository - https://github.com/Spijkervet/SimCLR | |
| the original implementation can be found here: https://github.com/Spijkervet/SimCLR/blob/master/simclr/modules/nt_xent.py | |
| ''' | |
| class NT_Xent(nn.Module): | |
| def __init__(self, batch_size, temperature, world_size): | |
| super(NT_Xent, self).__init__() | |
| self.batch_size = batch_size | |
| self.temperature = temperature | |
| self.world_size = world_size | |
| self.mask = self.mask_correlated_samples(batch_size, world_size) | |
| self.criterion = nn.CrossEntropyLoss(reduction="sum") | |
| self.similarity_f = nn.CosineSimilarity(dim=2) | |
| def mask_correlated_samples(self, batch_size, world_size): | |
| N = 2 * batch_size * world_size | |
| mask = torch.ones((N, N), dtype=bool) | |
| mask = mask.fill_diagonal_(0) | |
| for i in range(batch_size * world_size): | |
| mask[i, batch_size + i] = 0 | |
| mask[batch_size + i, i] = 0 | |
| # mask[i, batch_size * world_size + i] = 0 | |
| # mask[batch_size * world_size + i, i] = 0 | |
| return mask | |
| def forward(self, z_i, z_j): | |
| """ | |
| We do not sample negative examples explicitly. | |
| Instead, given a positive pair, similar to (Chen et al., 2017), we treat the other 2(N − 1) augmented examples within a minibatch as negative examples. | |
| """ | |
| N = 2 * self.batch_size * self.world_size | |
| z = torch.cat((z_i, z_j), dim=0) | |
| # combine embeddings from all GPUs | |
| if self.world_size > 1: | |
| z = torch.cat(GatherLayer.apply(z), dim=0) | |
| sim = self.similarity_f(z.unsqueeze(1), z.unsqueeze(0)) / self.temperature | |
| sim_i_j = torch.diag(sim, self.batch_size * self.world_size) | |
| sim_j_i = torch.diag(sim, -self.batch_size * self.world_size) | |
| # We have 2N samples, but with Distributed training every GPU gets N examples too, resulting in: 2xNxN | |
| positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1) | |
| negative_samples = sim[self.mask].reshape(N, -1) | |
| labels = torch.zeros(N).to(positive_samples.device).long() | |
| logits = torch.cat((positive_samples, negative_samples), dim=1) | |
| loss = self.criterion(logits, labels) | |
| loss /= N | |
| return loss | |
| # Root Mean Squared Loss | |
| # penalizes the volume factor with non-linearlity | |
| class RMSLoss(nn.Module): | |
| def __init__(self, reduce, loss_type="l2"): | |
| super(RMSLoss, self).__init__() | |
| self.weight_factor = 100. | |
| if loss_type=="l2": | |
| self.loss = nn.MSELoss(reduce=None) | |
| def forward(self, est_targets, targets): | |
| est_targets = est_targets.reshape(est_targets.shape[0]*est_targets.shape[1], est_targets.shape[2]) | |
| targets = targets.reshape(targets.shape[0]*targets.shape[1], targets.shape[2]) | |
| normalized_est = torch.sqrt(torch.mean(est_targets**2, dim=-1)) | |
| normalized_tgt = torch.sqrt(torch.mean(targets**2, dim=-1)) | |
| weight = torch.clamp(torch.abs(normalized_tgt-normalized_est), min=1/self.weight_factor) * self.weight_factor | |
| return torch.mean(weight**1.5 * self.loss(normalized_est, normalized_tgt)) | |
| # Multi-Scale Spectral Loss proposed at the paper "DDSP: DIFFERENTIABLE DIGITAL SIGNAL PROCESSING" (https://arxiv.org/abs/2001.04643) | |
| # we extend this loss by applying it to mid/side channels | |
| class MultiScale_Spectral_Loss_MidSide_DDSP(nn.Module): | |
| def __init__(self, mode='midside', \ | |
| reduce=True, \ | |
| n_filters=None, \ | |
| windows_size=None, \ | |
| hops_size=None, \ | |
| window="hann", \ | |
| eps=1e-7, \ | |
| device=torch.device("cpu")): | |
| super(MultiScale_Spectral_Loss_MidSide_DDSP, self).__init__() | |
| self.mode = mode | |
| self.eps = eps | |
| self.mid_weight = 0.5 # value in the range of 0.0 ~ 1.0 | |
| self.logmag_weight = 0.1 | |
| if n_filters is None: | |
| n_filters = [4096, 2048, 1024, 512] | |
| # n_filters = [4096] | |
| if windows_size is None: | |
| windows_size = [4096, 2048, 1024, 512] | |
| # windows_size = [4096] | |
| if hops_size is None: | |
| hops_size = [1024, 512, 256, 128] | |
| # hops_size = [1024] | |
| self.multiscales = [] | |
| for i in range(len(windows_size)): | |
| cur_scale = {'window_size' : float(windows_size[i])} | |
| if self.mode=='midside': | |
| cur_scale['front_end'] = FrontEnd(channel='mono', \ | |
| n_fft=n_filters[i], \ | |
| hop_length=hops_size[i], \ | |
| win_length=windows_size[i], \ | |
| window=window, \ | |
| device=device) | |
| elif self.mode=='ori': | |
| cur_scale['front_end'] = FrontEnd(channel='stereo', \ | |
| n_fft=n_filters[i], \ | |
| hop_length=hops_size[i], \ | |
| win_length=windows_size[i], \ | |
| window=window, \ | |
| device=device) | |
| self.multiscales.append(cur_scale) | |
| self.objective_l1 = nn.L1Loss(reduce=reduce) | |
| self.objective_l2 = nn.MSELoss(reduce=reduce) | |
| def forward(self, est_targets, targets): | |
| if self.mode=='midside': | |
| return self.forward_midside(est_targets, targets) | |
| elif self.mode=='ori': | |
| return self.forward_ori(est_targets, targets) | |
| def forward_ori(self, est_targets, targets): | |
| total_loss = 0.0 | |
| total_mag_loss = 0.0 | |
| total_logmag_loss = 0.0 | |
| for cur_scale in self.multiscales: | |
| est_mag = cur_scale['front_end'](est_targets, mode=["mag"]) | |
| tgt_mag = cur_scale['front_end'](targets, mode=["mag"]) | |
| mag_loss = self.magnitude_loss(est_mag, tgt_mag) | |
| logmag_loss = self.log_magnitude_loss(est_mag, tgt_mag) | |
| # cur_loss = mag_loss + logmag_loss | |
| # total_loss += cur_loss | |
| total_mag_loss += mag_loss | |
| total_logmag_loss += logmag_loss | |
| # return total_loss | |
| # print(f"ori - mag : {total_mag_loss}\tlog mag : {total_logmag_loss}") | |
| return (1-self.logmag_weight)*total_mag_loss + \ | |
| (self.logmag_weight)*total_logmag_loss | |
| def forward_midside(self, est_targets, targets): | |
| est_mid, est_side = self.to_mid_side(est_targets) | |
| tgt_mid, tgt_side = self.to_mid_side(targets) | |
| total_loss = 0.0 | |
| total_mag_loss = 0.0 | |
| total_logmag_loss = 0.0 | |
| for cur_scale in self.multiscales: | |
| est_mid_mag = cur_scale['front_end'](est_mid, mode=["mag"]) | |
| est_side_mag = cur_scale['front_end'](est_side, mode=["mag"]) | |
| tgt_mid_mag = cur_scale['front_end'](tgt_mid, mode=["mag"]) | |
| tgt_side_mag = cur_scale['front_end'](tgt_side, mode=["mag"]) | |
| mag_loss = self.mid_weight*self.magnitude_loss(est_mid_mag, tgt_mid_mag) + \ | |
| (1-self.mid_weight)*self.magnitude_loss(est_side_mag, tgt_side_mag) | |
| logmag_loss = self.mid_weight*self.log_magnitude_loss(est_mid_mag, tgt_mid_mag) + \ | |
| (1-self.mid_weight)*self.log_magnitude_loss(est_side_mag, tgt_side_mag) | |
| # cur_loss = mag_loss + logmag_loss | |
| # total_loss += cur_loss | |
| total_mag_loss += mag_loss | |
| total_logmag_loss += logmag_loss | |
| # return total_loss | |
| # print(f"midside - mag : {total_mag_loss}\tlog mag : {total_logmag_loss}") | |
| return (1-self.logmag_weight)*total_mag_loss + \ | |
| (self.logmag_weight)*total_logmag_loss | |
| def to_mid_side(self, stereo_in): | |
| mid = stereo_in[:,0] + stereo_in[:,1] | |
| side = stereo_in[:,0] - stereo_in[:,1] | |
| return mid, side | |
| def magnitude_loss(self, est_mag_spec, tgt_mag_spec): | |
| return torch.norm(self.objective_l1(est_mag_spec, tgt_mag_spec)) | |
| def log_magnitude_loss(self, est_mag_spec, tgt_mag_spec): | |
| est_log_mag_spec = torch.log10(est_mag_spec+self.eps) | |
| tgt_log_mag_spec = torch.log10(tgt_mag_spec+self.eps) | |
| return self.objective_l2(est_log_mag_spec, tgt_log_mag_spec) | |
| # hinge loss for discriminator | |
| def dis_hinge(dis_fake, dis_real): | |
| return torch.mean(torch.relu(1. - dis_real)) + torch.mean(torch.relu(1. + dis_fake)) | |
| # hinge loss for generator | |
| def gen_hinge(dis_fake, dis_real=None): | |
| return -torch.mean(dis_fake) | |
| # DirectCLR's implementation of infoNCE loss | |
| def infoNCE(nn, p, temperature=0.1): | |
| nn = torch.nn.functional.normalize(nn, dim=1) | |
| p = torch.nn.functional.normalize(p, dim=1) | |
| nn = gather_from_all(nn) | |
| p = gather_from_all(p) | |
| logits = nn @ p.T | |
| logits /= temperature | |
| n = p.shape[0] | |
| labels = torch.arange(0, n, dtype=torch.long).cuda() | |
| loss = torch.nn.functional.cross_entropy(logits, labels) | |
| return loss | |
| # Class of available loss functions | |
| class Loss: | |
| def __init__(self, args, reduce=True): | |
| device = torch.device("cpu") | |
| if torch.cuda.is_available(): | |
| device = torch.device(f"cuda:{args.gpu}") | |
| self.l1 = nn.L1Loss(reduce=reduce) | |
| self.mse = nn.MSELoss(reduce=reduce) | |
| self.ce = nn.CrossEntropyLoss() | |
| self.triplet = nn.TripletMarginLoss(margin=1., p=2) | |
| # self.ntxent = NT_Xent(args.train_batch*2, args.temperature, world_size=len(args.using_gpu.split(","))) | |
| self.ntxent = NT_Xent(args.batch_size_total*(args.num_strong_negatives+1), args.temperature, world_size=1) | |
| self.multi_scale_spectral_midside = MultiScale_Spectral_Loss_MidSide_DDSP(mode='midside', eps=args.eps, device=device) | |
| self.multi_scale_spectral_ori = MultiScale_Spectral_Loss_MidSide_DDSP(mode='ori', eps=args.eps, device=device) | |
| self.gain = RMSLoss(reduce=reduce) | |
| self.infonce = infoNCE | |