SELD_Streaming_RC / train_A_Stream_RC.py
dongyuxuan
Open-source release: add training, finetuning, and preprocessing pipelines
8d23519
import os, shutil, argparse
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import logging
import yaml
import pdb
from lmdb_data_loader_A import LmdbDataset
from models.cache_resnet_conformer import ResnetConformer_sed_doa_nopool
from lr_scheduler.tri_stage_lr_scheduler import TriStageLRScheduler
from utils.cls_tools.cls_compute_seld_results import ComputeSELDResults
from utils.write_csv import write_output_format_file
from utils.sed_doa import SedDoaResult, process_foa_input_sed_doa_labels, SedDoaLoss
def set_random_seed(seed):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
return None
def main(args):
# 设置log
log_output_folder = os.path.dirname(args['result']['log_output_path'])
os.makedirs(log_output_folder, exist_ok=True)
logging.basicConfig(filename=args['result']['log_output_path'], filemode='w', level=logging.INFO, format='%(levelname)s: %(asctime)s: %(message)s', datefmt='%m/%d/%Y %H:%M:%S')
logger = logging.getLogger(__name__)
logger.info(args)
data_process_fn = process_foa_input_sed_doa_labels
result_class = SedDoaResult
criterion = SedDoaLoss(loss_weight=[0.1,1])
model = ResnetConformer_sed_doa_nopool(in_channel=args['model']['in_channel'],
in_dim=args['model']['in_dim'],
out_dim=args['model']['out_dim'],
att_context_size = args['model']['att_context_size'],
num_conformer_layer = args['model']['num_conformer_layer'],
encoder_dim = args['model']['encoder_dim'])
# 训练集初始化
train_split = [1,2,3]
train_dataset = LmdbDataset(args['data']['train_lmdb_dir'], train_split, normalized_features_wts_file=args['data']['norm_file'],
ignore=args['data']['train_ignore'], segment_len=args['data']['segment_len'], data_process_fn=data_process_fn)
train_dataloader = DataLoader(
dataset=train_dataset, batch_size=args['data']['batch_size'], shuffle=True,
num_workers=args['train']['train_num_workers'], collate_fn=train_dataset.collater
)
# 测试集初始化
test_split = [4]
test_dataset = LmdbDataset(args['data']['test_lmdb_dir'], test_split, normalized_features_wts_file=args['data']['norm_file'],
ignore=args['data']['test_ignore'], segment_len=args['data']['segment_len'], data_process_fn=data_process_fn)
test_dataloader = DataLoader(
dataset=test_dataset, batch_size=args['data']['batch_size'], shuffle=False,
num_workers=args['train']['test_num_workers'], collate_fn=test_dataset.collater
)
# 模型初始化
use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else "cpu")
model = model.to(device)
logger.info(model)
set_random_seed(12332)
if args['model']['pre-train']:
model.load_state_dict(torch.load(args['model']['pre-train_model']))
# logger.info(model)
# 优化器初始化
optimizer = optim.Adam(model.parameters(), lr=args['train']['lr'])
total_steps = args['train']['nb_steps']
warmup_steps = int(total_steps*0.1)
hold_steps = int(total_steps*0.6)
decay_steps = int(total_steps*0.3)
scheduler = TriStageLRScheduler(optimizer, peak_lr=args['train']['lr'], init_lr_scale=0.01, final_lr_scale=0.05,
warmup_steps=warmup_steps, hold_steps=hold_steps, decay_steps=decay_steps)
epoch_count = 0
step_count = 0
# 开始训练
stop_training = False
best_seld_score = float('inf') # 初始化最佳SELD分数
best_epoch = 0 # 初始化最佳epoch
best_checkpoint = '' # 初始化最佳checkpoint路径
patience = 40 # 早停耐心值
patience_counter = 0 # 早停计数器
while not stop_training:
train_loss = []
test_loss = []
epoch_count += 1
# 训练
start_time = time.time()
model.train()
for data in train_dataloader:
input = data['input'].to(device)
target = data['target'].to(device)
optimizer.zero_grad()
output = model(input)
loss = criterion(output, target)
loss.backward()
optimizer.step()
scheduler.step()
train_loss.append(loss.item())
step_count += 1
if step_count % args['result']['log_interval'] == 0:
lr = optimizer.param_groups[0]['lr']
logger.info('epoch: {}, step: {}/{}, lr:{:.6f}, train_loss:{:.4f}'.format(epoch_count, step_count, total_steps, lr, loss.item()))
if step_count >= total_steps:
stop_training = True
logger.info('Reached maximum number of steps')
break
torch.cuda.empty_cache()
train_time = time.time() - start_time
# 测试
start_time = time.time()
model.eval()
test_result = result_class(segment_length=args['data']['segment_len'])
for data in test_dataloader:
input = data['input'].to(device)
target = data['target'].to(device)
with torch.no_grad():
output = model(input)
loss = criterion(output, target)
test_loss.append(loss.item())
test_result.add_items(data['wav_names'], output)
output_dict = test_result.get_result()
test_time = time.time() - start_time
# 保存测试集CSV文件
dcase_output_val_dir = os.path.join(args['result']['dcase_output_dir'], 'epoch{}_step{}'.format(epoch_count, step_count))
os.makedirs(dcase_output_val_dir, exist_ok=True)
for csv_name, perfile_out_dict in output_dict.items():
output_file = os.path.join(dcase_output_val_dir, '{}.csv'.format(csv_name))
write_output_format_file(output_file, perfile_out_dict)
#根据保存的CSV文件进行结果评估
score_obj = ComputeSELDResults(ref_files_folder=args['data']['ref_files_dir'])
val_ER, val_F, val_LE, val_LR, val_seld_scr, classwise_val_scr = score_obj.get_SELD_Results(dcase_output_val_dir)
logger.info('epoch: {}, step: {}/{}, train_time:{:.2f}, test_time:{:.2f}, average_train_loss:{:.4f}, average_test_loss:{:.4f}'.format(epoch_count, step_count, total_steps, train_time, test_time, np.mean(train_loss), np.mean(test_loss)))
logger.info('ER/F/LE/LR/SELD: {}'.format('{:0.4f}/{:0.4f}/{:0.4f}/{:0.4f}/{:0.4f}'.format(val_ER, val_F, val_LE, val_LR, val_seld_scr)))
# 保存模型
checkpoint_output_dir = args['result']['checkpoint_output_dir']
os.makedirs(checkpoint_output_dir, exist_ok=True)
model_path = os.path.join(checkpoint_output_dir, 'checkpoint_epoch{}_step{}.pth'.format(epoch_count, step_count))
torch.save(model.state_dict(), model_path)
logger.info('save checkpoint: {}'.format(model_path))
# 更新最佳性能记录
if val_seld_scr < best_seld_score:
best_seld_score = val_seld_scr
best_epoch = epoch_count
best_checkpoint = model_path
patience_counter = 0 # 重置早停计数器
logger.info('New best model found SELD score: {:.4f}'.format(best_seld_score))
else:
patience_counter += 1
logger.info('No improvement for {} epochs. Best SELD score so far: {:.4f}'.format(
patience_counter, best_seld_score))
# 检查是否应该早停
if patience_counter >= patience:
logger.info('Early stopping triggered after {} epochs without improvement'.format(patience))
stop_training = True
# 训练结束后记录最佳性能
logger.info('='*50)
logger.info('Training completed!')
if patience_counter >= patience:
logger.info('Stopped due to: Early stopping criterion met')
else:
logger.info('Stopped due to: Maximum steps reached')
logger.info('Best performance:')
logger.info('Epoch: {}'.format(best_epoch))
logger.info('SELD score: {:.4f}'.format(best_seld_score))
logger.info('Checkpoint path: {}'.format(best_checkpoint))
logger.info('Total epochs trained: {}'.format(epoch_count))
logger.info('='*50)
if __name__ == "__main__":
parser = argparse.ArgumentParser('train')
parser.add_argument('-c', '--config_name', type=str, default='foa_dev_multi_accdoa_nopool', help='name of config')
input_args = parser.parse_args()
# 不同任务使用不同配置文件
with open(os.path.join('config', '{}.yaml'.format(input_args.config_name)), 'r') as f:
args = yaml.safe_load(f)
main(args)