# config.py from transformers import PretrainedConfig class ModelConfig(PretrainedConfig): model_type = "SongFormer" def __init__( self, input_dim=2048, input_dim_raw=4096, transformer_encoder_input_dim=1024, transformer_input_dim=512, num_transformer_layers=4, transformer_nhead=8, transformer_dropout=0.1, num_classes=128, num_dataset_classes=64, down_sample_conv_kernel_size=3, down_sample_conv_stride=3, down_sample_conv_dropout=0.1, down_sample_conv_padding=0, boundary_tv_loss_beta=0.6, boundary_tv_loss_lambda=0.4, boundary_tv_loss_boundary_threshold=0.01, boundary_tv_loss_reduction_weight=0.1, boundary_tvloss_weight=0.05, label_focal_loss_alpha=0.25, label_focal_loss_gamma=2.0, label_focal_loss_weight=0.2, loss_weight_section=0.2, loss_weight_function=0.8, learn_label=True, learn_segment=True, local_maxima_filter_size=3, frame_rates=8.333, **kwargs ): super().__init__(**kwargs) self.input_dim = input_dim self.input_dim_raw = input_dim_raw self.transformer_encoder_input_dim = transformer_encoder_input_dim self.transformer_input_dim = transformer_input_dim self.num_transformer_layers = num_transformer_layers self.transformer_nhead = transformer_nhead self.transformer_dropout = transformer_dropout self.num_classes = num_classes self.num_dataset_classes = num_dataset_classes self.down_sample_conv_kernel_size = down_sample_conv_kernel_size self.down_sample_conv_stride = down_sample_conv_stride self.down_sample_conv_dropout = down_sample_conv_dropout self.down_sample_conv_padding = down_sample_conv_padding self.boundary_tv_loss_beta = boundary_tv_loss_beta self.boundary_tv_loss_lambda = boundary_tv_loss_lambda self.boundary_tv_loss_boundary_threshold = boundary_tv_loss_boundary_threshold self.boundary_tv_loss_reduction_weight = boundary_tv_loss_reduction_weight self.boundary_tvloss_weight = boundary_tvloss_weight self.label_focal_loss_alpha = label_focal_loss_alpha self.label_focal_loss_gamma = label_focal_loss_gamma self.label_focal_loss_weight = label_focal_loss_weight self.loss_weight_section = loss_weight_section self.loss_weight_function = loss_weight_function self.learn_label = learn_label self.learn_segment = learn_segment self.local_maxima_filter_size = local_maxima_filter_size self.frame_rates = frame_rates