Spaces:
Sleeping
Sleeping
| import os, shutil, argparse | |
| import time | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader | |
| import logging | |
| import yaml | |
| import pdb | |
| from lmdb_data_loader_A import LmdbDataset | |
| from models.cache_resnet_conformer import ResnetConformer_sed_doa_nopool | |
| from lr_scheduler.tri_stage_lr_scheduler import TriStageLRScheduler | |
| from utils.cls_tools.cls_compute_seld_results import ComputeSELDResults | |
| from utils.write_csv import write_output_format_file | |
| from utils.sed_doa import SedDoaResult, process_foa_input_sed_doa_labels, SedDoaLoss | |
| def set_random_seed(seed): | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| return None | |
| def main(args): | |
| # 设置log | |
| log_output_folder = os.path.dirname(args['result']['log_output_path']) | |
| os.makedirs(log_output_folder, exist_ok=True) | |
| logging.basicConfig(filename=args['result']['log_output_path'], filemode='w', level=logging.INFO, format='%(levelname)s: %(asctime)s: %(message)s', datefmt='%m/%d/%Y %H:%M:%S') | |
| logger = logging.getLogger(__name__) | |
| logger.info(args) | |
| data_process_fn = process_foa_input_sed_doa_labels | |
| result_class = SedDoaResult | |
| criterion = SedDoaLoss(loss_weight=[0.1,1]) | |
| model = ResnetConformer_sed_doa_nopool(in_channel=args['model']['in_channel'], | |
| in_dim=args['model']['in_dim'], | |
| out_dim=args['model']['out_dim'], | |
| att_context_size = args['model']['att_context_size'], | |
| num_conformer_layer = args['model']['num_conformer_layer'], | |
| encoder_dim = args['model']['encoder_dim']) | |
| # 训练集初始化 | |
| train_split = [1,2,3] | |
| train_dataset = LmdbDataset(args['data']['train_lmdb_dir'], train_split, normalized_features_wts_file=args['data']['norm_file'], | |
| ignore=args['data']['train_ignore'], segment_len=args['data']['segment_len'], data_process_fn=data_process_fn) | |
| train_dataloader = DataLoader( | |
| dataset=train_dataset, batch_size=args['data']['batch_size'], shuffle=True, | |
| num_workers=args['train']['train_num_workers'], collate_fn=train_dataset.collater | |
| ) | |
| # 测试集初始化 | |
| test_split = [4] | |
| test_dataset = LmdbDataset(args['data']['test_lmdb_dir'], test_split, normalized_features_wts_file=args['data']['norm_file'], | |
| ignore=args['data']['test_ignore'], segment_len=args['data']['segment_len'], data_process_fn=data_process_fn) | |
| test_dataloader = DataLoader( | |
| dataset=test_dataset, batch_size=args['data']['batch_size'], shuffle=False, | |
| num_workers=args['train']['test_num_workers'], collate_fn=test_dataset.collater | |
| ) | |
| # 模型初始化 | |
| use_cuda = torch.cuda.is_available() | |
| device = torch.device('cuda' if use_cuda else "cpu") | |
| model = model.to(device) | |
| logger.info(model) | |
| set_random_seed(12332) | |
| if args['model']['pre-train']: | |
| model.load_state_dict(torch.load(args['model']['pre-train_model'])) | |
| # logger.info(model) | |
| # 优化器初始化 | |
| optimizer = optim.Adam(model.parameters(), lr=args['train']['lr']) | |
| total_steps = args['train']['nb_steps'] | |
| warmup_steps = int(total_steps*0.1) | |
| hold_steps = int(total_steps*0.6) | |
| decay_steps = int(total_steps*0.3) | |
| scheduler = TriStageLRScheduler(optimizer, peak_lr=args['train']['lr'], init_lr_scale=0.01, final_lr_scale=0.05, | |
| warmup_steps=warmup_steps, hold_steps=hold_steps, decay_steps=decay_steps) | |
| epoch_count = 0 | |
| step_count = 0 | |
| # 开始训练 | |
| stop_training = False | |
| best_seld_score = float('inf') # 初始化最佳SELD分数 | |
| best_epoch = 0 # 初始化最佳epoch | |
| best_checkpoint = '' # 初始化最佳checkpoint路径 | |
| patience = 40 # 早停耐心值 | |
| patience_counter = 0 # 早停计数器 | |
| while not stop_training: | |
| train_loss = [] | |
| test_loss = [] | |
| epoch_count += 1 | |
| # 训练 | |
| start_time = time.time() | |
| model.train() | |
| for data in train_dataloader: | |
| input = data['input'].to(device) | |
| target = data['target'].to(device) | |
| optimizer.zero_grad() | |
| output = model(input) | |
| loss = criterion(output, target) | |
| loss.backward() | |
| optimizer.step() | |
| scheduler.step() | |
| train_loss.append(loss.item()) | |
| step_count += 1 | |
| if step_count % args['result']['log_interval'] == 0: | |
| lr = optimizer.param_groups[0]['lr'] | |
| logger.info('epoch: {}, step: {}/{}, lr:{:.6f}, train_loss:{:.4f}'.format(epoch_count, step_count, total_steps, lr, loss.item())) | |
| if step_count >= total_steps: | |
| stop_training = True | |
| logger.info('Reached maximum number of steps') | |
| break | |
| torch.cuda.empty_cache() | |
| train_time = time.time() - start_time | |
| # 测试 | |
| start_time = time.time() | |
| model.eval() | |
| test_result = result_class(segment_length=args['data']['segment_len']) | |
| for data in test_dataloader: | |
| input = data['input'].to(device) | |
| target = data['target'].to(device) | |
| with torch.no_grad(): | |
| output = model(input) | |
| loss = criterion(output, target) | |
| test_loss.append(loss.item()) | |
| test_result.add_items(data['wav_names'], output) | |
| output_dict = test_result.get_result() | |
| test_time = time.time() - start_time | |
| # 保存测试集CSV文件 | |
| dcase_output_val_dir = os.path.join(args['result']['dcase_output_dir'], 'epoch{}_step{}'.format(epoch_count, step_count)) | |
| os.makedirs(dcase_output_val_dir, exist_ok=True) | |
| for csv_name, perfile_out_dict in output_dict.items(): | |
| output_file = os.path.join(dcase_output_val_dir, '{}.csv'.format(csv_name)) | |
| write_output_format_file(output_file, perfile_out_dict) | |
| #根据保存的CSV文件进行结果评估 | |
| score_obj = ComputeSELDResults(ref_files_folder=args['data']['ref_files_dir']) | |
| val_ER, val_F, val_LE, val_LR, val_seld_scr, classwise_val_scr = score_obj.get_SELD_Results(dcase_output_val_dir) | |
| logger.info('epoch: {}, step: {}/{}, train_time:{:.2f}, test_time:{:.2f}, average_train_loss:{:.4f}, average_test_loss:{:.4f}'.format(epoch_count, step_count, total_steps, train_time, test_time, np.mean(train_loss), np.mean(test_loss))) | |
| logger.info('ER/F/LE/LR/SELD: {}'.format('{:0.4f}/{:0.4f}/{:0.4f}/{:0.4f}/{:0.4f}'.format(val_ER, val_F, val_LE, val_LR, val_seld_scr))) | |
| # 保存模型 | |
| checkpoint_output_dir = args['result']['checkpoint_output_dir'] | |
| os.makedirs(checkpoint_output_dir, exist_ok=True) | |
| model_path = os.path.join(checkpoint_output_dir, 'checkpoint_epoch{}_step{}.pth'.format(epoch_count, step_count)) | |
| torch.save(model.state_dict(), model_path) | |
| logger.info('save checkpoint: {}'.format(model_path)) | |
| # 更新最佳性能记录 | |
| if val_seld_scr < best_seld_score: | |
| best_seld_score = val_seld_scr | |
| best_epoch = epoch_count | |
| best_checkpoint = model_path | |
| patience_counter = 0 # 重置早停计数器 | |
| logger.info('New best model found SELD score: {:.4f}'.format(best_seld_score)) | |
| else: | |
| patience_counter += 1 | |
| logger.info('No improvement for {} epochs. Best SELD score so far: {:.4f}'.format( | |
| patience_counter, best_seld_score)) | |
| # 检查是否应该早停 | |
| if patience_counter >= patience: | |
| logger.info('Early stopping triggered after {} epochs without improvement'.format(patience)) | |
| stop_training = True | |
| # 训练结束后记录最佳性能 | |
| logger.info('='*50) | |
| logger.info('Training completed!') | |
| if patience_counter >= patience: | |
| logger.info('Stopped due to: Early stopping criterion met') | |
| else: | |
| logger.info('Stopped due to: Maximum steps reached') | |
| logger.info('Best performance:') | |
| logger.info('Epoch: {}'.format(best_epoch)) | |
| logger.info('SELD score: {:.4f}'.format(best_seld_score)) | |
| logger.info('Checkpoint path: {}'.format(best_checkpoint)) | |
| logger.info('Total epochs trained: {}'.format(epoch_count)) | |
| logger.info('='*50) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser('train') | |
| parser.add_argument('-c', '--config_name', type=str, default='foa_dev_multi_accdoa_nopool', help='name of config') | |
| input_args = parser.parse_args() | |
| # 不同任务使用不同配置文件 | |
| with open(os.path.join('config', '{}.yaml'.format(input_args.config_name)), 'r') as f: | |
| args = yaml.safe_load(f) | |
| main(args) |